AI 摘要

文章剖析了自研 RPC 框架的网络传输模块:先定义 RpcRequest/Response 实体,再抽象 RpcRequestTransport 接口,分别用 Socket 与 Netty 实现客户端与服务端收发;Netty 版引入编解码器、自定义协议及心跳保活,并借助 CompletableFuture 实现异步回调,完成高性能远程调用。

09 RPC 框架代码分析之网络传输模块

以下提到的 服务端 指的是提供服务/方法的一端,客户端 指的是调用远程(服务端)服务/方法的一端。

我们之前在“如何自己实现一个 RPC 框架?”这篇文章中介绍到说:既然我们要调用远程的方法,就要发送网络请求来传递目标类和方法的信息以及方法的参数等数据到服务端。 这就涉及到了网络传输!网络传输具体实现你可以使用 Socket ( Java 中最原始、最基础的网络通信方式。但是,Socket 是阻塞 IO、性能低并且功能单一)。你也可以使用同步非阻塞的 I/O 模型 NIO ,但是用它来进行网络编程真的太麻烦了。不过没关系,你可以使用基于 NIO 的网络编程框架 Netty ,它将是你最好的选择!

guide-rpc-framework 使用了一种基于 Socket,一种基于 Netty 的方式(循序渐进)。

网络传输模块整体结构如下:

一共被分为了 4 个包

  1. constants : 存放一些网络传输模块共用的常量
  2. dto : 用于网络传输的类。
  3. handler : 里面只有一个用于处理 rpc 请求的类RpcRequestHandler(根据 rpc 请求调用目标类的目标方法)。
  4. transport : 用户网络传输相关类(真正传输网络请求的地方。提供了 Socket 和 Netty 两种网络传输方式)。

网络传输实体类

网络传输实体类在 dto 包下,主要有两个类。

**RpcRequest.java**

rpc 请求实体类。当你要调用远程方法的时候,你需要先传输一个 RpcRequest 给对方,RpcRequest 里面包含了要调用的目标方法和类的名称、参数等数据。

另外,version 字段(服务版本)主要是为后续不兼容升级提供可能。group 字段主要用于处理一个接口有多个类实现的情况。

@AllArgsConstructor
@NoArgsConstructor
@Getter
@Builder
@ToString
public class RpcRequest implements Serializable {
    private static final long serialVersionUID = 1905122041950251207L;
    private String requestId;
    private String interfaceName;
    private String methodName;
    private Object[] parameters;
    private Class<?>[] paramTypes;
    private RpcMessageType rpcMessageType;
    private String version;
    private String group;

    public RpcServiceProperties toRpcProperties() {
        return RpcServiceProperties.builder().serviceName(this.getInterfaceName())
                .version(this.getVersion())
                .group(this.getGroup()).build();
    }
}

**RpcResponse.java**

既然有了 rpc 请求实体类,那肯定就要有 rpc 响应实体类了。

当服务端通过 RpcRequest 中的相关数据调用到目标服务的目标方法之后,调用结果就通过 RpcResponse 返回给客户端。

@AllArgsConstructor
@NoArgsConstructor
@Getter
@Setter
@Builder
@ToString
public class RpcResponse<T> implements Serializable {

    private static final long serialVersionUID = 715745410605631233L;
    private String requestId;
    /**
     * response code
     */
    private Integer code;
    /**
     * response message
     */
    private String message;
    /**
     * response body
     */
    private T data;

    public static <T> RpcResponse<T> success(T data, String requestId) {
        RpcResponse<T> response = new RpcResponse<>();
        response.setCode(RpcResponseCode.SUCCESS.getCode());
        response.setMessage(RpcResponseCode.SUCCESS.getMessage());
        response.setRequestId(requestId);
        if (null != data) {
            response.setData(data);
        }
        return response;
    }

    public static <T> RpcResponse<T> fail(RpcResponseCode rpcResponseCode) {
        RpcResponse<T> response = new RpcResponse<>();
        response.setCode(rpcResponseCode.getCode());
        response.setMessage(rpcResponseCode.getMessage());
        return response;
    }

}

网络传输

由于,这部分我提供了一种基于 Socket,一种基于 Netty 的网络传输方式(循序渐进)。

