aws / aws-xray-sdk-java

The official AWS X-Ray Recorder SDK for Java.
Apache License 2.0
95 stars 99 forks source link

Support for Spring Framework 6 (jakarta vs javax) #364

Closed jeroenvandevelde closed 1 year ago

jeroenvandevelde commented 1 year ago

Is it possible to add support for Spring Framework 6? The first problem i found is that xray for spring is using the javax. packages while Spring Framework 6 is using jakarta. packages. For example see the below code:

@Bean
public Filter TracingFilter() {
     return new AWSXRayServletFilter(new FixedSegmentNamingStrategy("defaultSegmentName"));
}

Willing to help, if help is wanted.

atshaw43 commented 1 year ago

Thanks for bringing this to our attention.

I will add this as a work item for us.

vnonchev-mentormate commented 1 year ago

Could you please share what the rough estimated timeline for this enhancement is, or where we could see progress on it? Thank you!

joshuavillano commented 1 year ago

Does this mean we wont be able to upgrade to Spring 6/ boot 3 if we use these dependencies?

willarmiros commented 1 year ago

Hi all, unfortunately we're not able to provide timeline estimates, but we will be prioritizing this soon and will update this post when we do. In the meantime, any contributions are welcome!

mjgp2 commented 1 year ago

While the official fix is being worked on you can patch it, something like this:

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package com.amazonaws.xray.jakarta.servlet;

import com.amazonaws.xray.AWSXRay;
import com.amazonaws.xray.AWSXRayRecorder;
import com.amazonaws.xray.entities.Entity;
import com.amazonaws.xray.entities.Segment;
import com.amazonaws.xray.entities.TraceHeader;
import com.amazonaws.xray.entities.TraceHeader.SampleDecision;
import com.amazonaws.xray.entities.TraceID;
import com.amazonaws.xray.strategy.sampling.SamplingRequest;
import com.amazonaws.xray.strategy.sampling.SamplingResponse;
import com.amazonaws.xray.strategy.sampling.SamplingStrategy;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

// TODO: open a PR - https://github.com/aws/aws-xray-sdk-java/issues/364
public class AWSXRayServletFilter implements jakarta.servlet.Filter {

    private class AWSXRayServletAsyncListener implements AsyncListener {

        public static final String ENTITY_ATTRIBUTE_KEY = "com.amazonaws.xray.entities.Entity";

        private void processEvent(AsyncEvent event) {
            Entity entity = (Entity) event.getSuppliedRequest().getAttribute(ENTITY_ATTRIBUTE_KEY);
            entity.run(
                    () -> {
                        if (event.getThrowable() != null) {
                            entity.addException(event.getThrowable());
                        }
                        AWSXRayServletFilter.this.postFilter(event.getSuppliedResponse());
                    });
        }

        @Override
        public void onComplete(AsyncEvent event) throws IOException {
            processEvent(event);
        }

        @Override
        public void onTimeout(AsyncEvent event) throws IOException {
            processEvent(event);
        }

        @Override
        public void onError(AsyncEvent event) throws IOException {
            processEvent(event);
        }

        @Override
        public void onStartAsync(AsyncEvent event) throws IOException {
            // DO NOTHING
        }
    }

    private static final Log logger = LogFactory.getLog(AWSXRayServletFilter.class);

    private final String segmentDefaultName;

    private final AWSXRayRecorder recorder;
    private final AWSXRayServletAsyncListener listener;

    public AWSXRayServletFilter(String fallbackName) {
        this(fallbackName, null);
    }

    public AWSXRayServletFilter(String fallbackName, AWSXRayRecorder recorder) {
        this.recorder = recorder == null ? AWSXRay.getGlobalRecorder() : recorder;
        this.listener = new AWSXRayServletAsyncListener();
        this.segmentDefaultName = fallbackName;
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        if (logger.isDebugEnabled()) {
            logger.debug(
                    "AWSXRayServletFilter is beginning to process request: " + request.toString());
        }
        Segment segment = preFilter(request, response);

        try {
            chain.doFilter(request, response);
        } catch (Throwable e) {
            segment.addException(e);
            throw e;
        } finally {
            if (request.isAsyncStarted()) {
                request.setAttribute(AWSXRayServletAsyncListener.ENTITY_ATTRIBUTE_KEY, segment);
                try {
                    request.getAsyncContext().addListener(listener);
                    if (recorder != null) {
                        recorder.clearTraceEntity();
                    }
                } catch (IllegalStateException ise) {
                    // race condition that occurs when async processing finishes before adding the
                    // listener
                    postFilter(response);
                }
            } else {
                postFilter(response);
            }

            if (logger.isDebugEnabled()) {
                logger.debug(
                        "AWSXRayServletFilter is finished processing request: "
                                + request.toString());
            }
        }
    }

