RPC系列之:通过一个最简单的RPC服务探究net/rpc源码实现

码字不易,转载请附原链,搬砖繁忙回复不及时见谅,技术交流请加QQ群:909211071

目录

前言

示例代码

服务端代码

客户端代码

源码探究

服务端

核心结构体

服务注册

整体服务请求逻辑

Request逻辑

方法调用核心代码

Response逻辑

服务端逻辑总结

客户端

核心结构体

连接服务端并启动 channel 等待接收响应

发起客户端调用

整体逻辑

总结


前言

之前做的服务大多基于 HTTP 实现服务端和客户端,可以通过 API 接口的方式进行调用,优点是不受语言限制、调用方便。但是众所周知 HTTP 是应用层协议,所以对于性能要求较高的服务来说,更适合用基于 TCP 的  RPC 服务。由于最近在转 Go 语言,标准库提供了开箱即用的 RPC 包,之前一直没接触过,最近正好打算学习 RPC,所以通过一个最简单的 RPC 服务来研究了下 net/rpc 包的实现原理。

示例代码

例子非常简单,客户端发送一个字符串,服务端返回 hello + 字符串作为响应,代码如下:

服务端代码

  1. 定义一个服务结构体和 Hello 方法,注意方法的第一个参数为请求数据,第二个参数为响应数据,注意必须要和客户端的请求参数和响应参数类型保持一致,否则会触发panic
  2. 通过 rpc.RegisterName 注册服务和方法,第一个参数接收一个 string 作为服务名,如果不传默认为服务结构体名称,第二个参数接收一个地址
  3. 通过 net.Listen 开启一个 TCP 服务端
  4. 通过 Accept 不断接收连接请求,一旦有连接到来,调用 rpc.ServeConn 处理请求
package main

import (
	"log"
	"net"
	"net/rpc"
)

type HelloService struct{}

func (p *HelloService) Hello(req string, rep *string) error {
	*rep = "hello " + req
	return nil
}

func main() {
	rpc.RegisterName("HelloService", new(HelloService))

	listener, err := net.Listen("tcp", ":1234")
	if err != nil {
		log.Fatal("listen tpc error:", err)
	}

	for {
		conn, err := listener.Accept()
		if err != nil {
			log.Fatal("accept error:", err)
		}
		rpc.ServeConn(conn)
	}
}

客户端代码

  1. 通过 net.Dial 开启一个连接到服务端的 TCP 会话连接
  2. 通过 client.Call 调用服务对应方法,第一个参数为 “服务名.方法名”,第二个参数为请求数据,第三个参数为接收返回数据的变量地址,注意第二个和第三个参数必须要和服务端的请求参数和响应参数类型保持一致,否则会触发panic
package main

import (
	"fmt"
	"log"
	"net/rpc"
)

func main() {
	client, err := rpc.Dial("tcp", "localhost:1234")
	if err != nil {
		log.Fatal("dial error:", err)
	}

	var rep string
	err = client.Call("HelloService.Hello", "hello", &rep)
	if err != nil {
		log.Fatal(err)
	}

	fmt.Println(rep)
}

只需要几十行代码,就可以实现一个简单的 RPC 服务,不得不感叹 net/rpc 包的强大,越是简单的代码,就越对其实现原理感兴趣,下面我们一步步看一下标准库是如何实现的。

 

源码探究

服务端

核心结构体

首先看一下核心的 4 个结构体:

一定要把每个变量是做什么的印在自己脑子里,这对对于下面阅读源码起到至关重要的作用,如果没印象可能会看得一头雾水

//保存单个服务信息
type service struct {
	name   string                 // 服务名
	rcvr   reflect.Value          // 服务接受者值
	typ    reflect.Type           // 服务接受者类型
	method map[string]*methodType // 当前服务注册的方法Map
}

//RPC请求结构体
// Request is a header written before every RPC call. It is used internally
// but documented here as an aid to debugging, such as when analyzing
// network traffic.
type Request struct {
	ServiceMethod string   // 当前请求对应的服务方法
	Seq           uint64   // 请求序列号
	next          *Request // 记录server下一个请求指针
}