因此,我先定义了一个发送 RPC 请求的顶层接口,然后我们分别使用 Socket 和 Netty 两种方式对这个接口进行实现即可!

RpcRequestTransport.java 传输请求的接口

/**
 * send RpcRequest。
 *
 * @author shuang.kou
 * @createTime 2020年05月29日 13:26:00
 */
@SPI
public interface RpcRequestTransport {
    /**
     * send rpc request to server and get result
     *
     * @param rpcRequest message body
     * @return data from server
     */
    Object sendRpcRequest(RpcRequest rpcRequest);
}

下面,我们先来看一下比较简单点的使用 Socket 进行网络传输的方式。

Socket

客户端

这里的客户端实际就是发送 RPC 请求的一端,可以对照我们之间画的 RPC 调用的原理图来理解。

客户端主要用于发送网络请求到服务端(目标方法所在的服务器)。当我们知道了服务端的地址之后,我们就可以通过 SocketRpcClient 发送 rpc 请求(RpcRequest) 到服务端了(如果我们要找到服务端的地址,涉及到了注册中心相关的知识。下一节会提到。)。

我们直接实现上面定义的 RpcRequestTransport.java 即可。这样的话,通过 Socket 来传输消息的模块就写好了!

/**
 * 基于 Socket 传输 RpcRequest
 *
 * @author shuang.kou
 * @createTime 2020年05月10日 18:40:00
 */
@AllArgsConstructor
@Slf4j
public class SocketRpcClient implements RpcRequestTransport {
    private final ServiceDiscovery serviceDiscovery;

    public SocketRpcClient() {
        this.serviceDiscovery = ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension("zk");
    }

    @Override
    public Object sendRpcRequest(RpcRequest rpcRequest) {
        // build rpc service name by rpcRequest
        String rpcServiceName = RpcServiceProperties.builder().serviceName(rpcRequest.getInterfaceName())
                .group(rpcRequest.getGroup()).version(rpcRequest.getVersion()).build().toRpcServiceName();
        InetSocketAddress inetSocketAddress = serviceDiscovery.lookupService(rpcServiceName);
        try (Socket socket = new Socket()) {
            socket.connect(inetSocketAddress);
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(socket.getOutputStream());
            // Send data to the server through the output stream
            objectOutputStream.writeObject(rpcRequest);
            ObjectInputStream objectInputStream = new ObjectInputStream(socket.getInputStream());
            // Read RpcResponse from the input stream
            return objectInputStream.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new RpcException("调用服务失败:", e);
        }
    }
}

上面的逻辑很简单,就是对 Socket 发送网络请求这个基础知识的运用。

我这里就不再对上面的代码进行解析了,看不懂的小伙伴自行翻看之前关于 Socket 讲解的章节。

服务端

**SocketRpcServer.java**

Socket 服务端。用于等待客户端连接。当客户端成功连接之后,就可以发送 rpc 请求(RpcRequest) 到服务端了。然后,服务端拿到 RpcRequest就会去执行对应的方法。执行完对应的方法之后,就把执行得到的结果放在 RpcResponse 中返回给客户端。

/**
 * @author shuang.kou
 * @createTime 2020年05月10日 08:01:00
 */
@Slf4j
public class SocketRpcServer {

    private final ExecutorService threadPool;
    private final ServiceProvider serviceProvider;


    public SocketRpcServer() {
        threadPool = ThreadPoolFactoryUtils.createCustomThreadPoolIfAbsent("socket-server-rpc-pool");
        serviceProvider = SingletonFactory.getInstance(ServiceProviderImpl.class);
    }

    public void registerService(Object service) {
        serviceProvider.publishService(service);
    }

    public void registerService(Object service, RpcServiceProperties rpcServiceProperties) {
        serviceProvider.publishService(service, rpcServiceProperties);
    }

    public void start() {
        try (ServerSocket server = new ServerSocket()) {
            String host = InetAddress.getLocalHost().getHostAddress();
            server.bind(new InetSocketAddress(host, PORT));
            CustomShutdownHook.getCustomShutdownHook().clearAll();
            Socket socket;
            while ((socket = server.accept()) != null) {
                log.info("client connected [{}]", socket.getInetAddress());
                threadPool.execute(new SocketRpcRequestHandlerRunnable(socket));
            }
            threadPool.shutdown();
        } catch (IOException e) {
            log.error("occur IOException:", e);
        }
    }

}

