【Golang】gorilla/websocket实战和底层代码分析
原创前言
在【为什么有了http,还需要websocket,我懂了!】中介绍了web端即时通讯的方式,以及websocket如何进行连接、验证、数据帧的格式,这些都是了解websocket的基础知识。
本期将会继续上次话题,这次是进行实操websocket框架,分享它使用和底层实现!
📚 全文字数 : 4k+
⏳ 阅读时长 : 7min
📢 关键词 : gorilla/websocket 、数据帧、Upgrader
相信很多使用Golang的小伙伴都知道Gorilla这个工具包,长久以来 gorilla/websocket 都是比官方包更好的websocket包。
题外话 gorilla:大猩猩(不过这个猩猩还挺可爱的)
gorilla/websocket 框架开源地址为: https://github.com/gorilla/websocket
今天小许就用【gorilla/websocket】框架来展开本期文章内容,文章会设计到核心代码的走读,会涉及到不少代码,需要小伙伴们保持耐心往下看,然后结合之前分享的websocket基础,彻底学个明白!
简单使用
安装Gorilla Websocket Go软件包,您只需要使用即可go get
go get github.com/gorilla/websocket
在正式使用之前我们先简单了解下两个数据结构 Upgrader 和 Conn
Upgrader
Upgrader指定用于将 HTTP 连接升级到 WebSocket 连接
type Upgrader struct {
HandshakeTimeout time.Duration
ReadBufferSize, WriteBufferSize int
WriteBufferPool BufferPool
Subprotocols []string
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
CheckOrigin func(r *http.Request) bool
EnableCompression bool
}
- HandshakeTimeout: 握手完成的持续时间
- ReadBufferSize和WriteBufferSize:以字节为单位指定I/O缓冲区大小。如果缓冲区大小为零,则使用HTTP服务器分配的缓冲区
- CheckOrigin : 函数应仔细验证请求来源 防止跨站点请求伪造
这里一般会设置下CheckOrigin来解决跨域问题
Conn
Conn类型表示WebSocket连接,这个结构体的组成包括两部分,写入字段(Write fields)和 读取字段(Read fields)
type Conn struct {
conn net.Conn
isServer bool
// Write fields
writeBuf []byte
writePool BufferPool
writeBufSize int
writer io.WriteCloser
isWriting bool
// Read fields
readRemaining int64
readFinal bool
readLength int64
messageReader *messageReader
}
isServer : 字段来区分我们是否用Conn作为客户端还是服务端,也就是说说gorilla/websocket中同时编写客户端程序和服务器程序,但是一般是Web应用程序使用单独的前端作为客户端程序。
部分字段说明如下图:
服务端示例
出于说明的目的,我们将在Go中同时编写客户端程序和服务端程序(其实小许是前端小趴菜😅 🤭)。
当然我们在开发程序的时候基本都是单独的前端,通常使用(Javascript,vue等)实现websocket客户端,这里为了让大家有比较直观的感受,用【gorilla/websocket】分别写了服务端和客户端示例。
var upGrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
func main() {
http.HandleFunc("/ws", wsUpGrader)
err := http.ListenAndServe("localhost:8080", nil)
if err != nil {
log.Println("server start err", err)
func wsUpGrader(w http.ResponseWriter, r *http.Request) {
//转换为升级为websocket
conn, err := upGrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
//释放连接
defer conn.Close()
for {
//接收消息
messageType, message, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
log.Println("server receive messageType", messageType, "message", string(message))
//发送消息
err = conn.WriteMessage(messageType, []byte("pong"))
if err != nil {
log.Println(err)
return
}
我们知道websocket协议是基于http协议进行upgrade升级的, 这里使用 net/http提供原始的http连接。
http.HandleFunc接受两个参数:第一个参数是字符串表示的 url 路径,第二个参数是该 url 实际的处理对象
http.ListenAndServe 监听在某个端口,启动服务,准备接受客户端的请求
HandleFunc的作用:通过类型转换让我们可以将普通的函数作为HTTP处理器使用
服务端代码流程:
- Gorilla在使用websocket之前是先将http装维websocket,用的是初始化的upGrader结构体变量调用Upgrade方法进行请求协议升级
- 升级后返回 *Conn(此时isServer = true),后续使用它来处理websocket连接
- 服务端消息读写分别用 ReadMessage()、WriteMessage()
客户端示例
import (
"fmt"
"github.com/gorilla/websocket"
"log"
"time"
func main() {
//服务器地址 websocket 统一使用 ws://
url := "ws://localhost:8080/ws"
//使用默认拨号器,向服务器发送连接请求
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
log.Fatal(err)
//关闭连接
defer conn.Close()
//发送消息
go func() {
for {
err := ws.WriteMessage(websocket.BinaryMessage, []byte("ping"))
if err != nil {
log.Fatal(err)
//休眠两秒
time.Sleep(time.Second * 2)
//接收消息
for {
_, data, err := ws.ReadMessage()
if err != nil {
log.Fatal(err)
fmt.Println("client receive message: ", string(data))
}
客户端的实现看起来也是简单,先使用默认拨号器,向服务器地址发送连接请求,拨号成功时也返回一个*Conn,开启一个协程每隔两秒向服务端发送消息,同样都是使用ReadMessage和W riteMessage读写消息。
示例代码运行结果如下:
源码走读
看完上面基本的客户端和服务端案例之后,我们对整个消息发送和接收的使用已经熟悉了,实际开发中要做的就是如何结合业务去定义消息类型和发送场景了,我们接着走读下底层的实现逻辑!
代码走读我们分了四部分,主要了解协议是如何升级、已经消息如何读写、解析数据帧【 🚩 🚩核心】!
Upgrade 协议升级
Upgrade顾名思义【升级】,在进行协议升级之前是需要对协议进行校验的,之前我们知道待升级的http请求是有固定请求头的,这里列举几个:
✏️ Upgrade进行校验的目的是看该请求是否符合协议升级的规定
Upgrade的部分校验代码如下,return处进行了省略
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return ...
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return ...
//必须是get请求方法
if r.Method != http.MethodGet {
return ...
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return ...
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return ...
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
}
tokenListContainsValue的目的是校验请求的Header中是否有upgrade需要的特定参数,比如我们上图列举的一些。
newConn就是初始化部分Conn结构体的,方法中的第二个参数为true代表这是服务端
computeAcceptKey 计算接受密钥:
这个函数重点说下,在上一期中在websocket【连接确认】这一章节中知道,websocket协议升级时,需要满足如下条件:
✏️只有当请求头参数Sec-WebSocket-Key字段的值经过固定算法加密后的数据和响应头里的Sec-WebSocket-Accept的值保持一致,该连接才会被认可建立。
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
上面 computeAcceptKey 函数的实现,验证了之前说的关于 Sec-WebSocket-Accept的生成
服务端需将Sec-WebSocket-Key和固定的 GUID 字符串( 258EAFA5-E914-47DA-95CA-C5AB0DC85B11) 拼接后使用 SHA-1 进行哈希,并采用 base64 编码后返回
ReadMessage 读消息
ReadMessage方法内部使用NextReader获取读取器并从该读取器读取到缓冲区,如果是一条消息由多个数据帧,则会拼接成完整的消息,返回给业务层。
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
var r io.Reader
messageType, r, err = c.NextReader()
if err != nil {
return messageType, nil, err
//ReadAll从r读取,直到出现错误或EOF,并返回读取的数据
p, err = io.ReadAll(r)
return messageType, p, err
}
该方法,返回三个参数,分别是消息类型、内容、error
messageType是int型,值可能是 BinaryMessage(二进制消息) 或 TextMessage(文本消息)
NextReader: 该方法得到一个消息类型 messageType,io.Reader,err
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
for c.readErr == nil {
//解析数据帧方法advanceFrame
// frameType : 帧类型
frameType, err := c.advanceFrame()
if err != nil {
c.readErr = hideTempErr(err)
break
//数据类型是 文本或二进制类型
if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
c.reader = c.messageReader
if c.readDecompress {
c.reader = c.newDecompressionReader(c.reader)
return frameType, c.reader, nil
}
c.advanceFrame() 是核心代码,主要是实现解析这条消息,这里在最后章节会讲。
这里有个 c.messageReader (当前的低级读取器),赋值给c.reader,为什么要这样呢?
c.messageReader 是更低级读取器,而 c.reader 的作用是当前读取器返回到应用程序。简单就是messageReader 是实现了 c.reader 接口的结构体, 从而也实现了 io.Reader接口
图上加一个 bufio.Read方法:Read读取数据写入p。本方法返回写入p的字节数。本方法一次调用最多会调用下层Reader接口一次Read方法,因此返回值n可能小于len(p)。读取到达结尾时,返回值n将为0而err将为io.EOF
messageReader的 Read方法: 我们看下Read的具体实现,Read方法主要是读取数据帧内容,直到出现并返回io.EOF或者其他错误为止,而实际调用它的正是 io.ReadAll。
func (r *messageReader) Read(b []byte) (int, error) {
for c.readErr == nil {
//当前帧中剩余的字节
if c.readRemaining > 0 {
if int64(len(b)) > c.readRemaining {
b = b[:c.readRemaining]
//读取到切片b中
n, err := c.br.Read(b)
c.readErr = hideTempErr(err)
//当Conn是服务端
if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
//readRemaining字节数转int64
rem := c.readRemaining
rem -= int64(n)
//跟踪连接上剩余的字节数
if err := c.setReadRemaining(rem); err != nil {
return 0, err
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
//返回读后字节数
return n, c.readErr
//标记是否最后一个数据帧
if c.readFinal {
// messageRader 置为nil
c.messageReader = nil
return 0, io.EOF
//获取数据帧类型
frameType, err := c.advanceFrame()
switch {
case err != nil:
c.readErr = hideTempErr(err)
case frameType == TextMessage || frameType == BinaryMessage:
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
err := c.readErr
if err == io.EOF && c.messageReader == r {
err = errUnexpectedEOF
return 0, err
}
io.ReadAll : ReadAll从r读取,这里是实现如果一条消息由多个数据帧,会一直读直到最后一帧的关键。
func ReadAll(r Reader) ([]byte, error) {
b := make([]byte, 0, 512)
for {
if len(b) == cap(b) {
// 给[]byte添加更多容量
b = append(b, 0)[:len(b)]
n, err := r.Read(b[len(b):cap(b)])
b = b[:len(b)+n]
if err != nil {
if err == EOF {
err = nil
return b, err
}
可以看出在for 循环中一直读取,直至读取到最后一帧,直到返回io.EOF或网络原因错误为止,否则一直进行阻塞读,这些 error 可以从上面讲到的messageReader的 Read方法可以看出来。
总结下,整个流程如下:
整个读消息的流程就结束了,我们继续看如何写消息
WriteMessage 写消息
既然读消息是对数据帧进行解析,那么写消息就自然会联想到将数据按照数据帧的规范组装写入到一个writebuf中,然后写入到网络中。
我们继续看 WriteMessage 是如何实现的
func (c *Conn) WriteMessage(messageType int, data []byte) error {
//w 是一个io.WriteCloser
w, err := c.NextWriter(messageType)
if err != nil {
return err
//将data写入writeBuf中
if _, err = w.Write(data); err != nil {
return err
return w.Close()
}
WriteMessage方法接收一个消息类型和数据,主要逻辑是先调用 Conn的NextWriter 方法得到一个io.WriteCloser,然后写消息到这个Conn的writeBuf,写完消息后close它。
NextWriter实现如下:
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err
c.writer = &mw
return c.writer, nil
}
注意看这里有个messageWriter赋值给了Conn的writer,也就是说messageWriter实现了io.WriterCloser接口。
这里的实现跟读消息中的NextReader方法中的messageReader很像,也是通过实现io.Reader接口,然后赋值给了Conn的Reader,这里可以做个小联动,找到读写消息实际的实现者 messageReader、messageWriter。
messageWriter的Write实现:
前置知识:如果没有设置Conn中writeBufferSize, 默认情况下会设置为 4096个字节,另外加上14字节的数据帧头部大小【这些在newConn中初始化的时候有代码说明】
func (w *messageWriter) Write(p []byte) (int, error) {
//如果字节长度大于初始化的writeBuf空间大小
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
//写入方法
err := w.flushFrame(false, p)
//字节长度不大于初始化的writeBuf空间大小
nn := len(p)
for len(p) > 0 {
//内部也是调用的flushFrame
n, err := w.ncopy(len(p))
return nn, nil
}
messageWriter中的Write方法主要的目的是将数据写入到writeBuf中,它主要存储结构化的数据帧内容,所谓结构化就是按照数据帧的格式,用Go实现写入的。
总结下,整个流程如下:
而flushFrame方法将缓冲数据和额外数据作为帧写入网络,这个final参数表示这是消息中的最后一帧。
至于flushFrame内部是如何实现写入网络中的,你可以看看 net.Conn 是怎么Write的,因为最终就是调这个写入网络的,这里就不再深究了,有兴趣的同学可以自己挖一挖!
advanceFrame 解析数据帧
解析数据帧放在最后,前面的代码走读主要是为了方便大家能把整体流程搞清楚,而数据帧的解析,是更加需要对websocket基础有了解,特别是数据帧的组成,因为解析就是按照协定用Go代码实现的一种方式而已!
强烈推荐大家看完:websocket基础文章链接
根据上图【来自网络】回顾下数据帧各部分代表的意思:
FIN : 1个bit位,用来标记当前数据帧是不是最后一个数据帧 RSV1, RSV2, RSV3 :这三个各占用一个bit位用做扩展用途,没有这个需求的话设置位0 Opcode : 该值定义的是数据帧的数据类型 1 表示文本 2 表示二进制 MASK : 表示数据有没有使用掩码 Payload length :数据的长度,Payload data的长度,占7bits,7+16bits,7+64bits Masking-key :数据掩码 (设置位0,则该部分可以省略,如果设置位1,则用来解码客户端发送给服务端的数据帧) Payload data : 帧真正要发送的数据,可以是任意长度
advanceFrame 解析方法
实现代码会比较长,如果直接贴代码,会看不下去,该方法返回数据类型和error, 这里我们只会截取其中一部分
func (c *Conn) advanceFrame() (int, error) {
//读取前两个字节
p, err := c.read(2)
if err != nil {
return noFrame, err
//数据帧类型
frameType := int(p[0] & 0xf)
// FIN 标记位
final := p[0]&finalBit != 0
//三个扩展用
rsv1 := p[0]&rsv1Bit != 0
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
//mask :是否使用掩码
mask := p[1]&maskBit != 0
switch c.readRemaining {
case 126:
p, err := c.read(2)
if err != nil {
return noFrame, err
if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
return noFrame, err
case 127:
p, err := c.read(8)
if err != nil {
return noFrame, err
if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
return noFrame, err