spring-cloud / spring-cloud-gateway

An API Gateway built on Spring Framework and Spring Boot providing routing and more.
http://cloud.spring.io
Apache License 2.0
4.49k stars 3.31k forks source link

IllegalArgumentException: Unable to find GatewayFilterFactory with name RequestRateLimiter - after upgrading to springboot 3.1.6 #3258

Closed Abhishekj2096 closed 5 months ago

Abhishekj2096 commented 7 months ago

I am using the following- Spring-boot version 3.1.6 Spring-cloud-dependencies version 2022.0.3 Java version 17 Kotlin version 1.8.20 Gradle version 7.4.2 Docker-compose version 0.17.5 Junit-jupiter version 5.7.2

I am running unit test for rate limiting using Redis. The Redis server is run on a docker container created from docker-compose. The test run perfectly fine when I am on these versions - Spring-boot version 2.7.14 Spring-cloud-dependencies version 2021.0.7 Java version 11 Kotlin version 1.7.0 Gradle version 7.4.2 Docker-compose version 0.17.5 Junit-jupiter version 5.7.2

But after upgrading to spring-boot 3+ version (mentioned on top) the unit test begin to fail with the error -

java.lang.IllegalStateException: Failed to load ApplicationContext for [ReactiveWebMergedContextConfiguration@356e3c0f testClass = com.test.lnt.strix.redis.RedisRateLimiterTest, locations = [], classes = [com.test.lnt.strix.lntGatewayApplication], contextInitializerClasses = [com.test.lnt.strix.ApplicationContextInitializerImpl], activeProfiles = ["redis-test"], propertySourceLocations = ["classpath:/com/test/lnt/strix/auth/hmac.properties"], propertySourceProperties = ["org.springframework.boot.test.context.SpringBootTestContextBootstrapper=true", "server.port=0"], contextCustomizers = [org.springframework.boot.test.autoconfigure.actuate.observability.ObservabilityContextCustomizerFactory$DisableObservabilityContextCustomizer@1f, org.springframework.boot.test.autoconfigure.properties.PropertyMappingContextCustomizer@9fdfac6d, org.springframework.boot.test.autoconfigure.web.servlet.WebDriverContextCustomizer@78e68401, [ImportsContextCustomizer@6adcccc7 key = [org.springframework.cloud.contract.wiremock.WireMockRestTemplateConfiguration, org.springframework.cloud.contract.wiremock.WireMockConfiguration]], org.springframework.boot.test.context.filter.ExcludeFilterContextCustomizer@7f3c0399, org.springframework.boot.test.json.DuplicateJsonObjectContextCustomizerFactory$DuplicateJsonObjectContextCustomizer@2c9d90fc, org.springframework.boot.test.mock.mockito.MockitoContextCustomizer@0, org.springframework.boot.test.web.client.TestRestTemplateContextCustomizer@634f58d2, org.springframework.boot.test.web.reactive.server.WebTestClientContextCustomizer@1a1f79ce, org.springframework.boot.test.context.SpringBootTestAnnotation@a17b40a1], contextLoader = org.springframework.boot.test.context.SpringBootContextLoader, parent = null]
    at org.springframework.test.context.cache.DefaultCacheAwareContextLoaderDelegate.loadContext(DefaultCacheAwareContextLoaderDelegate.java:143)
    at org.springframework.test.context.support.DefaultTestContext.getApplicationContext(DefaultTestContext.java:127)
    at org.springframework.test.context.support.DependencyInjectionTestExecutionListener.injectDependencies(DependencyInjectionTestExecutionListener.java:141)
    at org.springframework.test.context.support.DependencyInjectionTestExecutionListener.prepareTestInstance(DependencyInjectionTestExecutionListener.java:97)
    at org.springframework.test.context.TestContextManager.prepareTestInstance(TestContextManager.java:241)
    at org.springframework.test.context.junit.jupiter.SpringExtension.postProcessTestInstance(SpringExtension.java:138)
    at org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor.lambda$invokeTestInstancePostProcessors$10(ClassBasedTestDescriptor.java:377)
    at org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor.executeAndMaskThrowable(ClassBasedTestDescriptor.java:382)
    at org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor.lambda$invokeTestInstancePostProcessors$11(ClassBasedTestDescriptor.java:377)
    at java.base/java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:197)
    at java.base/java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:179)
    at java.base/java.util.ArrayList$ArrayListSpliterator.forEachRemaining(ArrayList.java:1625)
    at java.base/java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:509)
    at java.base/java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:499)
    at java.base/java.util.stream.StreamSpliterators$WrappingSpliterator.forEachRemaining(StreamSpliterators.java:310)
    at java.base/java.util.stream.Streams$ConcatSpliterator.forEachRemaining(Streams.java:735)
    at java.base/java.util.stream.Streams$ConcatSpliterator.forEachRemaining(Streams.java:734)
    at java.base/java.util.stream.ReferencePipeline$Head.forEach(ReferencePipeline.java:762)
    at org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor.invokeTestInstancePostProcessors(ClassBasedTestDescriptor.java:376)
    at org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor.lambda$instantiateAndPostProcessTestInstance$6(ClassBasedTestDescriptor.java:289)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor.instantiateAndPostProcessTestInstance(ClassBasedTestDescriptor.java:288)
    at org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor.lambda$testInstancesProvider$4(ClassBasedTestDescriptor.java:278)
    at java.base/java.util.Optional.orElseGet(Optional.java:364)
    at org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor.lambda$testInstancesProvider$5(ClassBasedTestDescriptor.java:277)
    at org.junit.jupiter.engine.execution.TestInstancesProvider.getTestInstances(TestInstancesProvider.java:31)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.lambda$prepare$0(TestMethodTestDescriptor.java:105)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.prepare(TestMethodTestDescriptor.java:104)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.prepare(TestMethodTestDescriptor.java:68)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$prepare$2(NodeTestTask.java:123)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.prepare(NodeTestTask.java:123)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:90)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:41)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:155)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
    at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:41)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:155)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
    at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.submit(SameThreadHierarchicalTestExecutorService.java:35)
    at org.junit.platform.engine.support.hierarchical.HierarchicalTestExecutor.execute(HierarchicalTestExecutor.java:57)
    at org.junit.platform.engine.support.hierarchical.HierarchicalTestEngine.execute(HierarchicalTestEngine.java:54)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:108)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:88)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.lambda$execute$0(EngineExecutionOrchestrator.java:54)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.withInterceptedStreams(EngineExecutionOrchestrator.java:67)
    at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:52)
    at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:96)
    at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:75)
    at org.gradle.api.internal.tasks.testing.junitplatform.JUnitPlatformTestClassProcessor$CollectAllTestClassesExecutor.processAllTestClasses(JUnitPlatformTestClassProcessor.java:99)
    at org.gradle.api.internal.tasks.testing.junitplatform.JUnitPlatformTestClassProcessor$CollectAllTestClassesExecutor.access$000(JUnitPlatformTestClassProcessor.java:79)
    at org.gradle.api.internal.tasks.testing.junitplatform.JUnitPlatformTestClassProcessor.stop(JUnitPlatformTestClassProcessor.java:75)
    at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:568)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
    at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
    at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
    at jdk.proxy2/jdk.proxy2.$Proxy5.stop(Unknown Source)
    at org.gradle.api.internal.tasks.testing.worker.TestWorker$3.run(TestWorker.java:193)
    at org.gradle.api.internal.tasks.testing.worker.TestWorker.executeAndMaintainThreadName(TestWorker.java:129)
    at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:100)
    at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:60)
    at org.gradle.process.internal.worker.child.ActionExecutionWorker.execute(ActionExecutionWorker.java:56)
    at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:133)
    at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:71)
    at worker.org.gradle.process.internal.worker.GradleWorkerMain.run(GradleWorkerMain.java:69)
    at worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)