Netty

Netty 这部分的原理也差不多,不过实现代码差别很大。

客户端

**NettyClient.java**

Netty 客户端主要提供了:

  • **doConnect()** :用于连接服务端(目标方法所在的服务器)并返回对应的 Channel。当我们知道了服务端的地址之后,我们就可以通过 NettyClient 成功连接服务端了。(有了 Channel 之后就能发送数据到服务端了)
  • **sendRpcRequest()** : 用于传输 rpc 请求(RpcRequest) 到服务端。
@Slf4j
public final class NettyRpcClient implements RpcRequestTransport {
    private final ServiceDiscovery serviceDiscovery;
    private final UnprocessedRequests unprocessedRequests;
    private final ChannelProvider channelProvider;
    private final Bootstrap bootstrap;
    private final EventLoopGroup eventLoopGroup;

    @SneakyThrows
    public Channel doConnect(InetSocketAddress inetSocketAddress) {
        CompletableFuture<Channel> completableFuture = new CompletableFuture<>();
        bootstrap.connect(inetSocketAddress).addListener((ChannelFutureListener) future -> {
            if (future.isSuccess()) {
                log.info("The client has connected [{}] successful!", inetSocketAddress.toString());
                completableFuture.complete(future.channel());
            } else {
                throw new IllegalStateException();
            }
        });
        return completableFuture.get();
    }

    @Override
    public Object sendRpcRequest(RpcRequest rpcRequest) {
        // build return value
        CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture<>();
        // build rpc service name by rpcRequest
        String rpcServiceName = rpcRequest.toRpcProperties().toRpcServiceName();
        // get server address
        InetSocketAddress inetSocketAddress = serviceDiscovery.lookupService(rpcServiceName);
        // get  server address related channel
        Channel channel = getChannel(inetSocketAddress);
        if (channel.isActive()) {
            // put unprocessed request
            unprocessedRequests.put(rpcRequest.getRequestId(), resultFuture);
            RpcMessage rpcMessage = new RpcMessage();
            rpcMessage.setData(rpcRequest);
            rpcMessage.setCodec(SerializationTypeEnum.PROTOSTUFF.getCode());
            rpcMessage.setCompress(CompressTypeEnum.GZIP.getCode());
            rpcMessage.setMessageType(RpcConstants.REQUEST_TYPE);
            channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) future -> {
                if (future.isSuccess()) {
                    log.info("client send message: [{}]", rpcMessage);
                } else {
                    future.channel().close();
                    resultFuture.completeExceptionally(future.cause());
                    log.error("Send failed:", future.cause());
                }
            });
        } else {
            throw new IllegalStateException();
        }

        return resultFuture;
    }
}

**UnprocessedRequests.java**

用于存放未被服务端处理的请求(建议限制 map 容器大小,避免未处理请求过多 OOM)。

public class UnprocessedRequests {
    private static final Map<String, CompletableFuture<RpcResponse<Object>>> UNPROCESSED_RESPONSE_FUTURES = new ConcurrentHashMap<>();

    public void put(String requestId, CompletableFuture<RpcResponse<Object>> future) {
        UNPROCESSED_RESPONSE_FUTURES.put(requestId, future);
    }

    public void complete(RpcResponse<Object> rpcResponse) {
        CompletableFuture<RpcResponse<Object>> future = UNPROCESSED_RESPONSE_FUTURES.remove(rpcResponse.getRequestId());
        if (null != future) {
            future.complete(rpcResponse);
        } else {
            throw new IllegalStateException();
        }
    }
}

**NettyClientHandler**

自定义客户端 ChannelHandler 用于处理服务器发送的数据。

@Slf4j
public class NettyClientHandler extends ChannelInboundHandlerAdapter {
    private final UnprocessedRequests unprocessedRequests;
    private final ChannelProvider channelProvider;

    public NettyClientHandler() {
        this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
        this.channelProvider = SingletonFactory.getInstance(ChannelProvider.class);
    }

