xiwenAndlejian / my-blog

Java基础学习练习题
1 stars 0 forks source link

RPC 实现笔记(二) #22

Open xiwenAndlejian opened 5 years ago

xiwenAndlejian commented 5 years ago

image

接上文,为了完成请求和响应,还需要编写服务端和客户端的RPC 请求处理器

RPC 请求处理器

客户端请求处理

注:此处的雪花算法(SnowFlake)可参考网络的实现或使用别的分布式 id 算法

public class RpcClientHandler extends SimpleChannelInboundHandler<RpcResponsePacket> {

    // requestId -> response
    private static final Map<Long, RpcResponsePacket> resultMap = new ConcurrentHashMap<>();
    // requestId -> latch
    private static final Map<Long, CountDownLatch>    lockMap   = new ConcurrentHashMap<>();

    // 雪花算法
    private static SnowFlake snowFlake;

    // 保存写入消息的 channel
    @Setter
    private Channel channel;

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        // 暂时写死雪花算法的 dataCenterId 以及 machineId
        snowFlake = new SnowFlake(1, 1);
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcResponsePacket msg) throws Exception {
        setResponse(msg.getRequestId(), msg);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
    }

    private void setResponse(Long requestId, RpcResponsePacket response) {
        resultMap.put(requestId, response);
        lockMap.get(requestId).countDown();
    }

    public RpcResponsePacket getRpcResponse(RpcRequestPacket request) throws InterruptedException {
        // 设置此次 rpc 请求的 requestId
        request.setRequestId(snowFlake.getNextId());
        // 向服务端发送 RPC 请求
        channel.writeAndFlush(request);
        var lock = new CountDownLatch(1);
        lockMap.put(request.getRequestId(), lock);
        // 阻塞
        lock.await();

        return resultMap.remove(request.getRequestId());
    }

}

服务端请求处理器

@Slf4j
public class RpcServerHandler extends SimpleChannelInboundHandler<RpcRequestPacket> {

    private ProxyFactory proxyFactory = RpcServerProxy.getInstance();

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcRequestPacket msg) throws Exception {
        log.debug("request: ", msg);

        Class<?>   target     = msg.getClazz();
        Class<?>[] paramTypes = msg.getParameterTypes();

        var args       = msg.getParameters();
        var methodName = msg.getMethodName();
        var requestId  = msg.getRequestId();

        // 异步执行方法
        CompletableFuture<RpcResponsePacket> future = CompletableFuture.supplyAsync(() -> {
            RpcResponsePacket response = new RpcResponsePacket();
            try {
                // 动态代理
                Object proxy  = proxyFactory.getProxy(target);
                Method method = target.getMethod(methodName, paramTypes);
                Object result = method.invoke(proxy, args);
                response.setResponse(result);
                response.setClazz(result.getClass());
            } catch (Exception e) {
                // 注意这里获取实际的 exception 时,连续调用了两次 getCause
                // 原因是 cglib 的 invoke 方法包裹了两次异常,需要调用两次才能获取实际抛出的异常
                response.setException(new ProxyException(e.getCause().getCause()));
            }
            response.setRequestId(requestId);
            return response;
        });
        RpcResponsePacket response = future.get();
        log.debug("proxy result response: ", response.getResponse());
        // 发送响应
        ctx.channel().writeAndFlush(response);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.error("RPC request handle failed cause:", cause);
        cause.printStackTrace();
    }
}

服务端动态代理(cglib)

动态调用服务端 service 的具体实现,执行对应方法。

当实现类不存在的情况下,需要抛出异常

@Slf4j
public class RpcServerProxy implements ProxyFactory {

    private static final Enhancer enhancer = new Enhancer();

    private static Map<Class<?>, Class<?>> producerClasses   = null;
    // 缓存 service 的实例
    private static Map<Class<?>, Object>   producerInstances = new HashMap<>();
    // service 实现类所在包路径
    private static final String PACKAGE_PATH = "com.dekuofa.service.impl";

    public static RpcServerProxy getInstance() {
        return new RpcServerProxy();
    }

    private RpcServerProxy() {
        if (producerClasses == null) {
            // 加载包名下具有 ProxyProducer 注解的生产者
            List<Class<?>> classList = ClassScanUtil.scanner(PACKAGE_PATH);
            // 消费者(接口) -> 生产者(实现类)
            producerClasses = classList.stream()
                    .filter(clazz -> clazz.isAnnotationPresent(ProxyProducer.class))
                    .collect(Collectors.toMap(
                            clazz -> clazz.getAnnotation(ProxyProducer.class).consumer(),
                            Function.identity()));
        }
    }

    @SuppressWarnings("unchecked")
    @Override
    public <T> T getProxy(Class<T> targetClass) {
        enhancer.setSuperclass(targetClass);
        enhancer.setCallback((MethodInterceptor) (Object obj, Method method, Object[] args, MethodProxy proxy) -> {
            // rpc 远程调用方法所在的接口类
            Class<?> clazz = method.getDeclaringClass();
            // 获取实现类
            Class<?> producerClass = producerClasses.get(clazz);
            if (Objects.isNull(producerClass)) {
                log.error("proxy exception: cannot found producer class: {}", clazz.getName());
                throw new ProxyServiceNotFoundException("class: " + clazz.getName() + " not found", clazz);
            }
            // 获取实现类缓存的实例
            Object producer = producerInstances.get(producerClass);
            if (Objects.isNull(producer)) {
                producer = producerClass.newInstance();
                producerInstances.put(producerClass, producer);
            }
            // 调用方法
            return method.invoke(producer, args);
        });
        return (T) enhancer.create();
    }
}

