手写RPC(六) 核心模块网络协议模块编写 ---- 实现编解码器

前面的基础已经写好了,现在我们来实现编码器。
为什么需要编码器?
netty只负责传输数据,至于数据长什么样它是不关注的。前面也提到了自定义协议就是把我们要传输的数据按照我们的规则进行组织、传输、解码,编码器就是对我们要发送的数据进行组织的作用。
netty已经为我们做好了封装,我们只需要集成MessageToByteEncoder实现其encode方法即可,然后把这个编码器添加到我们netty处理器的pipeline(流水线)即可。

package com.info.protocol.netty.core.codec;

import com.info.protocol.netty.core.Header;
import com.info.protocol.netty.core.Protocol;
import com.info.protocol.serial.Serializer;
import com.info.protocol.serial.SerializerManager;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class CustomEncoder extends MessageToByteEncoder<Protocol<Object>> {

        /*
        +--------------------------------------------------------------------------------------------+
        |魔数 16bit|协议版本 8bit|序列化方式 8bit|消息长度 32bit |消息类型(请求还是响应)2bit|messageId 64bit |
        +--------------------------------------------------------------------------------------------+
        */

    @Override
    protected void encode(ChannelHandlerContext ctx, Protocol<Object> msg, ByteBuf out) throws Exception {
        log.info("----------------- begin encode -----------------");
        final Header header = msg.getHeader();
        // 写入魔数
        out.writeShort(header.getMagic());
        // 写入使用的协议版本
        out.writeByte(header.getProtocolVersion());
        // 写入序列化方式
        out.writeByte(header.getSerializeType());
        // 对消息内容进行序列化 又因为前面已经实现好了序列化,但是这里只有序列化方式,
        // 我们需要通过这个序列化方式拿到对应的序列化器,因此序列化方式最好有一个管理器,并且提供根据序列化方式
        // 获取具体的序列化器的功能
        Serializer serializer = SerializerManager.getSerializerByCode(header.getSerializeType());
        // 序列化消息
        final Object content = msg.getContent();
        final byte[] data = serializer.serialize(content);
        // 写入消息长度
        out.writeInt(data.length);
        // 写入消息类型
        out.writeByte(header.getMessageType());
        // 写入消息id
        out.writeLong(header.getMessageId());
        // 写入具体的消息内容
        out.writeBytes(data);
    }
}

序列化管理器

package com.info.serial;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class SerializerManager {

    private static Map<Byte, Serializer> serializerMap = new ConcurrentHashMap<>(1 << 1);

    static {
        Serializer javaSerializer = new JavaSerializer();
        Serializer jsonSerializer = new JsonSerializer();
        serializerMap.put(javaSerializer.getType(), javaSerializer);
        serializerMap.put(jsonSerializer.getType(), jsonSerializer);
    }

    public static Serializer getSerializerByCode(byte code) {
        Serializer serializer = serializerMap.get(code);
        if (serializer == null) {
            serializer = new JavaSerializer();
        }
        return serializer;
    }
}

编码器实现好了,解码器是反向对我们编码好的数据进行解码,因为要区分消息类型(发送请求和接收请求返回的内容是不一样的),我们定义好要反序列化的对象

package com.info.protocol.netty.core;

import lombok.Getter;
import lombok.Setter;
import lombok.ToString;

import java.io.Serializable;

@Getter
@Setter
@ToString
public class Request implements Serializable {

    private String className;

    private String methodName;

    private Object[] params;

    private Class<?>[] parameterTypes;
}

package com.info.protocol.netty.core;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
@AllArgsConstructor
public class Response {

    private Object data;

    private String msg;
}

实现解码器

package com.info.protocol.netty.core.codec;

import com.info.protocol.constants.CommonConstant;
import com.info.protocol.enums.MessageTypeEnum;
import com.info.protocol.netty.core.Header;
import com.info.protocol.netty.core.Protocol;
import com.info.protocol.netty.core.Request;
import com.info.protocol.netty.core.Response;
import com.info.protocol.serial.Serializer;
import com.info.protocol.serial.SerializerManager;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import lombok.extern.slf4j.Slf4j;

import java.util.List;

@Slf4j
public class CustomDecoder extends ByteToMessageDecoder {

        /*
        +--------------------------------------------------------------------------------------------+
        |魔数 16bit|协议版本 8bit|序列化方式 8bit|消息长度 32bit |消息类型(请求还是响应)2bit|messageId 64bit |
        +--------------------------------------------------------------------------------------------+
        */

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        log.info("------------------ begin decode ------------------");
        if (in.readableBytes() < CommonConstant.HEADER_LENGTH) {
            // 消息长度不够,暂不解析
            return;
        }
        //标记一个读取数据的索引,必要的时候用来重置。
        in.markReaderIndex();