    /**
     * 读取从服务端返回的消息
     */
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        try {
            log.info("client receive msg: [{}]", msg);
            if (msg instanceof RpcResponse) {
                RpcResponse<Object> rpcResponse = (RpcResponse<Object>) msg;
                unprocessedRequests.complete(rpcResponse);
            }
        } finally {
            ReferenceCountUtil.release(msg);
        }
    }

    // Netty 心跳机制相关。保证客户端和服务端的连接不被断掉,避免重连。
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
       //省略部分代码
    }


}

从代码中,可以看出当 rpc 请求被成功处理(客户端收到服务端的执行结果)之后,我们调用了 unprocessedRequests.complete(rpcResponse) 方法,这样的话,你只需要通过下面的方式就能成功接收到服务端返回的结果。

CompletableFuture<RpcResponse> completableFuture = (CompletableFuture<RpcResponse>) clientTransport.sendRpcRequest(rpcRequest);
rpcResponse = completableFuture.get();

**ChannelProvider.java**

用于存放 ChannelChannel用于在服务端和客户端之间传输数据)。

@Slf4j
public class ChannelProvider {

    private final Map<String, Channel> channelMap;

    public ChannelProvider() {
        channelMap = new ConcurrentHashMap<>();
    }

    public Channel get(InetSocketAddress inetSocketAddress) {
        String key = inetSocketAddress.toString();
        // determine if there is a connection for the corresponding address
        if (channelMap.containsKey(key)) {
            Channel channel = channelMap.get(key);
            // if so, determine if the connection is available, and if so, get it directly
            if (channel != null && channel.isActive()) {
                return channel;
            } else {
                channelMap.remove(key);
            }
        }
        return null;
    }

    public void set(InetSocketAddress inetSocketAddress, Channel channel) {
        String key = inetSocketAddress.toString();
        channelMap.put(key, channel);
    }

    public void remove(InetSocketAddress inetSocketAddress) {
        String key = inetSocketAddress.toString();
        channelMap.remove(key);
        log.info("Channel map size :[{}]", channelMap.size());
    }
}

服务端相关

**NettyRpcServer.java**

Netty 服务端。并监听客户端的连接。另外,还提供了两个用户手动注册服务的方法(_还可以通过注解__RpcService__注册服务,这个后面也会介绍到_)。

@Slf4j
@Component
public class NettyRpcServer {

    public static final int PORT = 9998;

    private final ServiceProvider serviceProvider = SingletonFactory.getInstance(ServiceProviderImpl.class);

    public void registerService(Object service, RpcServiceProperties rpcServiceProperties) {
        serviceProvider.publishService(service, rpcServiceProperties);
    }

    @SneakyThrows
    public void start() {
        CustomShutdownHook.getCustomShutdownHook().clearAll();
        String host = InetAddress.getLocalHost().getHostAddress();
        EventLoopGroup bossGroup = new NioEventLoopGroup(1);
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        DefaultEventExecutorGroup serviceHandlerGroup = new DefaultEventExecutorGroup(
                RuntimeUtil.cpus() * 2,
                ThreadPoolFactoryUtils.createThreadFactory("service-handler-group", false)
        );
        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    // TCP默认开启了 Nagle 算法,该算法的作用是尽可能的发送大数据快,减少网络传输。TCP_NODELAY 参数的作用就是控制是否启用 Nagle 算法。
                    .childOption(ChannelOption.TCP_NODELAY, true)
                    // 是否开启 TCP 底层心跳机制
                    .childOption(ChannelOption.SO_KEEPALIVE, true)
                    //表示系统用于临时存放已完成三次握手的请求的队列的最大长度,如果连接建立频繁,服务器处理创建新连接较慢,可以适当调大这个参数
                    .option(ChannelOption.SO_BACKLOG, 128)
                    .handler(new LoggingHandler(LogLevel.INFO))
                    // 当客户端第一次进行请求的时候才会进行初始化
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) {
                            // 30 秒之内没有收到客户端请求的话就关闭连接
                            ChannelPipeline p = ch.pipeline();
                            p.addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));
                            p.addLast(new RpcMessageEncoder());
                            p.addLast(new RpcMessageDecoder());
                            p.addLast(serviceHandlerGroup, new NettyRpcServerHandler());
                        }
                    });

            // 绑定端口,同步等待绑定成功
            ChannelFuture f = b.bind(host, PORT).sync();
            // 等待服务端监听端口关闭
            f.channel().closeFuture().sync();
        } catch (InterruptedException e) {
            log.error("occur exception when start server:", e);
        } finally {
            log.error("shutdown bossGroup and workerGroup");
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
            serviceHandlerGroup.shutdownGracefully();
        }
    }


}

