MarcGiffing / bucket4j-spring-boot-starter

Spring Boot Starter for Bucket4j
Apache License 2.0
298 stars 63 forks source link

Solution for getting ServletRequestFilter working in external Tomcat war deployments #309

Open dariosanna opened 3 months ago

dariosanna commented 3 months ago

As far as my understanding is correct, configured ServletRequests are only registered correctly when using embedded Tomcat.

As a solution, I implemented my own filter, which works similarly to the Spring Security Filterchain, delegating requests to the list of ServletRequestFilters registered in the application context.

In order to access the registered ServletRequestFilter instances, they are passed to the delegating filter after the aplicationcontext refresh (because my own Filter does not have access to the list of ServletRequestFilters at creation time).

In addition, two points are currently only possible as a workaround:

This solution should work also for embedded Tomcat (not verified). To get rid of the workarounds, the following enhancements would helpfull:

I would be very happy if the suggestions or parts of them could be taken into account in a next version

Here are the implementations:

Spring Boot Configuration

import com.giffing.bucket4j.spring.boot.starter.filter.servlet.ServletRequestFilter;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import lombok.extern.log4j.Log4j2;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.context.event.EventListener;
import org.springframework.core.Ordered;
import java.util.LinkedList;
import java.util.List;

@Configuration
@Log4j2
@ConditionalOnProperty(name = "bucket4j.enabled")
public class RateLimitConfiguration {

    private RateLimitFilter rateLimitFilter;

    @Bean
    public FilterRegistrationBean<RateLimitFilter> registrationRateLimitFilter() {
        this.rateLimitFilter = new RateLimitFilter();

        FilterRegistrationBean<RateLimitFilter> filterRegistrationBean = new FilterRegistrationBean<>();

        filterRegistrationBean.setName("rateLimitFilter");
        filterRegistrationBean.setFilter(rateLimitFilter);
        filterRegistrationBean.setDispatcherTypes(DispatcherType.REQUEST);
        filterRegistrationBean.setOrder(Ordered.HIGHEST_PRECEDENCE + 10);

        return filterRegistrationBean;
    }

    @EventListener
    public void handleContextRefreshedEvent(ContextRefreshedEvent event) {
        ApplicationContext applicationContext = event.getApplicationContext ();

        List<ServletRequestFilter> servletRequestFilters = new LinkedList<> ();

        String[] filterBeanNames = applicationContext.getBeanNamesForType(Filter.class);
        for (String beanName: filterBeanNames) {
            Filter filter = applicationContext.getBean(beanName, Filter.class);
            if (filter instanceof ServletRequestFilter servletRequestFilter) {
                servletRequestFilters.add (servletRequestFilter);
            };
        }

        this.rateLimitFilter.setServletRequestFilters(servletRequestFilters);
    }

}

Filter

import com.giffing.bucket4j.spring.boot.starter.filter.servlet.ServletRequestFilter;
import jakarta.servlet.Filter;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.log4j.Log4j2;
import org.springframework.web.filter.OncePerRequestFilter;

import java.io.IOException;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;

@Log4j2
public class RateLimitFilter extends OncePerRequestFilter {

    private LinkedList<ServletRequestFilter> servletRequestFilters = new LinkedList<> ();

    @Override
    protected void doFilterInternal (HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        if (this.servletRequestFilters.isEmpty ()) {
            filterChain.doFilter (request, response);
            return;
        }

        new VirtualFilterChain (filterChain, this.servletRequestFilters).doFilter (request, response);
    }

    public void setServletRequestFilters (List<ServletRequestFilter> servletRequestFilters) {
        this.servletRequestFilters.addAll (servletRequestFilters);
        this.servletRequestFilters.sort (Comparator.comparing (ServletRequestFilter::getOrder));
    }

    private static class VirtualFilterChain implements FilterChain {

        private final FilterChain originalChain;

        private final List<? extends Filter> filters;

        private int currentPosition = 0;

        public VirtualFilterChain(FilterChain chain, List<? extends Filter> filters) {
            this.originalChain = chain;
            this.filters = filters;
        }

        @Override
        public void doFilter(final ServletRequest request, final ServletResponse response) throws IOException, ServletException {
            if (response instanceof HttpServletResponse httpServletResponse) {
                // TO many requests?
                if (httpServletResponse.getStatus () == 429) {
                    return;
                }

                // ServletRequestFilter matched and proceed the request?
                if (httpServletResponse.containsHeader ("X-Rate-Limit-Remaining")) {
                    this.currentPosition = this.filters.size();
                }
            }

            if (this.currentPosition == this.filters.size()) {
                this.originalChain.doFilter(request, response);
            } else {
                this.currentPosition++;
                Filter nextFilter = this.filters.get(this.currentPosition - 1);
                nextFilter.doFilter(request, response, this);
            }
        }
    }
}
MarcGiffing commented 3 months ago

The provided solution has a limitation: only one ServletFilter is registered with a specific order.

filterRegistrationBean.setOrder(Ordered.HIGHEST_PRECEDENCE + 10);

The current solution registers multiple filters, each with a distinct order. This is necessary, for example, if you want the rate limit to be executed before and/or after a security filter as an example.

dariosanna commented 3 months ago

only one ServletFilter is registered with a specific order.

yes, but the filter delegates to multiple ServletRequestFilter, also in a specific order (all before other regular filters). To have ServletRequestFilter applied after other regular filters (like Spring Security), the delegating filter can easily enhanced by delegating to filters after

this.originalChain.doFilter(request, response);

To differentiate between ServletRequestFilter that should applied "before regular filters" and "after regular filters", a new element at ServletRequestFilter could be introduced, or by different order ranges.

The above solution is not as fine-graind as the current solution, but it provides a huge range of use-cases. In fact, i can't see any use case, that makes it necessary to put one ServletRequestFilter bevore regular filter one, another before regular two and onother after regular filter three (for example) - one filter execution chain before all regular filters and one execution chain after all regular filters would deliver a useable solution.

The provided implementation can also be enhanced with a @ConditionalOnProperty(name="bucket4j.useDelegatingFilter=true") (for example) to enable/disable the DelegatingFilter.

What about enhancements to get rid of the workarounds:

dariosanna commented 2 months ago

any updates here?