// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package conn import ( "io" "net" "time" "github.com/gorilla/websocket" ) func WrapWebSocketConn(c *websocket.Conn) *WebSocketConn { return &WebSocketConn{c: c} } type WebSocketConn struct { r io.Reader c *websocket.Conn } func (c *WebSocketConn) Write(p []byte) (int, error) { err := c.c.WriteMessage(websocket.BinaryMessage, p) if err != nil { return 0, err } return len(p), nil } func (c *WebSocketConn) Read(p []byte) (int, error) { for { if c.r == nil { // Advance to next message. var err error _, c.r, err = c.c.NextReader() if err != nil { return 0, err } } n, err := c.r.Read(p) if err == io.EOF { // At end of message. c.r = nil if n > 0 { return n, nil } else { // No data read, continue to next message. continue } } return n, err } } func (c *WebSocketConn) Close() error { return c.c.Close() } func (c *WebSocketConn) LocalAddr() net.Addr { return c.c.LocalAddr() } func (c *WebSocketConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } func (c *WebSocketConn) SetDeadline(t time.Time) error { if err := c.SetReadDeadline(t); err != nil { return err } if err := c.SetWriteDeadline(t); err != nil { return err } return nil } func (c *WebSocketConn) SetReadDeadline(t time.Time) error { return c.c.SetReadDeadline(t) } func (c *WebSocketConn) SetWriteDeadline(t time.Time) error { return c.c.SetWriteDeadline(t) }