        final short magic = in.readShort();
        if (magic != CommonConstant.MAGIC) {
            throw new IllegalArgumentException("illegal request parameter 'magic' " + magic);
        }
        final byte protocolVersionCode = in.readByte();
        final byte serializerTypeCode = in.readByte();
        final int messageLength = in.readInt();
        final byte messageTypeCode = in.readByte();
        final long messageId = in.readLong();
        if (in.readableBytes() < messageLength) {
            // 消息不够 重置读标识
            in.resetReaderIndex();
            return;
        }
        byte[] content = new byte[messageLength];
        in.readBytes(content);

        Header header = new Header();
        header.setMagic(magic)
                .setMessageId(messageId)
                .setMessageLength(messageLength)
                .setMessageType(messageTypeCode)
                .setProtocolVersion(protocolVersionCode)
                .setSerializeType(serializerTypeCode);

        // 获取反序列器
        final Serializer serializer = SerializerManager.getSerializerByCode(serializerTypeCode);
        MessageTypeEnum messageType = MessageTypeEnum.getMessageTypeEnumByCode(messageTypeCode);

        switch (messageType) {
            case REQUEST:
                final Request request = serializer.deserialize(content, Request.class);
                Protocol<Request> requestProtocol = new Protocol<>();
                requestProtocol.setHeader(header);
                requestProtocol.setContent(request);
                out.add(requestProtocol);
                break;
            case RESPONSE:
                final Response response = serializer.deserialize(content, Response.class);
                Protocol<Response> responseProtocol = new Protocol<>();
                responseProtocol.setHeader(header);
                responseProtocol.setContent(response);
                out.add(responseProtocol);
                break;
            case HEART_BEAT:
            default:
        }

    }
}

最后需要实现一个处理器,目的是服务端在读取到请求的数据以后(已经解码完成),需要真正的去获取调用目标方法的结果,这里使用反射实现。处理器还需要继承SimpleChannelInboundHandler,然后实现其channelRead0 方法即可。

package com.info.protocol.netty.server;

import com.info.protocol.enums.MessageTypeEnum;
import com.info.protocol.netty.core.Header;
import com.info.protocol.netty.core.Protocol;
import com.info.protocol.netty.core.Request;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class CustomServerHandler extends SimpleChannelInboundHandler<Protocol<Request>> {
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Protocol<Request> msg) throws Exception {
        Protocol protocol = new Protocol<>();
        final Header header = msg.getHeader();
        header.setMessageType(MessageTypeEnum.RESPONSE.getCode());
        // 反射调用目标方法
        Object result = invoke(msg.getContent());
        protocol.setHeader(header);
        protocol.setContent(result);
        ctx.writeAndFlush(protocol);
    }

    // 调用目标方法
    private Object invoke(Request request) {
        final String clzName = request.getClassName();
        try {
            final Class<?> clz = Class.forName(clzName);
            final Constructor<?> constructor = clz.getDeclaredConstructors()[0];
            final Object instance = constructor.newInstance();
            final Method method = clz.getDeclaredMethod(request.getMethodName(), request.getParameterTypes());
            return method.invoke(instance, request.getParams());
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        }
        return null;
    }
}

服务端基础的代码都已经实现了,现在需要把我们的编码器、解码器、处理器结合起来使他们生效,其方法就是把它们添加到 nettypipeline中:

package com.info.protocol.netty.server;

import com.info.protocol.netty.core.codec.CustomDecoder;
import com.info.protocol.netty.core.codec.CustomEncoder;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LoggingHandler;

public class CustomServerInitializer extends ChannelInitializer<SocketChannel> {
    @Override
    protected void initChannel(SocketChannel ch) throws Exception {
        ch.pipeline()
                .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 12, 4, 0, 0))
                // netty提供的,日志处理,便于常看调用流程
                .addLast(new LoggingHandler())
                .addLast(new CustomEncoder())
                .addLast(new CustomDecoder());
    }
}

最后需要把这个初始化器(连接编解码器)的组件添加到netty:

ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    // 指定 childHandler
                    .childHandler(new CustomServerInitializer());

至此,服务端的代码已经实现完毕,下节开始实现客户端的代码。

上一篇:Python - 自定义向量类


下一篇:PyQt5基础学习-QWebEngineView(构建网页显示器) 1.QWebEngineView().load(Qurl(加载对应的网址))