// Response is a header written before every RPC return. It is used internally
// but documented here as an aid to debugging, such as when analyzing
// network traffic.
type Response struct {
	ServiceMethod string    // 当前响应对应的服务方法
	Seq           uint64    // 对应响应序列号
	Error         string    // 记录响应错误
	next          *Response // 记录server下一个响应指针
}

//RPC服务结构体
// Server represents an RPC Server.
type Server struct {
	serviceMap sync.Map   // 服务列表Map
	reqLock    sync.Mutex // 保护读取请求缓冲区的互斥锁
	freeReq    *Request   // 空闲Request地址,用于内存复用
	respLock   sync.Mutex // 保护写入响应缓冲区的互斥锁
	freeResp   *Response  // 空闲Response地址,用于内存复用
}

服务注册

首先我们在服务端的 rpc.RegisterName 添加断点,运行调试,一路下一步,进入到 net/rpc 包下的 server.go 的 register 方法:

RPC系列之:通过一个最简单的RPC服务探究net/rpc源码实现

注意看第 255 行的方法调用,第三个参数传的 true,代表使用传入的服务名 name。

register 完整函数代码如下,关键注释已给出:

func (server *Server) register(rcvr interface{}, name string, useName bool) error {
    //这是 service 结构体
    //type service struct {
	//    name   string                 // 保存服务名
    //	  rcvr   reflect.Value          // 服务方法接收者
    //	  typ    reflect.Type           // 服务方法接受者变量对应类型
    //    method map[string]*methodType // 服务启动前注册的方法Map
    //}
	s := new(service)
    //通过反射解析并保存服务方法接受者类型
	s.typ = reflect.TypeOf(rcvr)
    //通过反射解析并保存服务方法接受者
	s.rcvr = reflect.ValueOf(rcvr)
    //通过反射解析并保存服务名(Indirect解析变量地址)
	sname := reflect.Indirect(s.rcvr).Type().Name()
    //如果第三个参数为true,代表默认以传入服务名为准
	if useName {
		sname = name
	}
	if sname == "" {
		s := "rpc.Register: no service name for type " + s.typ.String()
		log.Print(s)
		return errors.New(s)
	}
	if !token.IsExported(sname) && !useName {
		s := "rpc.Register: type " + sname + " is not exported"
		log.Print(s)
		return errors.New(s)
	}
	s.name = sname

	// 注册保存方法名和方法映射,这里用反射解析服务struct定义的所有方法
	s.method = suitableMethods(s.typ, true)

	if len(s.method) == 0 {
		str := ""

		// To help the user, see if a pointer receiver would work.
		method := suitableMethods(reflect.PtrTo(s.typ), false)
		if len(method) != 0 {
			str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
		} else {
			str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
		}
		log.Print(str)
		return errors.New(str)
	}

	if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
		return errors.New("rpc: service already defined: " + sname)
	}
	return nil
}

整体服务请求逻辑

至此,服务注册已经完成,整体通过反射实现,简单来说就是根据服务结构体注册服务名,解析结构体方法注册服务的所有方法,下面我们运行客户端代码调试下调用逻辑:

我们直接一路下一步,直接看重点,下面是调用栈:

RPC系列之:通过一个最简单的RPC服务探究net/rpc源码实现

SerceCodec 方法主要做了以下几件事,源码已加中文注释:

关于协程的同步控制可以参考这篇文章

  1. 通过 sync.Mutex 控多个协程向响应缓冲区写入响应数据
  2. 通过 sync.WaitGroup 控制 goroutine 等待响应写入完成后关闭
  3. 调用 server.readRequest 方法解析请求数据

  4. 通过 go 关键字用 goroutine 处理调用请求
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func (server *Server) ServeCodec(codec ServerCodec) {
    //控制并发写入response
	sending := new(sync.Mutex)
    //控制调用goroutine等待缓冲区写入完成后关闭
	wg := new(sync.WaitGroup)
	for {
        //解析客户端连接发送的请求数据
		service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
		if err != nil {
			if debugLog && err != io.EOF {
				log.Println("rpc:", err)
			}
			if !keepReading {
				break
			}
			// send a response if we actually managed to read a header.
			if req != nil {
				server.sendResponse(sending, req, invalidRequest, codec, err.Error())
				server.freeRequest(req)
			}
			continue
		}
		wg.Add(1)
        //通过goroutine调用方法处理并写入响应,别急,下面会进去分析
		go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
	}
	// We've seen that there are no more requests.
	// Wait for responses to be sent before closing codec.
	wg.Wait()
    //关闭codec写入流
	codec.Close()
}

