golang 实现 websocket
具体可以参考GitHub上的代码
package websocket
import (
"bytes"
"github.com/gorilla/websocket"
"log"
"net/http"
"time"
)
var (
newline = []byte{'\n'}
space = []byte{' '}
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 512
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
//定义-》 只有一个组里面的 client 才能互相聊天,
type WsClientGroup struct {
broadcast chan []byte //用于广播数据【广播协程监听这个通道】
clientEnter chan *WsClient
clientExit chan *WsClient
clients map[*WsClient]bool
}
func (g *WsClientGroup) HandleRun() {
//eventLoop:
log.Printf("handle run")
for {
select {
case client := <-g.clientEnter:
//注册用户
g.clients[client] = true
case client := <-g.clientExit:
//用户退出
if _, ok := g.clients[client]; ok {
delete(g.clients, client)
close(client.send)
}
case broadcastMsg := <-g.broadcast:
log.Printf("broadcastMsg, %s",broadcastMsg)
for cli := range g.clients {
select {
case cli.send <- broadcastMsg:
default:
//这种情况下,只能说明 cli.send == nil
close(cli.send)
delete(g.clients, cli)
}
}
}
}
}
func NewClientGroup() *WsClientGroup {
return &WsClientGroup{
broadcast: make(chan []byte),
clientEnter: make(chan *WsClient),
clientExit: make(chan *WsClient),
clients: make(map[*WsClient]bool),
}
}
type WsClient struct {
send chan []byte
conn *websocket.Conn
Groups *WsClientGroup
}
//开启读协程
func (cli *WsClient) ReadLoopGroup() {
defer func() {
cli.Groups.clientExit <- cli
// exit and close connection
cli.conn.Close()
}()
cli.conn.SetReadLimit(int64(maxMessageSize))
cli.conn.SetReadDeadline(time.Now().Add(pongWait))
cli.conn.SetPongHandler(func(s string) error {
//续期
cli.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
log.Printf("begin read loop ")
readEventLoop:
for {
_, msg, err := cli.conn.ReadMessage()
log.Printf("begin read msg")
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("websocket client error: %+v", err)
}
break readEventLoop
}
msg = bytes.TrimSpace(msg)
log.Printf("receive client msg = [%s]", msg)
// 这里可以广播给其他用户, 前端可以传个 type,后端 根据 type 判断是广播还是私聊
//假设这里用广播
cli.Groups.broadcast <- msg
}
}
func (cli *WsClient) WriteLoopGroup() {
ticker := time.NewTicker(pingPeriod)
defer func() {
log.Printf("exit writeLoopGroup")
ticker.Stop()
cli.conn.Close()
}()
writeEventLoop:
for {
select {
case msg, ok := <-cli.send:
if !ok {
cli.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := cli.conn.NextWriter(websocket.TextMessage)
if err != nil {
break writeEventLoop
}
w.Write(msg)
n := len(cli.send)
for i := 0; i < n; i++ {
//继续发送
w.Write(newline)
w.Write(<-cli.send)
}
if err = w.Close(); err != nil {
break writeEventLoop
}
case <-ticker.C:
cli.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := cli.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
//write errror
//超时 异常 ,主动退出
break writeEventLoop
}
}
}
}
func Register(groups *WsClientGroup, w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("error info %+v", err)
return
}
client := &WsClient{
Groups: groups,
send: make(chan []byte, 256),
conn: conn,
}
groups.clientEnter <- client
log.Printf("enter client ")
//groups.HandleRun()
go client.ReadLoopGroup()
go client.WriteLoopGroup()
}
main.go 测试代码
package main
import (
"log"
"net/http"
"websocket/websocket"
)
func serverHome(w http.ResponseWriter, r *http.Request) {
log.Println(r.URL)
if r.URL.Path != "/" {
http.Error(w, "Not found", http.StatusNotFound)
return
}
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
http.ServeFile(w, r, "index.html")
}
func main() {
http.HandleFunc("/", serverHome)
var groups = websocket.NewClientGroup()
go groups.HandleRun()
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
//这里要区分不同的聊天室的话, 可能还需要加一个 hashMap, hashMap.put("chat_room_id",groups)
websocket.Register(groups, w, r)
})
http.ListenAndServe(":8080", nil)
}
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<title>Chat Example</title>
<script type="text/javascript">
window.onload = function () {
var conn;
var msg = document.getElementById("msg");
var log = document.getElementById("log");
function appendLog(item) {
var doScroll = log.scrollTop > log.scrollHeight - log.clientHeight - 1;
log.appendChild(item);
if (doScroll) {
log.scrollTop = log.scrollHeight - log.clientHeight;
}
}
document.getElementById("form").onsubmit = function () {
if (!conn) {
return false;
}
if (!msg.value) {
return false;
}
conn.send(msg.value);
msg.value = "";
return false;
};
if (window["WebSocket"]) {
conn = new WebSocket("ws://" + document.location.host + "/ws");
conn.onclose = function (evt) {
var item = document.createElement("div");
item.innerHTML = "<b>Connection closed.</b>";
appendLog(item);
};
conn.onmessage = function (evt) {
var messages = evt.data.split('\n');
for (var i = 0; i < messages.length; i++) {
var item = document.createElement("div");
item.innerText = messages[i];
appendLog(item);
}
};
} else {
var item = document.createElement("div");
item.innerHTML = "<b>Your browser does not support WebSockets.</b>";
appendLog(item);
}
};
</script>
<style type="text/css">
html {
overflow: hidden;
}
body {
overflow: hidden;
padding: 0;
margin: 0;
width: 100%;
height: 100%;
background: gray;
}
#log {
background: white;
margin: 0;
padding: 0.5em 0.5em 0.5em 0.5em;
position: absolute;
top: 0.5em;
left: 0.5em;
right: 0.5em;
bottom: 3em;
overflow: auto;
}
#form {
padding: 0 0.5em 0 0.5em;
margin: 0;
position: absolute;
bottom: 1em;
left: 0px;
width: 100%;
overflow: hidden;
}
</style>
</head>
<body>
<div id="log"></div>
<form id="form">
<input type="submit" value="Send" />
<input type="text" id="msg" size="64" autofocus />
</form>
</body>
</html>