spring-cloud / spring-cloud-gateway

An API Gateway built on Spring Framework and Spring Boot providing routing and more.
http://cloud.spring.io
Apache License 2.0
4.54k stars 3.33k forks source link

ModifyResponseBodyGatewayFilterFactory block stream response #2775

Closed vineetkala11 closed 8 months ago

vineetkala11 commented 2 years ago

Describe the bug ModifyResponseBodyGatewayFilterFactory is blocker for SSE based resources because writeWith method of ModifiedServerHttpResponse is responsible to modify response body for both stream and non stream mime-type, writeWith convert response body to Mono, this behavior of writeWith is a blocker for SSE.

Due to this issue consumer of SSE endpoint getting all events in one go instead of in form of stream. Without ModifyResponseBodyGatewayFilterFactory things works perfectly.

Current behavior of ModifyResponseBodyGatewayFilterFactory: blocked_events-2

Expected behavior from ModifyResponseBodyGatewayFilterFactory: unblocked_events-2

Fix

This PR contains fix for this issue and idea is to handle writeAndFlushWith method differently for stream mime type, as writeAndFlushWith being called by writer only when response content type is of Stream media type.

Please review this PR, if looks ok please approve it.

Sample

Below test should fail with DataBufferLimitException, which means that ModifyResponseBodyGatewayFilterFactory block all outgoing SSE events, until cancel signal received from publisher, and republish all events after modifying response body, as all events are in memory, which exceed the current buffer limit spring.codec.max-in-memory-size=40 -

    @Test
    public void should_failWithExcedingLimitOfBuffer() {
        int events = 20;
        long delay = 100L;

        URI uri = UriComponentsBuilder.fromUriString(this.baseUri + "/sse?events=" + events + "&delay=" + delay)
                .build(true).toUri();

        testClient.get().uri(uri).header("Host", "www.modifyresponsebodyjavawithsse.org")
                .accept(MediaType.TEXT_EVENT_STREAM).exchange().expectStatus()
                .isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR).expectBody().jsonPath("message")
                .isEqualTo("Exceeded limit on max bytes to buffer : 40");

    }

Adding test class -

public class ModifyResponseBodyGatewayFilterFactoryTests extends BaseWebClientTests {

    private static final String toLarge;

    static {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < 1000; i++) {
            sb.append("to-large-");
        }
        toLarge = sb.toString();
    }

    @Test
    public void testModificationOfResponseBody() {
        URI uri = UriComponentsBuilder.fromUriString(this.baseUri + "/").build(true).toUri();

        testClient.get().uri(uri).header("Host", "www.modifyresponsebodyjava.org").accept(MediaType.APPLICATION_JSON)
                .exchange().expectBody().json("{\"value\": \"httpbin compatible home\", \"length\": 23}");
    }

    @Test
    public void modifyResponeBodyToLarge() {
        testClient.post().uri("/post").header("Host", "www.modifyresponsebodyjavatoolarge.org")
                .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
                .body(BodyInserters.fromValue(toLarge)).exchange().expectStatus()
                .isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR).expectBody().jsonPath("message")
                .isEqualTo("Exceeded limit on max bytes to buffer : 40");
    }

    @Test
    public void should_failWithExcedingLimitOfBuffer() {
        int events = 20;
        long delay = 100L;

        URI uri = UriComponentsBuilder.fromUriString(this.baseUri + "/sse?events=" + events + "&delay=" + delay)
                .build(true).toUri();

        testClient.get().uri(uri).header("Host", "www.modifyresponsebodyjavawithsse.org")
                .accept(MediaType.TEXT_EVENT_STREAM).exchange().expectStatus()
                .isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR).expectBody().jsonPath("message")
                .isEqualTo("Exceeded limit on max bytes to buffer : 40");

    }

    @EnableAutoConfiguration
    @SpringBootConfiguration
    @Import(DefaultTestConfig.class)
    public static class TestConfig {

        @Value("${test.uri}")
        String uri;

        @Bean
        public RouteLocator testRouteLocator(RouteLocatorBuilder builder) {
            return builder.routes().route("modify_response_java_test",
                    r -> r.path("/").and().host("www.modifyresponsebodyjava.org").filters(f -> f.prefixPath("/httpbin")
                            .modifyResponseBody(String.class, Map.class, (webExchange, originalResponse) -> {
                                Map<String, Object> modifiedResponse = new HashMap<>();
                                modifiedResponse.put("value", originalResponse);
                                modifiedResponse.put("length", originalResponse.length());
                                return Mono.just(modifiedResponse);
                            })).uri(uri))
                    .route("modify_response_java_test_to_large",
                            r -> r.path("/").and().host("www.modifyresponsebodyjavatoolarge.org")
                                    .filters(f -> f.prefixPath("/httpbin").modifyResponseBody(String.class,
                                            String.class, (webExchange, originalResponse) -> {
                                                return Mono.just(toLarge);
                                            }))
                                    .uri(uri))
                    .route("modify_response_java_test_sse",
                            r -> r.host("www.modifyresponsebodyjavawithsse.org").filters(f -> f
                                    .modifyResponseBody(byte[].class, String.class, (webExchange, originalResponse) -> {
                                        String originalResponseStr = new String(originalResponse,
                                                StandardCharsets.UTF_8);
                                        String modifiedResponse = originalResponseStr.replace("D -", "MD -");
                                        return Mono.just(modifiedResponse);
                                    })).uri(uri))
                    .build();
        }

    }

}