我们这里着重看一下 codec 的结构,这里预留了两个扩展点:

  • dec 和 enc 两个字段可以通过插件实现自定义的编码和解码
  • RPC 协议建立在抽象的 io.ReadWriteCloser 接口之上,可以灵活替换通信协议
type gobServerCodec struct {
	rwc    io.ReadWriteCloser
	dec    *gob.Decoder
	enc    *gob.Encoder
	encBuf *bufio.Writer
	closed bool
}

Request逻辑

核心逻辑已给出中文注释,直接结合注释看代码:

func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
	service, mtype, req, keepReading, err = server.readRequestHeader(codec)
	if err != nil {
		if !keepReading {
			return
		}
		// discard body
		codec.ReadRequestBody(nil)
		return
	}
    // 下面是通过反射解析请求体中的服务和方法
}

func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
	// Grab the request header.
    //获得Request
	req = server.getRequest()
    //解析Request Header
	err = codec.ReadRequestHeader(req)
	if err != nil {
		req = nil
		if err == io.EOF || err == io.ErrUnexpectedEOF {
			return
		}
		err = errors.New("rpc: server cannot decode request: " + err.Error())
		return
	}

	// We read the header successfully. If we see an error now,
	// we can still recover and move on to the next request.
	keepReading = true

	dot := strings.LastIndex(req.ServiceMethod, ".")
	if dot < 0 {
		err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
		return
	}
	serviceName := req.ServiceMethod[:dot]
	methodName := req.ServiceMethod[dot+1:]

	// 加载服务名
	svci, ok := server.serviceMap.Load(serviceName)
	if !ok {
		err = errors.New("rpc: can't find service " + req.ServiceMethod)
		return
	}
	svc = svci.(*service)
    //获得服务对应方法
	mtype = svc.method[methodName]
	if mtype == nil {
		err = errors.New("rpc: can't find method " + req.ServiceMethod)
	}
	return
}
func (server *Server) getRequest() *Request {
    //加锁
	server.reqLock.Lock()
    //保存当前空闲Req
	req := server.freeReq
	if req == nil {
        //如果空闲Req指针为空,则new一个新的Request对象
		req = new(Request)
	} else {
        //如果空闲Req指针不为空,则将空闲Req指针指向请求
		server.freeReq = req.next
        //置空保存好的空闲Req,作为当前请求容器,避免new申请内存
		*req = Request{}
	}
    //解锁
	server.reqLock.Unlock()
	return req
}

方法调用核心代码

好了,接下来我们去看一下核心的服务调用代码 service.call:

func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
    //首先开始 defer,在调用处理完后关闭等待组。
	if wg != nil {
		defer wg.Done()
	}
    //这里通过加锁,实现服务端每个方法的调用次数统计
    //type methodType struct {
    //	sync.Mutex // protects counters
    //	method     reflect.Method
    //	ArgType    reflect.Type
    //	ReplyType  reflect.Type
    //	numCalls   uint
    //}
	mtype.Lock()
	mtype.numCalls++
	mtype.Unlock()
	function := mtype.method.Func
	// Invoke the method, providing a new value for the reply.
    //通过反射调用对应方法,并将响应写入replyv中
	returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
	// The return value for the method is an error.
	errInter := returnValues[0].Interface()
	errmsg := ""
	if errInter != nil {
		errmsg = errInter.(error).Error()
	}
    //发送响应
	server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
    //释放请求资源
	server.freeRequest(req)
}

Response逻辑

func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
    //获得Response结构体,用于写入响应
	resp := server.getResponse()
	// Encode the response header
    //记录请求方法
	resp.ServiceMethod = req.ServiceMethod
	if errmsg != "" {
		resp.Error = errmsg
		reply = invalidRequest
	}
	resp.Seq = req.Seq
	sending.Lock()
    //写入响应数据
	err := codec.WriteResponse(resp, reply)
	if debugLog && err != nil {
		log.Println("rpc: writing response:", err)
	}
	sending.Unlock()
    //释放响应结构体占用内存
	server.freeResponse(resp)
}