    private HttpServletRequest castServletRequest(ServletRequest request) {
        try {
            return (HttpServletRequest) request;
        } catch (ClassCastException cce) {
            logger.warn("Unable to cast ServletRequest to HttpServletRequest.", cce);
        }
        return null;
    }

    private HttpServletResponse castServletResponse(ServletResponse response) {
        try {
            return (HttpServletResponse) response;
        } catch (ClassCastException cce) {
            logger.warn("Unable to cast ServletResponse to HttpServletResponse.", cce);
        }
        return null;
    }

    private Optional<TraceHeader> getTraceHeader(HttpServletRequest request) {
        String traceHeaderString = request.getHeader(TraceHeader.HEADER_KEY);
        if (null != traceHeaderString) {
            return Optional.of(TraceHeader.fromString(traceHeaderString));
        }
        return Optional.empty();
    }

    private Optional<String> getHost(HttpServletRequest request) {
        return Optional.ofNullable(request.getHeader("Host"));
    }

    private Optional<String> getClientIp(HttpServletRequest request) {
        return Optional.ofNullable(request.getRemoteAddr());
    }

    private Optional<String> getXForwardedFor(HttpServletRequest request) {
        String forwarded = request.getHeader("X-Forwarded-For");
        if (forwarded != null) {
            return Optional.of(forwarded.split(",")[0].trim());
        }
        return Optional.empty();
    }

    private Optional<String> getUserAgent(HttpServletRequest request) {
        String userAgentHeaderString = request.getHeader("User-Agent");
        if (null != userAgentHeaderString) {
            return Optional.of(userAgentHeaderString);
        }
        return Optional.empty();
    }

    private Optional<Integer> getContentLength(HttpServletResponse response) {
        String contentLengthString = response.getHeader("Content-Length");
        if (null != contentLengthString && !contentLengthString.isEmpty()) {
            try {
                return Optional.of(Integer.parseInt(contentLengthString));
            } catch (NumberFormatException nfe) {
                logger.debug(
                        "Unable to parse Content-Length header from HttpServletResponse.", nfe);
            }
        }
        return Optional.empty();
    }

    private String getSegmentName(HttpServletRequest httpServletRequest) {
        Optional<String> hostHeaderValue =
                Optional.ofNullable(httpServletRequest.getHeader("Host"));
        if (hostHeaderValue.isPresent()) {
            return hostHeaderValue.get();
        }
        return segmentDefaultName;
    }

    private SamplingResponse fromSamplingStrategy(HttpServletRequest httpServletRequest) {
        SamplingRequest samplingRequest =
                new SamplingRequest(
                        getSegmentName(httpServletRequest),
                        getHost(httpServletRequest).orElse(null),
                        httpServletRequest.getRequestURI(),
                        httpServletRequest.getMethod(),
                        recorder.getOrigin());
        return recorder.getSamplingStrategy().shouldTrace(samplingRequest);
    }

    private SampleDecision getSampleDecision(SamplingResponse sample) {
        if (sample.isSampled()) {
            logger.debug("Sampling strategy decided SAMPLED.");
            return SampleDecision.SAMPLED;
        } else {
            logger.debug("Sampling strategy decided NOT_SAMPLED.");
            return SampleDecision.NOT_SAMPLED;
        }
    }

