diff --git a/client.go b/client.go new file mode 100644 index 0000000..996ae75 --- /dev/null +++ b/client.go @@ -0,0 +1,103 @@ +package main + +import ( + "bytes" + "log" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +const ( + writeWait = 10 * time.Second + pongWait = 60 * time.Second + pingPeriod = (pongWait * 9) / 10 + maxMessageSize = 512 +) + +var ( + newLine = []byte{'\n'} + space = []byte{' '} +) + +var upgrader = websocket.Upgrader{} + +type Client struct { + hub *Hub + conn *websocket.Conn + send chan []byte +} + +func (c *Client) readPump() { + defer func() { + c.hub.unregister <- c + c.conn.Close() + }() + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("error: %v", err) + } + break + } + message = bytes.TrimSpace(bytes.Replace(message, newLine, space, -1)) + c.hub.broadcast <- message + } +} + +func (c *Client) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + w.Write(message) + + n := len(c.send) + for i := 0; i < n; i++ { + w.Write(newLine) + w.Write(<-c.send) + } + + if err := w.Close(); err != nil { + return + } + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} + client.hub.register <- client + + go client.writePump() + go client.readPump() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8d78f96 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module go-websocket-example + +go 1.15 + +require github.com/gorilla/websocket v1.4.2 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..85efffd --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/hub.go b/hub.go new file mode 100644 index 0000000..9a8ffcb --- /dev/null +++ b/hub.go @@ -0,0 +1,40 @@ +package main + +type Hub struct { + clients map[*Client]bool + broadcast chan []byte + register chan *Client + unregister chan *Client +} + +func newHub() *Hub { + return &Hub{ + broadcast: make(chan []byte), + register: make(chan *Client), + unregister: make(chan *Client), + clients: make(map[*Client]bool), + } +} + +func (h *Hub) run() { + for { + select { + case client := <-h.register: + h.clients[client] = true + case client := <-h.unregister: + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + } + case message := <-h.broadcast: + for client := range h.clients { + select { + case client.send <- message: + default: + close(client.send) + delete(h.clients, client) + } + } + } + } +} diff --git a/index.html b/index.html new file mode 100644 index 0000000..67199f9 --- /dev/null +++ b/index.html @@ -0,0 +1,98 @@ + + + +Chat Example + + + + +
+
+ + +
+ + \ No newline at end of file diff --git a/main.go b/main.go new file mode 100644 index 0000000..d289d09 --- /dev/null +++ b/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "flag" + "log" + "net/http" +) + +var addr = flag.String("addr", ":8080", "http service address") + +func serveIndex(w http.ResponseWriter, r *http.Request) { + log.Println(r.URL) + if r.URL.Path != "/" { + http.Error(w, "Not found", http.StatusNotFound) + return + } + if r.Method != "GET" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + http.ServeFile(w, r, "index.html") +} + +func main() { + flag.Parse() + hub := newHub() + go hub.run() + http.HandleFunc("/", serveIndex) + http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + serveWs(hub, w, r) + }) + err := http.ListenAndServe(*addr, nil) + if err != nil { + log.Fatal("Liatend and serve: ", err) + } +}