getResponse+freeResponse 逻辑思路和 getRequeset  + freeRequest 完全一致,也是复用内存地址,不再具体分析,下面是代码:

func (server *Server) getResponse() *Response {
	server.respLock.Lock()
	resp := server.freeResp
	if resp == nil {
		resp = new(Response)
	} else {
		server.freeResp = resp.next
		*resp = Response{}
	}
	server.respLock.Unlock()
	return resp
}


func (server *Server) freeResponse(resp *Response) {
	server.respLock.Lock()
	resp.next = server.freeResp
	server.freeResp = resp
	server.respLock.Unlock()
}

服务端逻辑总结

整套RPC服务调用逻辑如下:

RegisterName -> register -> Listen -> Accept -> ServeConn -> ServeCodec -> readRequest -> getRequest -> ReadRequestHeader -> call -> sendResponse -> freeRequest

客户端

客户端就比较简单了,我们大致看一下流程就好了:

  1.  rpc.Dial 连接服务端,启动后台 goroutine 不断从响应缓冲区中读取响应,通过 chan 返回响应数据
  2. client.Call 阻塞等待 chan 返回并解析响应数据

核心结构体

// Call represents an active RPC.
type Call struct {
	ServiceMethod string      // The name of the service and method to call.
	Args          interface{} // The argument to the function (*struct).
	Reply         interface{} // The reply from the function (*struct).
	Error         error       // After completion, the error status.
	Done          chan *Call  // Strobes when call is complete.
}

// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
	codec ClientCodec

	reqMutex sync.Mutex // 保护并发请求安全
	request  Request

	mutex    sync.Mutex // 保护并发响应读取安全
	seq      uint64
	pending  map[uint64]*Call // 接收响应的通道map
	closing  bool // user has called Close
	shutdown bool // server has told us to stop
}

连接服务端并启动 channel 等待接收响应

下面是源码,重点看中文注释即可:

// rpc.Dial("tcp", "localhost:1234")
// Dial connects to an RPC server at the specified network address.
func Dial(network, address string) (*Client, error) {
    //开启TCP连接
	conn, err := net.Dial(network, address)
	if err != nil {
		return nil, err
	}
    //返回RPC客户端
	return NewClient(conn), nil
}

func NewClient(conn io.ReadWriteCloser) *Client {
    //通过参数控制通过gob编码请求和响应数据
	encBuf := bufio.NewWriter(conn)
	client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
	return NewClientWithCodec(client)
}


func NewClientWithCodec(codec ClientCodec) *Client {
	client := &Client{
		codec:   codec,
		pending: make(map[uint64]*Call), //等待响应的channel
	}
    //后台协程持续监听响应
	go client.input()
	return client
}


func (client *Client) input() {
	var err error
	var response Response
	for err == nil {
		response = Response{}
        //读取响应头
		err = client.codec.ReadResponseHeader(&response)
		if err != nil {
			break
		}
        //读取响应序列号
		seq := response.Seq
        //加锁读取响应方法
		client.mutex.Lock()
		call := client.pending[seq]
		delete(client.pending, seq)
		client.mutex.Unlock()

		switch {
		case call == nil:
			// We've got no pending call. That usually means that
			// WriteRequest partially failed, and call was already
			// removed; response is a server telling us about an
			// error reading request body. We should still attempt
			// to read error body, but there's no one to give it to.
			err = client.codec.ReadResponseBody(nil)
			if err != nil {
				err = errors.New("reading error body: " + err.Error())
			}
		case response.Error != "":
			// We've got an error response. Give this to the request;
			// any subsequent requests will get the ReadResponseBody
			// error if there is one.
			call.Error = ServerError(response.Error)
			err = client.codec.ReadResponseBody(nil)
			if err != nil {
				err = errors.New("reading error body: " + err.Error())
			}
			call.done()
		default:
            //读取响应body
			err = client.codec.ReadResponseBody(call.Reply)
			if err != nil {
				call.Error = errors.New("reading body " + err.Error())
			}
            //通过chan发送响应数据
			call.done()
		}
	}
	// Terminate pending calls.
	client.reqMutex.Lock()
	client.mutex.Lock()
	client.shutdown = true
	closing := client.closing
	if err == io.EOF {
		if closing {
			err = ErrShutdown
		} else {
			err = io.ErrUnexpectedEOF
		}
	}
	for _, call := range client.pending {
		call.Error = err
		call.done()
	}
	client.mutex.Unlock()
	client.reqMutex.Unlock()
	if debugLog && err != io.EOF && !closing {
		log.Println("rpc: client protocol error:", err)
	}
}