Caused by: java.lang.IllegalArgumentException: Unable to find GatewayFilterFactory with name RequestRateLimiter
    at org.springframework.cloud.gateway.route.RouteDefinitionRouteLocator.loadGatewayFilters(RouteDefinitionRouteLocator.java:145)
    at org.springframework.cloud.gateway.route.RouteDefinitionRouteLocator.getFilters(RouteDefinitionRouteLocator.java:192)
    at org.springframework.cloud.gateway.route.RouteDefinitionRouteLocator.convertToRoute(RouteDefinitionRouteLocator.java:132)
    at reactor.core.publisher.FluxMapFuseable$MapFuseableSubscriber.onNext(FluxMapFuseable.java:113)
    at reactor.core.publisher.FluxFlattenIterable$FlattenIterableSubscriber.drainSync(FluxFlattenIterable.java:640)
    at reactor.core.publisher.FluxFlattenIterable$FlattenIterableSubscriber.drain(FluxFlattenIterable.java:721)
    at reactor.core.publisher.FluxFlattenIterable$FlattenIterableSubscriber.request(FluxFlattenIterable.java:303)
    at reactor.core.publisher.FluxMapFuseable$MapFuseableSubscriber.request(FluxMapFuseable.java:171)
    at reactor.core.publisher.FluxMapFuseable$MapFuseableSubscriber.request(FluxMapFuseable.java:171)
    at reactor.core.publisher.Operators$BaseFluxToMonoOperator.request(Operators.java:2041)
    at reactor.core.publisher.MonoFlatMapMany$FlatMapManyMain.onSubscribe(MonoFlatMapMany.java:141)
    at reactor.core.publisher.Operators$BaseFluxToMonoOperator.onSubscribe(Operators.java:2025)
    at reactor.core.publisher.FluxMapFuseable$MapFuseableSubscriber.onSubscribe(FluxMapFuseable.java:96)
    at reactor.core.publisher.FluxMapFuseable$MapFuseableSubscriber.onSubscribe(FluxMapFuseable.java:96)
    at reactor.core.publisher.FluxFlattenIterable$FlattenIterableSubscriber.onSubscribe(FluxFlattenIterable.java:222)
    at reactor.core.publisher.FluxIterable.subscribe(FluxIterable.java:201)
    at reactor.core.publisher.FluxIterable.subscribe(FluxIterable.java:105)
    at reactor.core.publisher.MonoFlattenIterable.subscribeOrReturn(MonoFlattenIterable.java:79)
    at reactor.core.publisher.Flux.subscribe(Flux.java:8759)
    at reactor.core.publisher.FluxMergeSequential$MergeSequentialMain.onNext(FluxMergeSequential.java:237)
    at reactor.core.publisher.FluxIterable$IterableSubscription.slowPath(FluxIterable.java:335)
    at reactor.core.publisher.FluxIterable$IterableSubscription.request(FluxIterable.java:294)
    at reactor.core.publisher.FluxMergeSequential$MergeSequentialMain.onSubscribe(FluxMergeSequential.java:198)
    at reactor.core.publisher.FluxIterable.subscribe(FluxIterable.java:201)
    at reactor.core.publisher.FluxIterable.subscribe(FluxIterable.java:83)
    at reactor.core.publisher.InternalFluxOperator.subscribe(InternalFluxOperator.java:62)
    at reactor.core.publisher.FluxDefer.subscribe(FluxDefer.java:54)
    at reactor.core.publisher.Flux.subscribe(Flux.java:8773)
    at reactor.core.publisher.Flux.blockLast(Flux.java:2752)
    at org.springframework.cloud.gateway.filter.WeightCalculatorWebFilter.lambda$onApplicationEvent$0(WeightCalculatorWebFilter.java:134)
    at org.springframework.beans.factory.support.DefaultListableBeanFactory$DependencyObjectProvider.ifAvailable(DefaultListableBeanFactory.java:2070)
    at org.springframework.cloud.gateway.filter.WeightCalculatorWebFilter.onApplicationEvent(WeightCalculatorWebFilter.java:134)
    at org.springframework.context.event.SimpleApplicationEventMulticaster.doInvokeListener(SimpleApplicationEventMulticaster.java:174)
    at org.springframework.context.event.SimpleApplicationEventMulticaster.invokeListener(SimpleApplicationEventMulticaster.java:167)
    at org.springframework.context.event.SimpleApplicationEventMulticaster.multicastEvent(SimpleApplicationEventMulticaster.java:145)
    at org.springframework.context.support.AbstractApplicationContext.publishEvent(AbstractApplicationContext.java:445)
    at org.springframework.context.support.AbstractApplicationContext.publishEvent(AbstractApplicationContext.java:378)
    at org.springframework.cloud.gateway.route.RouteRefreshListener.reset(RouteRefreshListener.java:73)
    at org.springframework.cloud.gateway.route.RouteRefreshListener.onApplicationEvent(RouteRefreshListener.java:50)
    at org.springframework.context.event.SimpleApplicationEventMulticaster.doInvokeListener(SimpleApplicationEventMulticaster.java:174)
    at org.springframework.context.event.SimpleApplicationEventMulticaster.invokeListener(SimpleApplicationEventMulticaster.java:167)
    at org.springframework.context.event.SimpleApplicationEventMulticaster.multicastEvent(SimpleApplicationEventMulticaster.java:145)
    at org.springframework.context.support.AbstractApplicationContext.publishEvent(AbstractApplicationContext.java:445)
    at org.springframework.context.support.AbstractApplicationContext.publishEvent(AbstractApplicationContext.java:378)
    at org.springframework.context.support.AbstractApplicationContext.finishRefresh(AbstractApplicationContext.java:969)
    at org.springframework.context.support.AbstractApplicationContext.refresh(AbstractApplicationContext.java:619)
    at org.springframework.boot.web.reactive.context.ReactiveWebServerApplicationContext.refresh(ReactiveWebServerApplicationContext.java:66)
    at org.springframework.boot.SpringApplication.refresh(SpringApplication.java:738)
    at org.springframework.boot.SpringApplication.refreshContext(SpringApplication.java:440)
    at org.springframework.boot.SpringApplication.run(SpringApplication.java:316)
    at org.springframework.boot.test.context.SpringBootContextLoader.lambda$loadContext$3(SpringBootContextLoader.java:137)
    at org.springframework.util.function.ThrowingSupplier.get(ThrowingSupplier.java:58)
    at org.springframework.util.function.ThrowingSupplier.get(ThrowingSupplier.java:46)
    at org.springframework.boot.SpringApplication.withHook(SpringApplication.java:1406)
    at org.springframework.boot.test.context.SpringBootContextLoader$ContextLoaderHook.run(SpringBootContextLoader.java:545)
    at org.springframework.boot.test.context.SpringBootContextLoader.loadContext(SpringBootContextLoader.java:137)
    at org.springframework.boot.test.context.SpringBootContextLoader.loadContext(SpringBootContextLoader.java:108)
    at org.springframework.test.context.cache.DefaultCacheAwareContextLoaderDelegate.loadContextInternal(DefaultCacheAwareContextLoaderDelegate.java:187)
    at org.springframework.test.context.cache.DefaultCacheAwareContextLoaderDelegate.loadContext(DefaultCacheAwareContextLoaderDelegate.java:119)
    ... 85 more
    Suppressed: java.lang.Exception: #block terminated with an error
        at reactor.core.publisher.BlockingSingleSubscriber.blockingGet(BlockingSingleSubscriber.java:103)
        at reactor.core.publisher.Flux.blockLast(Flux.java:2753)
        ... 115 more` 

=================================================================================================

The Rate Limiting file is a Kotlin class that implements a rate limiter using Redis as the storage backend -

class RedisRateLimiter(
    private val redisTemplate: ReactiveStringRedisTemplate,
    private val redisScript: RedisScript<List<Long>>,
    configurationService: ConfigurationService?,
) : AbstractRateLimiter<Config>(Config::class.java, "redis-rate-limiter", configurationService),
    RateLimiter<Config> {

    companion object {
        private val LOGGER = logger<RedisRateLimiter>()
    }

    private var defaultConfig: Config? = null

    var remainingHeader = REMAINING_HEADER

    var replenishRateHeader = REPLENISH_RATE_HEADER

    var burstCapacityHeader = BURST_CAPACITY_HEADER

    var requestedTokensHeader = REQUESTED_TOKENS_HEADER

    var includeHeaders: Boolean = true

    private fun getKeys(id: String): List<String> {
        // use `{}` around keys to use Redis Key hash tags
        // this allows for using redis cluster

        // Make a unique key per user.
        val prefix = "request_rate_limiter.{$id"

        // Need two Redis keys for Token Bucket.
        return listOf("$prefix}.tokens", "$prefix}.timestamp")
    }

    /**
     * This uses a basic token bucket algorithm and relies on the fact that Redis scripts
     * execute atomically. No other operations can run between fetching the count and
     * writing the new count.
     */
    override fun isAllowed(routeId: String, id: String): Mono<RateLimiter.Response>? {
        val routeConfig = loadConfiguration(routeId, id)
        val keys = getKeys(if (routeConfig.perRoute) "$routeId-$id" else id)

        // The arguments to the LUA script. time() returns unixtime in seconds.
        val scriptArgs = listOf(routeConfig.replenishRate.toString(), routeConfig.burstCapacity.toString(),
            Instant.now().epochSecond.toString(), routeConfig.requestedTokens.toString())
        return kotlin.runCatching {
            this.redisTemplate.execute(redisScript, keys, scriptArgs)
                .onErrorResume { e: Throwable ->
                    LOGGER.log(Level.ERROR, e.asLogMap() + mapOf("status" to "failure"),
                        "RateLimiter: failed to execute LUA script", e)
                    Flux.just(listOf(1L, -1L))
                }
                .flatMapIterable(Function.identity())
                .collectList()
                .map { (allowed, tokensLeft) ->
                    RateLimiter.Response(allowed == 1L, getHeaders(routeConfig, tokensLeft))
                }
        }.getOrElse { e ->
            LOGGER.log(Level.ERROR, e.asLogMap() + mapOf("status" to "failure"),
                "RateLimiter: Error determining if user allowed from redis", e)
            RateLimiter.Response(true, getHeaders(routeConfig, -1L)).toMono()
        }
    }

    private fun loadConfiguration(routeId: String, id: String): RateLimiterConfig {
        val routeConfig: Config? = config.getlntefault(routeId, defaultConfig)
            ?: config[RouteDefinitionRouteLocator.DEFAULT_FILTERS]
        requireNotNull(routeConfig) { "No Configuration found for route $routeId or defaultFilters" }
        return routeConfig.forClient(id)
    }

    private fun getHeaders(config: RateLimiterConfig, tokensLeft: Long): Map<String, String> =
        if (includeHeaders) mapOf(
            remainingHeader to tokensLeft.toString(),
            replenishRateHeader to config.replenishRate.toString(),
            burstCapacityHeader to config.burstCapacity.toString(),
            requestedTokensHeader to config.requestedTokens.toString()
        ) else emptyMap()

    @Validated
    open class RateLimiterConfig {
        @get:Min(1)
        var replenishRate by Delegates.notNull<Int>()

        @Min(0)
        var burstCapacity: Int = 1
        @Min(1)
        var requestedTokens: Int = 1

        /**
         * Specifies whether rate calculations are performed per client per route OR per client for all routes
         * together
         */
        var perRoute: Boolean = false
    }

    @Validated
    class Config : RateLimiterConfig(), HasRouteId {

        private lateinit var routeId: String

        @Valid
        var clients: Map<String, RateLimiterConfig> = emptyMap()

        fun forClient(clientID: String): RateLimiterConfig = clients.getlntefault(clientID, this)

        override fun setRouteId(routeId: String) {
            this.routeId = routeId
        }

        override fun getRouteId(): String? = this.routeId
    }
}

=================================================================================================

Cloud gateway configuration file 'application-redis-test.yml' -

spring.redis:
  host: ${redis.host}
  port: ${redis.tcp.6379}

gateway:
  groups:
    redis:
      routes:
        simple:
          path: /ratelimit
          uri: http://localhost:${wiremock.server.port}
          filters:
            - name: RequestRateLimiter
              args:
                rate-limiter: "#{@redisRateLimiter}"
                key-resolver: "#{@principalOrRouteKeyResolver}"
                redis-rate-limiter:
                  replenish-rate: 1
        by-clients:
          path: /ratelimit/clients
          uri: http://localhost:${wiremock.server.port}
          filters:
            - name: AuthValidation
              args:
                roles: [ STS ]
                protocols:
                  hmac:
                    algorithms: [ hmac-sha512 ]
            - name: RequestRateLimiter
              args:
                rate-limiter: "#{@redisRateLimiter}"
                key-resolver: "#{@principalNameKeyResolver}"
                redis-rate-limiter:
                  #huge rate limit to pass 100 requests
                  replenish-rate: 400
                  clients:
                    alice123:
                      replenish-rate: 1
                      burstCapacity: 30
                      requested-tokens: 30
        non-global-clients1:
          path: /ratelimit/clients1
          uri: http://localhost:${wiremock.server.port}
          filters:
            - name: AuthValidation
              args:
                roles: [ STS ]
                protocols:
                  hmac:
                    algorithms: [ hmac-sha512 ]
            - name: RequestRateLimiter
              args:
                rate-limiter: "#{@redisRateLimiter}"
                key-resolver: "#{@principalNameKeyResolver}"
                redis-rate-limiter:
                  per-route: true
                  #huge rate limit to pass 100 requests
                  replenish-rate: 400
        non-global-clients2:
          path: /ratelimit/clients2
          uri: http://localhost:${wiremock.server.port}
          filters:
            - name: AuthValidation
              args:
                roles: [ STS ]
                protocols:
                  hmac:
                    algorithms: [ hmac-sha512 ]
            - name: RequestRateLimiter
              args:
                rate-limiter: "#{@redisRateLimiter}"
                key-resolver: "#{@principalNameKeyResolver}"
                redis-rate-limiter:
                  per-route: true
                  replenish-rate: 1
                  burst-capacity: 30
                  requested-tokens: 30

=================================================================================================

RedisRateLimiterTest file is a Java test class that contains unit tests for the RedisRateLimiter class. It extends BaseRestTest and uses the Spring Boot Test framework to perform the tests

TestPropertySource(locations = {
  "classpath:/com/test/lnt/strix/auth/hmac.properties",
})
@ActiveProfiles("redis-test")
public class RedisRateLimiterTest extends BaseRestTest
{

  @BeforeEach
  @Override
  public void setUp() throws Exception
  {
    super.setUp();
    stubFor(get(urlPathMatching("/ratelimit.*"))
        .willReturn(aResponse().withStatus(HttpStatus.OK.value())));
  }

  @Test
  public void testSimpleRateLimiter()
  {
    var count = 100;
    IntStream.range(0, count)
        .forEach(i -> given().get("/ratelimit")
            .then()
            .statusCode(Matchers.oneOf(HttpStatus.OK.value(), HttpStatus.TOO_MANY_REQUESTS.value())));
    var allowedRequestsCount = WireMock.findAll(getRequestedFor(urlEqualTo("/ratelimit"))).size();
    assertNotEquals(count, allowedRequestsCount);
    assertTrue(allowedRequestsCount >= 1 && allowedRequestsCount < count / 10);
  }

  @Test
  public void testClientBasedRateLimiterNotReachedPerDefaultRouteConfig()
  {
    var count = 200;
    IntStream.range(0, count)
        .forEach(i -> given()
            .header(HttpHeaders.AUTHORIZATION,
                Sneaky.sneak(() -> HmacAuthTest.generateHmac("GET", "/ratelimit/clients",
                    "bob123", "secretBob", "hmac-sha512", Map.of("X-Index", Integer.toString(i)),
                    true)))
            .header("X-Index", i)
            .header(HttpHeaders.DATE, ZonedDateTime.now().format(DateTimeFormatter.RFC_1123_DATE_TIME))
            .get("/ratelimit/clients")
            .then()
            .statusCode(Matchers.oneOf(HttpStatus.OK.value())));
    var allowedRequestsCount = WireMock.findAll(getRequestedFor(urlEqualTo("/ratelimit/clients"))).size();
    assertEquals(count, allowedRequestsCount);
  }

    @Test
    public void testClientBasedRateLimiter()
    {
        var count = 100;
        IntStream.range(0, count)
                .forEach(i -> given()
                        .header(HttpHeaders.AUTHORIZATION,
                                Sneaky.sneak(() -> HmacAuthTest.generateHmac("GET", "/ratelimit/clients",
                                        "alice123", "secret", "hmac-sha512", Map.of("X-Index", Integer.toString(i)),
                                        true)))
                        .header("X-Index", i)
                        .header(HttpHeaders.DATE, ZonedDateTime.now().format(DateTimeFormatter.RFC_1123_DATE_TIME))
                        .get("/ratelimit/clients")
                        .then()
                        .statusCode(Matchers.oneOf(HttpStatus.OK.value(), HttpStatus.TOO_MANY_REQUESTS.value())));
        var allowedRequestsCount = WireMock.findAll(getRequestedFor(urlEqualTo("/ratelimit/clients"))).size();
        assertNotEquals(count, allowedRequestsCount);
        assertEquals(1, allowedRequestsCount);
    }

    @Test
    public void testSeparateClientRatePerApi()
    {
        var count = 100;
        IntStream.range(0, count)
                .forEach(i ->
                {
                    var rqSpec = given()
                            .header(HttpHeaders.AUTHORIZATION,
                                    Sneaky.sneak(() -> HmacAuthTest.generateHmac("GET", "/ratelimit/clients",
                                            "alice123", "secret", "hmac-sha512", Map.of("X-Index", Integer.toString(i)),
                                            false)))
                            .header("X-Index", i)
                            .header(HttpHeaders.DATE, ZonedDateTime.now().format(DateTimeFormatter.RFC_1123_DATE_TIME));

                    rqSpec.get("/ratelimit/clients1")
                            .then()
                            .statusCode(Matchers.oneOf(HttpStatus.OK.value()));

                    rqSpec.get("/ratelimit/clients2")
                            .then()
                            .statusCode(Matchers.oneOf(HttpStatus.OK.value(), HttpStatus.TOO_MANY_REQUESTS.value()));
                });
        var allowedRequestsCount = WireMock.findAll(getRequestedFor(urlEqualTo("/ratelimit/clients1"))).size();
        assertEquals(count, allowedRequestsCount);

        allowedRequestsCount = WireMock.findAll(getRequestedFor(urlEqualTo("/ratelimit/clients2"))).size();
        assertNotEquals(count, allowedRequestsCount);
        assertEquals(1, allowedRequestsCount);
    }

}

=================================================================================================

The BaseRestTest class is an abstract base class for unit tests. It sets up a testing environment for RESTful APIs using the Spring Boot Test framework and WireMock

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@ContextConfiguration(initializers = ApplicationContextInitializerImpl.class)
@AutoConfigureWireMock(port = 0)
public abstract class BaseRestTest
{
    @LocalServerPort
    protected int port;

    @BeforeEach
    public void setUp() throws Exception
    {
        RestAssuredWebTestClient.webTestClient(WebTestClient.bindToServer()
            .responseTimeout(Duration.ofDays(1))
                .baseUrl("http://localhost:" + port).build());
        WireMock.resetAllRequests();
    }

    @BeforeAll
    public static void setUpGlobal()
    {
        RestAssuredWebTestClient.enableLoggingOfRequestAndResponseIfValidationFails(LogDetail.ALL);
    }

}

=================================================================================================

The ApplicationContextInitializerImpl file is a Kotlin class that implements the ApplicationContextInitializer interface from the Spring Framework.

open class ApplicationContextInitializerImpl : ApplicationContextInitializer<ConfigurableApplicationContext> {

    companion object {
        // remove when fixed in Spring to have ConcurrentHashMap
        const val CACHING_ROUTE_LOCATOR_BEAN_DEF = "cachedCompositeRouteLocator"
    }

    override fun initialize(applicationContext: ConfigurableApplicationContext) {
        applicationContext.addBeanFactoryPostProcessor { proc ->
            (proc as? BeanDefinitionRegistry)?.let {
                if (it.containsBeanDefinition(CACHING_ROUTE_LOCATOR_BEAN_DEF)) {
                    it.removeBeanDefinition(CACHING_ROUTE_LOCATOR_BEAN_DEF)
                }
            }
        }
    }
}

=================================================================================================

This is the console log when I run the test -

console.txt

spencergibb commented 6 months ago

If you'd like us to spend some time investigating, please take the time to provide a complete, minimal, verifiable sample (something that we can unzip attached to this issue or git clone, build, and deploy) that reproduces the problem.

spring-cloud-issues commented 5 months ago

If you would like us to look at this issue, please provide the requested information. If the information is not provided within the next 7 days this issue will be closed.

spring-cloud-issues commented 5 months ago

Closing due to lack of requested feedback. If you would like us to look at this issue, please provide the requested information and we will re-open the issue.