SSE test controller -

@RestController
@RequestMapping("/httpbin")
public class HttpBinCompatibleController {

    private static final Log log = LogFactory.getLog(HttpBinCompatibleController.class);

    private static final String HEADER_REQ_VARY = "X-Request-Vary";

    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

    @GetMapping("/")
    public String home() {
        return "httpbin compatible home";
    }

    @RequestMapping(path = "/headers", method = { RequestMethod.GET, RequestMethod.POST },
            produces = MediaType.APPLICATION_JSON_VALUE)
    public Map<String, Object> headers(ServerWebExchange exchange) {
        Map<String, Object> result = new HashMap<>();
        result.put("headers", getHeaders(exchange));
        return result;
    }

    @PatchMapping("/headers")
    public ResponseEntity<Map<String, Object>> headersPatch(ServerWebExchange exchange,
            @RequestBody Map<String, String> headersToAdd) {
        Map<String, Object> result = new HashMap<>();
        result.put("headers", getHeaders(exchange));
        ResponseEntity.BodyBuilder responseEntity = ResponseEntity.status(HttpStatus.OK);
        headersToAdd.forEach(responseEntity::header);

        return responseEntity.body(result);
    }

    @RequestMapping(path = "/multivalueheaders", method = { RequestMethod.GET, RequestMethod.POST },
            produces = MediaType.APPLICATION_JSON_VALUE)
    public Map<String, Object> multiValueHeaders(ServerWebExchange exchange) {
        Map<String, Object> result = new HashMap<>();
        result.put("headers", exchange.getRequest().getHeaders());
        return result;
    }

    @GetMapping(path = "/delay/{sec}/**", produces = MediaType.APPLICATION_JSON_VALUE)
    public Mono<Map<String, Object>> delay(ServerWebExchange exchange, @PathVariable int sec)
            throws InterruptedException {
        int delay = Math.min(sec, 10);
        return Mono.just(get(exchange)).delayElement(Duration.ofSeconds(delay));
    }

    @GetMapping(path = "/anything/{anything}", produces = MediaType.APPLICATION_JSON_VALUE)
    public Map<String, Object> anything(ServerWebExchange exchange, @PathVariable(required = false) String anything) {
        return get(exchange);
    }

    @GetMapping(path = "/get", produces = MediaType.APPLICATION_JSON_VALUE)
    public Map<String, Object> get(ServerWebExchange exchange) {
        if (log.isDebugEnabled()) {
            log.debug("httpbin /get");
        }
        HashMap<String, Object> result = new HashMap<>();
        HashMap<String, String> params = new HashMap<>();
        exchange.getRequest().getQueryParams().forEach((name, values) -> {
            params.put(name, values.get(0));
        });
        result.put("args", params);
        result.put("headers", getHeaders(exchange));
        return result;
    }

    @PostMapping(value = "/post", consumes = MediaType.MULTIPART_FORM_DATA_VALUE,
            produces = MediaType.APPLICATION_JSON_VALUE)
    public Mono<Map<String, Object>> postFormData(@RequestBody Mono<MultiValueMap<String, Part>> parts) {
        // StringDecoder decoder = StringDecoder.allMimeTypes(true);
        return parts.flux().flatMap(map -> Flux.fromIterable(map.values())).flatMap(Flux::fromIterable)
                .filter(part -> part instanceof FilePart).reduce(new HashMap<String, Object>(), (files, part) -> {
                    MediaType contentType = part.headers().getContentType();
                    long contentLength = part.headers().getContentLength();
                    // TODO: get part data
                    files.put(part.name(), "data:" + contentType + ";base64," + contentLength);
                    return files;
                }).map(files -> Collections.singletonMap("files", files));
    }

