码字不易,转载请附原链,搬砖繁忙回复不及时见谅,技术交流请加QQ群:909211071
目录
前言
之前做的服务大多基于 HTTP 实现服务端和客户端,可以通过 API 接口的方式进行调用,优点是不受语言限制、调用方便。但是众所周知 HTTP 是应用层协议,所以对于性能要求较高的服务来说,更适合用基于 TCP 的 RPC 服务。由于最近在转 Go 语言,标准库提供了开箱即用的 RPC 包,之前一直没接触过,最近正好打算学习 RPC,所以通过一个最简单的 RPC 服务来研究了下 net/rpc 包的实现原理。
示例代码
例子非常简单,客户端发送一个字符串,服务端返回 hello + 字符串作为响应,代码如下:
服务端代码
- 定义一个服务结构体和 Hello 方法,注意方法的第一个参数为请求数据,第二个参数为响应数据,注意必须要和客户端的请求参数和响应参数类型保持一致,否则会触发panic
- 通过 rpc.RegisterName 注册服务和方法,第一个参数接收一个 string 作为服务名,如果不传默认为服务结构体名称,第二个参数接收一个地址
- 通过 net.Listen 开启一个 TCP 服务端
- 通过 Accept 不断接收连接请求,一旦有连接到来,调用 rpc.ServeConn 处理请求
package main
import (
"log"
"net"
"net/rpc"
)
type HelloService struct{}
func (p *HelloService) Hello(req string, rep *string) error {
*rep = "hello " + req
return nil
}
func main() {
rpc.RegisterName("HelloService", new(HelloService))
listener, err := net.Listen("tcp", ":1234")
if err != nil {
log.Fatal("listen tpc error:", err)
}
for {
conn, err := listener.Accept()
if err != nil {
log.Fatal("accept error:", err)
}
rpc.ServeConn(conn)
}
}
客户端代码
- 通过 net.Dial 开启一个连接到服务端的 TCP 会话连接
- 通过 client.Call 调用服务对应方法,第一个参数为 “服务名.方法名”,第二个参数为请求数据,第三个参数为接收返回数据的变量地址,注意第二个和第三个参数必须要和服务端的请求参数和响应参数类型保持一致,否则会触发panic
package main
import (
"fmt"
"log"
"net/rpc"
)
func main() {
client, err := rpc.Dial("tcp", "localhost:1234")
if err != nil {
log.Fatal("dial error:", err)
}
var rep string
err = client.Call("HelloService.Hello", "hello", &rep)
if err != nil {
log.Fatal(err)
}
fmt.Println(rep)
}
只需要几十行代码,就可以实现一个简单的 RPC 服务,不得不感叹 net/rpc 包的强大,越是简单的代码,就越对其实现原理感兴趣,下面我们一步步看一下标准库是如何实现的。
源码探究
服务端
核心结构体
首先看一下核心的 4 个结构体:
一定要把每个变量是做什么的印在自己脑子里,这对对于下面阅读源码起到至关重要的作用,如果没印象可能会看得一头雾水
//保存单个服务信息
type service struct {
name string // 服务名
rcvr reflect.Value // 服务接受者值
typ reflect.Type // 服务接受者类型
method map[string]*methodType // 当前服务注册的方法Map
}
//RPC请求结构体
// Request is a header written before every RPC call. It is used internally
// but documented here as an aid to debugging, such as when analyzing
// network traffic.
type Request struct {
ServiceMethod string // 当前请求对应的服务方法
Seq uint64 // 请求序列号
next *Request // 记录server下一个请求指针
}
// Response is a header written before every RPC return. It is used internally
// but documented here as an aid to debugging, such as when analyzing
// network traffic.
type Response struct {
ServiceMethod string // 当前响应对应的服务方法
Seq uint64 // 对应响应序列号
Error string // 记录响应错误
next *Response // 记录server下一个响应指针
}
//RPC服务结构体
// Server represents an RPC Server.
type Server struct {
serviceMap sync.Map // 服务列表Map
reqLock sync.Mutex // 保护读取请求缓冲区的互斥锁
freeReq *Request // 空闲Request地址,用于内存复用
respLock sync.Mutex // 保护写入响应缓冲区的互斥锁
freeResp *Response // 空闲Response地址,用于内存复用
}
服务注册
首先我们在服务端的 rpc.RegisterName 添加断点,运行调试,一路下一步,进入到 net/rpc 包下的 server.go 的 register 方法:
注意看第 255 行的方法调用,第三个参数传的 true,代表使用传入的服务名 name。
register 完整函数代码如下,关键注释已给出:
func (server *Server) register(rcvr interface{}, name string, useName bool) error {
//这是 service 结构体
//type service struct {
// name string // 保存服务名
// rcvr reflect.Value // 服务方法接收者
// typ reflect.Type // 服务方法接受者变量对应类型
// method map[string]*methodType // 服务启动前注册的方法Map
//}
s := new(service)
//通过反射解析并保存服务方法接受者类型
s.typ = reflect.TypeOf(rcvr)
//通过反射解析并保存服务方法接受者
s.rcvr = reflect.ValueOf(rcvr)
//通过反射解析并保存服务名(Indirect解析变量地址)
sname := reflect.Indirect(s.rcvr).Type().Name()
//如果第三个参数为true,代表默认以传入服务名为准
if useName {
sname = name
}
if sname == "" {
s := "rpc.Register: no service name for type " + s.typ.String()
log.Print(s)
return errors.New(s)
}
if !token.IsExported(sname) && !useName {
s := "rpc.Register: type " + sname + " is not exported"
log.Print(s)
return errors.New(s)
}
s.name = sname
// 注册保存方法名和方法映射,这里用反射解析服务struct定义的所有方法
s.method = suitableMethods(s.typ, true)
if len(s.method) == 0 {
str := ""
// To help the user, see if a pointer receiver would work.
method := suitableMethods(reflect.PtrTo(s.typ), false)
if len(method) != 0 {
str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
} else {
str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
}
log.Print(str)
return errors.New(str)
}
if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
return errors.New("rpc: service already defined: " + sname)
}
return nil
}
整体服务请求逻辑
至此,服务注册已经完成,整体通过反射实现,简单来说就是根据服务结构体注册服务名,解析结构体方法注册服务的所有方法,下面我们运行客户端代码调试下调用逻辑:
我们直接一路下一步,直接看重点,下面是调用栈:
SerceCodec 方法主要做了以下几件事,源码已加中文注释:
- 通过 sync.Mutex 控多个协程向响应缓冲区写入响应数据
- 通过 sync.WaitGroup 控制 goroutine 等待响应写入完成后关闭
-
调用 server.readRequest 方法解析请求数据
- 通过 go 关键字用 goroutine 处理调用请求
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func (server *Server) ServeCodec(codec ServerCodec) {
//控制并发写入response
sending := new(sync.Mutex)
//控制调用goroutine等待缓冲区写入完成后关闭
wg := new(sync.WaitGroup)
for {
//解析客户端连接发送的请求数据
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
if debugLog && err != io.EOF {
log.Println("rpc:", err)
}
if !keepReading {
break
}
// send a response if we actually managed to read a header.
if req != nil {
server.sendResponse(sending, req, invalidRequest, codec, err.Error())
server.freeRequest(req)
}
continue
}
wg.Add(1)
//通过goroutine调用方法处理并写入响应,别急,下面会进去分析
go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
}
// We've seen that there are no more requests.
// Wait for responses to be sent before closing codec.
wg.Wait()
//关闭codec写入流
codec.Close()
}
我们这里着重看一下 codec 的结构,这里预留了两个扩展点:
- dec 和 enc 两个字段可以通过插件实现自定义的编码和解码
- RPC 协议建立在抽象的 io.ReadWriteCloser 接口之上,可以灵活替换通信协议
type gobServerCodec struct {
rwc io.ReadWriteCloser
dec *gob.Decoder
enc *gob.Encoder
encBuf *bufio.Writer
closed bool
}
Request逻辑
核心逻辑已给出中文注释,直接结合注释看代码:
func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
service, mtype, req, keepReading, err = server.readRequestHeader(codec)
if err != nil {
if !keepReading {
return
}
// discard body
codec.ReadRequestBody(nil)
return
}
// 下面是通过反射解析请求体中的服务和方法
}
func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
// Grab the request header.
//获得Request
req = server.getRequest()
//解析Request Header
err = codec.ReadRequestHeader(req)
if err != nil {
req = nil
if err == io.EOF || err == io.ErrUnexpectedEOF {
return
}
err = errors.New("rpc: server cannot decode request: " + err.Error())
return
}
// We read the header successfully. If we see an error now,
// we can still recover and move on to the next request.
keepReading = true
dot := strings.LastIndex(req.ServiceMethod, ".")
if dot < 0 {
err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
return
}
serviceName := req.ServiceMethod[:dot]
methodName := req.ServiceMethod[dot+1:]
// 加载服务名
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc: can't find service " + req.ServiceMethod)
return
}
svc = svci.(*service)
//获得服务对应方法
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc: can't find method " + req.ServiceMethod)
}
return
}
func (server *Server) getRequest() *Request {
//加锁
server.reqLock.Lock()
//保存当前空闲Req
req := server.freeReq
if req == nil {
//如果空闲Req指针为空,则new一个新的Request对象
req = new(Request)
} else {
//如果空闲Req指针不为空,则将空闲Req指针指向请求
server.freeReq = req.next
//置空保存好的空闲Req,作为当前请求容器,避免new申请内存
*req = Request{}
}
//解锁
server.reqLock.Unlock()
return req
}
方法调用核心代码
好了,接下来我们去看一下核心的服务调用代码 service.call:
func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
//首先开始 defer,在调用处理完后关闭等待组。
if wg != nil {
defer wg.Done()
}
//这里通过加锁,实现服务端每个方法的调用次数统计
//type methodType struct {
// sync.Mutex // protects counters
// method reflect.Method
// ArgType reflect.Type
// ReplyType reflect.Type
// numCalls uint
//}
mtype.Lock()
mtype.numCalls++
mtype.Unlock()
function := mtype.method.Func
// Invoke the method, providing a new value for the reply.
//通过反射调用对应方法,并将响应写入replyv中
returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
// The return value for the method is an error.
errInter := returnValues[0].Interface()
errmsg := ""
if errInter != nil {
errmsg = errInter.(error).Error()
}
//发送响应
server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
//释放请求资源
server.freeRequest(req)
}
Response逻辑
func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
//获得Response结构体,用于写入响应
resp := server.getResponse()
// Encode the response header
//记录请求方法
resp.ServiceMethod = req.ServiceMethod
if errmsg != "" {
resp.Error = errmsg
reply = invalidRequest
}
resp.Seq = req.Seq
sending.Lock()
//写入响应数据
err := codec.WriteResponse(resp, reply)
if debugLog && err != nil {
log.Println("rpc: writing response:", err)
}
sending.Unlock()
//释放响应结构体占用内存
server.freeResponse(resp)
}
getResponse+freeResponse 逻辑思路和 getRequeset + freeRequest 完全一致,也是复用内存地址,不再具体分析,下面是代码:
func (server *Server) getResponse() *Response {
server.respLock.Lock()
resp := server.freeResp
if resp == nil {
resp = new(Response)
} else {
server.freeResp = resp.next
*resp = Response{}
}
server.respLock.Unlock()
return resp
}
func (server *Server) freeResponse(resp *Response) {
server.respLock.Lock()
resp.next = server.freeResp
server.freeResp = resp
server.respLock.Unlock()
}
服务端逻辑总结
整套RPC服务调用逻辑如下:
RegisterName -> register -> Listen -> Accept -> ServeConn -> ServeCodec -> readRequest -> getRequest -> ReadRequestHeader -> call -> sendResponse -> freeRequest
客户端
客户端就比较简单了,我们大致看一下流程就好了:
- rpc.Dial 连接服务端,启动后台 goroutine 不断从响应缓冲区中读取响应,通过 chan 返回响应数据
- client.Call 阻塞等待 chan 返回并解析响应数据
核心结构体
// Call represents an active RPC.
type Call struct {
ServiceMethod string // The name of the service and method to call.
Args interface{} // The argument to the function (*struct).
Reply interface{} // The reply from the function (*struct).
Error error // After completion, the error status.
Done chan *Call // Strobes when call is complete.
}
// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
codec ClientCodec
reqMutex sync.Mutex // 保护并发请求安全
request Request
mutex sync.Mutex // 保护并发响应读取安全
seq uint64
pending map[uint64]*Call // 接收响应的通道map
closing bool // user has called Close
shutdown bool // server has told us to stop
}
连接服务端并启动 channel 等待接收响应
下面是源码,重点看中文注释即可:
// rpc.Dial("tcp", "localhost:1234")
// Dial connects to an RPC server at the specified network address.
func Dial(network, address string) (*Client, error) {
//开启TCP连接
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
//返回RPC客户端
return NewClient(conn), nil
}
func NewClient(conn io.ReadWriteCloser) *Client {
//通过参数控制通过gob编码请求和响应数据
encBuf := bufio.NewWriter(conn)
client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
return NewClientWithCodec(client)
}
func NewClientWithCodec(codec ClientCodec) *Client {
client := &Client{
codec: codec,
pending: make(map[uint64]*Call), //等待响应的channel
}
//后台协程持续监听响应
go client.input()
return client
}
func (client *Client) input() {
var err error
var response Response
for err == nil {
response = Response{}
//读取响应头
err = client.codec.ReadResponseHeader(&response)
if err != nil {
break
}
//读取响应序列号
seq := response.Seq
//加锁读取响应方法
client.mutex.Lock()
call := client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
switch {
case call == nil:
// We've got no pending call. That usually means that
// WriteRequest partially failed, and call was already
// removed; response is a server telling us about an
// error reading request body. We should still attempt
// to read error body, but there's no one to give it to.
err = client.codec.ReadResponseBody(nil)
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
case response.Error != "":
// We've got an error response. Give this to the request;
// any subsequent requests will get the ReadResponseBody
// error if there is one.
call.Error = ServerError(response.Error)
err = client.codec.ReadResponseBody(nil)
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
call.done()
default:
//读取响应body
err = client.codec.ReadResponseBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
//通过chan发送响应数据
call.done()
}
}
// Terminate pending calls.
client.reqMutex.Lock()
client.mutex.Lock()
client.shutdown = true
closing := client.closing
if err == io.EOF {
if closing {
err = ErrShutdown
} else {
err = io.ErrUnexpectedEOF
}
}
for _, call := range client.pending {
call.Error = err
call.done()
}
client.mutex.Unlock()
client.reqMutex.Unlock()
if debugLog && err != io.EOF && !closing {
log.Println("rpc: client protocol error:", err)
}
}
发起客户端调用
// err = client.Call("HelloService.Hello", "hello", &rep)
// Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
return call.Error
}
// Go invokes the function asynchronously. It returns the Call structure representing
// the invocation. The done channel will signal when the call is complete by returning
// the same Call object. If done is nil, Go will allocate a new channel.
// If non-nil, done must be buffered or Go will deliberately crash.
func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
call := new(Call)
call.ServiceMethod = serviceMethod
call.Args = args
call.Reply = reply
if done == nil {
done = make(chan *Call, 10) // buffered.
} else {
// If caller passes done != nil, it must arrange that
// done has enough buffer for the number of simultaneous
// RPCs that will be using that channel. If the channel
// is totally unbuffered, it's best not to run at all.
if cap(done) == 0 {
log.Panic("rpc: done channel is unbuffered")
}
}
call.Done = done
client.send(call)
return call
}
func (client *Client) send(call *Call) {
//加锁保护并发请求
client.reqMutex.Lock()
defer client.reqMutex.Unlock()
// Register this call.
//加锁保护多个call channel对应的map并发读写
client.mutex.Lock()
//如果客户端关闭停止发送请求
if client.shutdown || client.closing {
client.mutex.Unlock()
call.Error = ErrShutdown
call.done()
return
}
//递增序列号,并保存对应 call
seq := client.seq
client.seq++
client.pending[seq] = call
client.mutex.Unlock()
// Encode and send the request.
client.request.Seq = seq
client.request.ServiceMethod = call.ServiceMethod
//写入请求
err := client.codec.WriteRequest(&client.request, call.Args)
if err != nil {
client.mutex.Lock()
call = client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
if call != nil {
call.Error = err
call.done()
}
}
}
func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) {
//编码request
if err = c.enc.Encode(r); err != nil {
return
}
//编码request body
if err = c.enc.Encode(body); err != nil {
return
}
//刷入缓冲区
return c.encBuf.Flush()
}
整体逻辑
rpc.Dial -> NewClient -> NewClientWithCodec -> go client.input -> client.codec.ReadResponseHeader -> client.codec.ReadResponseBody
client.Call -> client.Go -> client.send -> client.codec.WriteRequest -> call.done
总结
- 使用时,服务端和客户端的请求和响应类型应保持一致,否则会触发panic
- net/rpc通过反射解析服务结构体、调用方法、请求参数、服务端响应
- 通过加锁保护服务端的 Request 和 Response 结构体,并复用空闲结构体,减少内存分配次数
- 通过加锁保护并发读写请求和响应缓冲区
- 可以通过 mtype.numCalls 获得服务端每个方法的从启动开始的累计调用次数
- 通过序列号保证读写数据的关联关系
- dec 和 enc 两个字段可以通过插件实现自定义的编码和解码
- RPC 协议建立在抽象的 io.ReadWriteCloser 接口之上,可以灵活替换通信协议
思考:Gob 编码是 Go 语言特有的,其他服务调用起来比较困难,既然标准库为我们预留了自定义编码和解码,能否自己实现一个 json 格式的编码解码?