客户端

客户端需要建立与服务端的长连接。

相比短连接,长连接不会每次发送请求时再建立连接,会节省很多重新建立连接的开销(往往会占比较大的比例)。

长连接也需要一些机制来保证连接的有效性,eg:心跳包、客户端重连机制。

@Slf4j
public class RpcConnect {

    private String           host;
    private Integer          port;
    private RpcClientHandler rpcClientHandler;

    public RpcConnect(String host, Integer port) {
        this.host = host;
        this.port = port;
        rpcClientHandler = new RpcClientHandler();
    }

    public RpcConnect connect() {
        preCheck();
        try {
            EventLoopGroup group = new NioEventLoopGroup();
            Bootstrap bootstrap = new Bootstrap()
                    .group(group)
                    .channel(NioSocketChannel.class)
                    .handler(new ChannelInitializer<Channel>() {
                        @Override
                        protected void initChannel(Channel ch) throws Exception {
                            ChannelPipeline pipeline = ch.pipeline();
                            pipeline.addLast(new PacketEncoder())
                                    .addLast(new PacketDecoder())
                                    .addLast(rpcClientHandler);
                        }
                    });
            var connectFuture = bootstrap.connect(host, port);
            connectFuture.addListener(future -> {
                if (future.isSuccess()) {
                    log.info("client connect to {}:{} success", host, port);
                    rpcClientHandler.setChannel(((ChannelFuture) future).channel());
                } else {
                    log.info("client connect to {}:{} failed", host, port, future.cause());
                }
            });
        } catch (Exception e) {
            log.error("client connect failed: ", e);
        }
        return this;
    }

    private void preCheck() {
        if (Objects.isNull(port)) {
            throw new BootstrapException("端口号不能为空");
        }
        if (Objects.isNull(host) || "".equals(host)) {
            throw new BootstrapException("host不能为空");
        }
    }

    // RPC 消息发送入口
    public <T> T sendRequest(RpcRequestPacket request) {
        try {
            var response = rpcClientHandler.getRpcResponse(request);
            return response.getRpcResult();
        } catch (InterruptedException e) {
            log.error("get rpc response failed: ", e);
            return null;
        }
    }
}

客户端动态代理(cglib)

当接口类执行非 Object 类中的方法时,向发送 RPC 请求

@Slf4j
public class RpcClientProxy implements ProxyFactory {

    private static final Enhancer enhancer = new Enhancer();

    // 保留与服务端的连接,用于发送请求
    private RpcConnect connect;

    public RpcClientProxy(RpcConnect connect) {
        this.connect = connect;
    }

    @SuppressWarnings("unchecked")
    @Override
    public <T> T getProxy(Class<T> targetClass) {
        enhancer.setSuperclass(targetClass);
        enhancer.setCallback((MethodInterceptor) (Object obj, Method method, Object[] args, MethodProxy proxy) -> {
            Class<?> clazz = method.getDeclaringClass();
            if (Object.class.equals(clazz)) {
                return method.invoke(this, args);
            }
            // 组装请求参数,除 requestId
            var request = RpcRequestPacket.builder()
                    .clazz(clazz)
                    .methodName(method.getName())
                    .parameterTypes(method.getParameterTypes())
                    .parameters(args)
                    .build();
            // 发送 RPC 请求
            return connect.sendRequest(request);
        });
        return (T) enhancer.create();
    }

}

服务端

@Slf4j
public class RpcServer {
    private int port;

    private NioEventLoopGroup boss   = new NioEventLoopGroup(1);
    private NioEventLoopGroup worker = new NioEventLoopGroup();

    public RpcServer(int port) {
        this.port = port;
    }

    public static void main(String[] args) throws InterruptedException {
        RpcServer rpcServer = new RpcServer(8000);
        rpcServer.server();
    }

    public void server() throws InterruptedException {
        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(boss, worker)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<Channel>() {
                        @Override
                        protected void initChannel(Channel ch) throws Exception {
                            ChannelPipeline pipeline = ch.pipeline();
                            pipeline.addLast(new PacketDecoder());
                            pipeline.addLast(new PacketEncoder());
                            pipeline.addLast(new RpcServerHandler());
                        }
                    });
            Channel channel = bind(bootstrap, port);
            channel.closeFuture().sync();
            System.out.println();
        } finally {
            boss.shutdownGracefully();
            worker.shutdownGracefully();
        }
    }

    private static Channel bind(ServerBootstrap bootstrap, int port) {
        log.info("尝试绑定[{}]端口", port);
        return bootstrap.bind(port).addListener(future -> {
            if (future.isSuccess()) {
                log.info("绑定端口[{}]成功", port);
            } else {
                log.error("绑定端口[{}]失败", port);
                bind(bootstrap, port + 1);
            }
        }).channel();
    }

}