    @PostMapping(path = "/post", consumes = MediaType.APPLICATION_FORM_URLENCODED_VALUE,
            produces = MediaType.APPLICATION_JSON_VALUE)
    public Mono<Map<String, Object>> postUrlEncoded(ServerWebExchange exchange) throws IOException {
        return post(exchange, null);
    }

    @PostMapping(path = "/post", produces = MediaType.APPLICATION_JSON_VALUE)
    public Mono<Map<String, Object>> post(ServerWebExchange exchange, @RequestBody(required = false) String body)
            throws IOException {
        HashMap<String, Object> ret = new HashMap<>();
        ret.put("headers", getHeaders(exchange));
        ret.put("data", body);
        HashMap<String, Object> form = new HashMap<>();
        ret.put("form", form);

        return exchange.getFormData().flatMap(map -> {
            for (Map.Entry<String, List<String>> entry : map.entrySet()) {
                for (String value : entry.getValue()) {
                    form.put(entry.getKey(), value);
                }
            }
            return Mono.just(ret);
        });
    }

    @GetMapping("/status/{status}")
    public ResponseEntity<String> status(@PathVariable int status) {
        return ResponseEntity.status(status).body("Failed with " + status);
    }

    @RequestMapping(value = "/responseheaders/{status}", method = { RequestMethod.GET, RequestMethod.POST })
    public ResponseEntity<Map<String, Object>> responseHeaders(@PathVariable int status, ServerWebExchange exchange) {
        HttpHeaders httpHeaders = exchange.getRequest().getHeaders().entrySet().stream()
                .filter(entry -> entry.getKey().startsWith("X-Test-"))
                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,
                        (list1, list2) -> Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList()),
                        HttpHeaders::new));

        return ResponseEntity.status(status).headers(httpHeaders).body(Collections.singletonMap("status", status));
    }

    @PostMapping(path = "/post/empty", produces = MediaType.APPLICATION_JSON_VALUE)
    public Mono<String> emptyResponse() {
        return Mono.empty();
    }

    @GetMapping(path = "/gzip", produces = MediaType.APPLICATION_JSON_VALUE)
    public Mono<Void> gzip(ServerWebExchange exchange) throws IOException {
        if (log.isDebugEnabled()) {
            log.debug("httpbin /gzip");
        }

        String jsonResponse = OBJECT_MAPPER.writeValueAsString("httpbin compatible home");
        byte[] bytes = jsonResponse.getBytes(StandardCharsets.UTF_8);

        ServerHttpResponse response = exchange.getResponse();
        response.getHeaders().add(HttpHeaders.CONTENT_ENCODING, "gzip");
        DataBufferFactory dataBufferFactory = response.bufferFactory();
        response.setStatusCode(HttpStatus.OK);

        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        GZIPOutputStream is = new GZIPOutputStream(bos);
        FileCopyUtils.copy(bytes, is);

        byte[] gzippedResponse = bos.toByteArray();
        DataBuffer wrap = dataBufferFactory.wrap(gzippedResponse);
        return response.writeWith(Flux.just(wrap));
    }

    @GetMapping("/vary-on-header/**")
    public ResponseEntity<Map<String, Object>> varyOnAccept(ServerWebExchange exchange,
            @RequestHeader(name = HEADER_REQ_VARY, required = false) String headerToVary) {
        if (headerToVary == null) {
            return ResponseEntity.badRequest().body(Map.of("error", HEADER_REQ_VARY + " header is mandatory"));
        }
        else {
            var builder = ResponseEntity.ok();
            builder.varyBy(headerToVary);
            return builder.body(headers(exchange));
        }
    }

    @GetMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<ServerSentEvent<String>> sseEvents(
            @RequestParam(value = "events", required = false, defaultValue = "5") int events,
            @RequestParam(value = "delay", required = false, defaultValue = "500") long delay) {
        return Flux.interval(Duration.ofMillis(delay)).take(events).map(e -> {
            return ServerSentEvent.<String>builder().id(e.toString()).event("notice").data("D - " + e.toString())
                    .build();
        });
    }

    public Map<String, String> getHeaders(ServerWebExchange exchange) {
        return exchange.getRequest().getHeaders().toSingleValueMap();
    }

}
vineetkala11 commented 1 year ago

Hello @spencergibb

Please share your view on this issue.

usrivastava92 commented 1 year ago

@spencergibb this issue has been reported multiple times and has been open for a really long time, is there anything that you suggest improving in the PR?

spencergibb commented 8 months ago

Closing in favor of https://github.com/spring-cloud/spring-cloud-gateway/pull/2774