**NettyServerHandler.java**

自定义服务端 ChannelHandler 用于处理客户端发送的数据。

当客户端发的 rpc 请求(RpcRequest) 来了之后,服务端就会处理 rpc 请求(RpcRequest) ,处理完之后就把得到 rpc 相应(RpcResponse)传输给客户端。

@Slf4j
public class NettyServerHandler extends ChannelInboundHandlerAdapter {

    private final RpcRequestHandler rpcRequestHandler;

    public NettyServerHandler() {
        this.rpcRequestHandler = SingletonFactory.getInstance(RpcRequestHandler.class);
    }

     /**
     * 读取从客户端消息,然后调用目标服务的目标方法并返回给客户端。
     */
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
      // 省略部分代码
    }


    // Netty 心跳机制相关。保证客户端和服务端的连接不被断掉,避免重连。
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
     // 省略部分代码
    }

}

传输协议

在《如何自己实现一个 RPC 框架》这一节,我们就提到了传输协议的作用。

简单来说:通过设计协议,我们定义需要传输哪些类型的数据, 并且还会规定每一种类型的数据应该占多少字节。这样我们在接收到二级制数据之后,就可以正确的解析出我们需要的数据。这有一点像密文传输的感觉。

以下便是我们设计的传输协议(编解码器这里会用到!!!):

 *   0     1     2     3     4        5     6     7     8         9          10      11     12  13  14   15 16
 *   +-----+-----+-----+-----+--------+----+----+----+------+-----------+-------+----- --+-----+-----+-------+
 *   |   magic   code        |version | full length         | messageType| codec|compress|    RequestId       |
 *   +-----------------------+--------+---------------------+-----------+-----------+-----------+------------+
 *   |                                                                                                       |
 *   |                                         body                                                          |
 *   |                                                                                                       |
 *   |                                        ... ...                                                        |
 *   +-------------------------------------------------------------------------------------------------------+
 * 4B  magic code(魔法数)   1B version(版本)   4B full length(消息长度)    1B messageType(消息类型)
 * 1B compress(压缩类型) 1B codec(序列化类型)    4B  requestId(请求的Id)
  • 魔法数 : 通常是 4 个字节。这个魔数主要是为了筛选来到服务端的数据包,有了这个魔数之后,服务端首先取出前面四个字节进行比对,能够在第一时间识别出这个数据包并非是遵循自定义协议的,也就是无效数据包,为了安全考虑可以直接关闭连接以节省资源。
  • 序列化器类型 :标识序列化的方式,比如是使用 Java 自带的序列化,还是 json,kyro 等序列化方式。
  • 消息长度 : 运行时计算出来。
  • ......

编解码器

编解码器这里主要用到了 Kryo 序列化和反序列化以及 Netty 网络传输字节容器 ByteBuf 相关的知识。

编解码器的作用主要是让我们在 Netty 进行网络传输所用的对象类型 ByteBuf 与 我们代码层面需要的业务对象之间转换。这部分的代码还是比较多的,小伙伴们可以自己阅读以下,整体逻辑还是比较简单的。

一定要先搞懂传输协议之后再去看这部分代码。

**RpcMessageDecoder.java**

自定义解码器。负责处理"入站"消息,将 ByteBuf 消息格式的对象转换为我们需要的业务对象。

网络传输需要通过字节流来实现,ByteBuf 可以看作是 Netty 提供的字节数据的容器,使用它会让我们更加方便地处理字节数据。

**RpcMessageEncoder.java**

自定义编码器。负责处理"出站"消息,将消息格式转换字节数组然后写入到字节数据的容器 ByteBuf 对象中。