发起客户端调用

// err = client.Call("HelloService.Hello", "hello", &rep)

// Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
	call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
	return call.Error
}


// Go invokes the function asynchronously. It returns the Call structure representing
// the invocation. The done channel will signal when the call is complete by returning
// the same Call object. If done is nil, Go will allocate a new channel.
// If non-nil, done must be buffered or Go will deliberately crash.
func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
	call := new(Call)
	call.ServiceMethod = serviceMethod
	call.Args = args
	call.Reply = reply
	if done == nil {
		done = make(chan *Call, 10) // buffered.
	} else {
		// If caller passes done != nil, it must arrange that
		// done has enough buffer for the number of simultaneous
		// RPCs that will be using that channel. If the channel
		// is totally unbuffered, it's best not to run at all.
		if cap(done) == 0 {
			log.Panic("rpc: done channel is unbuffered")
		}
	}
	call.Done = done
	client.send(call)
	return call
}


func (client *Client) send(call *Call) {
    //加锁保护并发请求
	client.reqMutex.Lock()
	defer client.reqMutex.Unlock()

	// Register this call.
    //加锁保护多个call channel对应的map并发读写
	client.mutex.Lock()
    //如果客户端关闭停止发送请求
	if client.shutdown || client.closing {
		client.mutex.Unlock()
		call.Error = ErrShutdown
		call.done()
		return
	}
    //递增序列号,并保存对应 call
	seq := client.seq
	client.seq++
	client.pending[seq] = call
	client.mutex.Unlock()

	// Encode and send the request.
	client.request.Seq = seq
	client.request.ServiceMethod = call.ServiceMethod
    //写入请求
	err := client.codec.WriteRequest(&client.request, call.Args)
	if err != nil {
		client.mutex.Lock()
		call = client.pending[seq]
		delete(client.pending, seq)
		client.mutex.Unlock()
		if call != nil {
			call.Error = err
			call.done()
		}
	}
}

func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) {
    //编码request
	if err = c.enc.Encode(r); err != nil {
		return
	}
    //编码request body
	if err = c.enc.Encode(body); err != nil {
		return
	}
    //刷入缓冲区
	return c.encBuf.Flush()
}

整体逻辑

rpc.Dial -> NewClient -> NewClientWithCodec -> go client.input -> client.codec.ReadResponseHeader -> client.codec.ReadResponseBody

client.Call -> client.Go -> client.send -> client.codec.WriteRequest -> call.done

 

总结

  • 使用时,服务端和客户端的请求和响应类型应保持一致,否则会触发panic
  • net/rpc通过反射解析服务结构体、调用方法、请求参数、服务端响应
  • 通过加锁保护服务端的 Request 和 Response 结构体,并复用空闲结构体,减少内存分配次数
  • 通过加锁保护并发读写请求和响应缓冲区
  • 可以通过 mtype.numCalls 获得服务端每个方法的从启动开始的累计调用次数
  • 通过序列号保证读写数据的关联关系
  • dec 和 enc 两个字段可以通过插件实现自定义的编码和解码
  • RPC 协议建立在抽象的 io.ReadWriteCloser 接口之上,可以灵活替换通信协议

思考:Gob 编码是 Go 语言特有的,其他服务调用起来比较困难,既然标准库为我们预留了自定义编码和解码,能否自己实现一个 json 格式的编码解码?

 

 

上一篇:suse12 设置ssh 远程连接


下一篇:文本操作防止中文乱码