    public Segment preFilter(ServletRequest request, ServletResponse response) {
        HttpServletRequest httpServletRequest = castServletRequest(request);
        if (httpServletRequest == null) {
            logger.warn("Null value for incoming HttpServletRequest. Beginning NoOpSegment.");
            return recorder.beginNoOpSegment();
        }

        Optional<TraceHeader> incomingHeader = getTraceHeader(httpServletRequest);
        SamplingStrategy samplingStrategy = recorder.getSamplingStrategy();

        if (logger.isDebugEnabled() && incomingHeader.isPresent()) {
            logger.debug("Incoming trace header received: " + incomingHeader.get().toString());
        }

        SamplingResponse samplingResponse = fromSamplingStrategy(httpServletRequest);

        SampleDecision sampleDecision =
                incomingHeader.isPresent()
                        ? incomingHeader.get().getSampled()
                        : getSampleDecision(samplingResponse);
        if (SampleDecision.REQUESTED.equals(sampleDecision)
                || SampleDecision.UNKNOWN.equals(sampleDecision)) {
            sampleDecision = getSampleDecision(samplingResponse);
        }

        TraceID traceId = null;
        String parentId = null;
        if (incomingHeader.isPresent()) {
            TraceHeader header = incomingHeader.get();
            traceId = header.getRootTraceId();
            parentId = header.getParentId();
        }

        final Segment created;
        if (SampleDecision.SAMPLED.equals(sampleDecision)) {
            String segmentName = getSegmentName(httpServletRequest);
            created =
                    traceId != null
                            ? recorder.beginSegment(segmentName, traceId, parentId)
                            : recorder.beginSegment(segmentName);
            if (samplingResponse.getRuleName().isPresent()) {
                logger.debug(
                        "Sampling strategy decided to use rule named: "
                                + samplingResponse.getRuleName().get()
                                + ".");
                created.setRuleName(samplingResponse.getRuleName().get());
            }
        } else { // NOT_SAMPLED
            String segmentName = getSegmentName(httpServletRequest);
            if (samplingStrategy.isForcedSamplingSupported()) {
                created =
                        traceId != null
                                ? recorder.beginSegment(segmentName, traceId, parentId)
                                : recorder.beginSegment(segmentName);
                created.setSampled(false);
            } else {
                logger.debug("Creating Dummy Segment");
                created =
                        traceId != null
                                ? recorder.beginNoOpSegment(traceId)
                                : recorder.beginNoOpSegment();
            }
        }

        Map<String, Object> requestAttributes = new HashMap<>();
        requestAttributes.put("url", httpServletRequest.getRequestURL().toString());
        requestAttributes.put("method", httpServletRequest.getMethod());

        Optional<String> userAgent = getUserAgent(httpServletRequest);
        if (userAgent.isPresent()) {
            requestAttributes.put("user_agent", userAgent.get());
        }

        Optional<String> xForwardedFor = getXForwardedFor(httpServletRequest);
        if (xForwardedFor.isPresent()) {
            requestAttributes.put("client_ip", xForwardedFor.get());
            requestAttributes.put("x_forwarded_for", true);
        } else {
            Optional<String> clientIp = getClientIp(httpServletRequest);
            if (clientIp.isPresent()) {
                requestAttributes.put("client_ip", clientIp.get());
            }
        }

        created.putHttp("request", requestAttributes);

        HttpServletResponse httpServletResponse = castServletResponse(response);
        if (httpServletResponse == null) {
            return created;
        }

        final TraceHeader responseHeader;
        if (incomingHeader.isPresent()) {
            // create a new header, and use the incoming header so we know what to do in regards to
            // sending back the sampling
            // decision.
            responseHeader = new TraceHeader(created.getTraceId());
            if (SampleDecision.REQUESTED == incomingHeader.get().getSampled()) {
                responseHeader.setSampled(
                        created.isSampled() ? SampleDecision.SAMPLED : SampleDecision.NOT_SAMPLED);
            }
        } else {
            // Create a new header, we're the tracing root. We wont return the sampling decision.
            responseHeader = new TraceHeader(created.getTraceId());
        }
        httpServletResponse.addHeader(TraceHeader.HEADER_KEY, responseHeader.toString());

        return created;
    }

    public void postFilter(ServletResponse response) {
        Segment segment = recorder.getCurrentSegment();
        if (null != segment) {
            HttpServletResponse httpServletResponse = castServletResponse(response);

            if (null != httpServletResponse) {
                Map<String, Object> responseAttributes = new HashMap<String, Object>();

                int responseCode = httpServletResponse.getStatus();
                switch (responseCode / 100) {
                    case 4:
                        segment.setError(true);
                        if (responseCode == 429) {
                            segment.setThrottle(true);
                        }
                        break;
                    case 5:
                        segment.setFault(true);
                        break;
                    default:
                        break;
                }
                responseAttributes.put("status", responseCode);

                Optional<Integer> contentLength = getContentLength(httpServletResponse);
                if (contentLength.isPresent()) {
                    responseAttributes.put("content_length", contentLength.get());
                }

                segment.putHttp("response", responseAttributes);
            }

            recorder.endSegment();
        }
    }
}
brianforkan commented 1 year ago

Any idea when this will be released @srprash?

willarmiros commented 1 year ago

Hi @brianforkan - we will be trying to get a release out for this as soon as possible. Unfortunately we can't provide exact dates, but feel free to "watch" for the next release on the repo.

srprash commented 1 year ago

The support has been released in the v2.14.0.