From e76765c679aa1f7cb7c1b3bd71bd26d1681ed676 Mon Sep 17 00:00:00 2001 From: bitterpanda Date: Mon, 18 May 2026 16:39:49 -0700 Subject: [PATCH] add retry-after header --- aikido_zen/middleware/asgi.py | 12 +- aikido_zen/middleware/django.py | 4 +- aikido_zen/middleware/flask.py | 1 + aikido_zen/middleware/init_test.py | 3 +- aikido_zen/middleware/should_block_request.py | 1 + aikido_zen/ratelimiting/__init__.py | 30 ++- aikido_zen/ratelimiting/init_test.py | 103 +++++----- aikido_zen/ratelimiting/rate_limiter.py | 16 +- aikido_zen/ratelimiting/rate_limiter_test.py | 188 ++++++++++-------- 9 files changed, 209 insertions(+), 149 deletions(-) diff --git a/aikido_zen/middleware/asgi.py b/aikido_zen/middleware/asgi.py index 236b71d70..e7f24d9df 100644 --- a/aikido_zen/middleware/asgi.py +++ b/aikido_zen/middleware/asgi.py @@ -19,7 +19,10 @@ async def __call__(self, scope, receive, send): message = "You are rate limited by Zen." if result["trigger"] == "ip" and result["ip"]: message += " (Your IP: " + result["ip"] + ")" - return await send_status_code_and_text(send, (message, 429)) + extra_headers = [ + (b"retry-after", str(result["retry_after_seconds"]).encode()) + ] + return await send_status_code_and_text(send, (message, 429), extra_headers) if result["type"] == "blocked": return await send_status_code_and_text( @@ -30,13 +33,16 @@ async def __call__(self, scope, receive, send): return await self.app(scope, receive, send) -async def send_status_code_and_text(send, pre_response): +async def send_status_code_and_text(send, pre_response, extra_headers=None): """Sends a status code and text""" + headers = [(b"content-type", b"text/plain")] + if extra_headers: + headers = headers + extra_headers await send( { "type": "http.response.start", "status": pre_response[1], - "headers": [(b"content-type", b"text/plain")], + "headers": headers, } ) await send( diff --git a/aikido_zen/middleware/django.py b/aikido_zen/middleware/django.py index 5bfd1767d..2a582ac01 100644 --- a/aikido_zen/middleware/django.py +++ b/aikido_zen/middleware/django.py @@ -27,7 +27,9 @@ def __call__(self, request): message = "You are rate limited by Zen." if result["trigger"] == "ip" and result["ip"]: message += " (Your IP: " + result["ip"] + ")" - return self.HttpResponse(message, content_type="text/plain", status=429) + response = self.HttpResponse(message, content_type="text/plain", status=429) + response["Retry-After"] = str(result["retry_after_seconds"]) + return response if result["type"] == "blocked": return self.HttpResponse( diff --git a/aikido_zen/middleware/flask.py b/aikido_zen/middleware/flask.py index 585718fbe..f25d0b061 100644 --- a/aikido_zen/middleware/flask.py +++ b/aikido_zen/middleware/flask.py @@ -28,6 +28,7 @@ def __call__(self, environ, start_response): if result["trigger"] == "ip" and result["ip"]: message += " (Your IP: " + result["ip"] + ")" res = self.Response(message, mimetype="text/plain", status=429) + res.headers["Retry-After"] = str(result["retry_after_seconds"]) return res(environ, start_response) if result["type"] == "blocked": diff --git a/aikido_zen/middleware/init_test.py b/aikido_zen/middleware/init_test.py index d7f4bf1e8..74f895aa6 100644 --- a/aikido_zen/middleware/init_test.py +++ b/aikido_zen/middleware/init_test.py @@ -152,7 +152,7 @@ def test_cache_comms_with_endpoints(): mock_comms.send_data_to_bg_process.return_value = { "success": True, - "data": {"block": True, "trigger": "my_trigger"}, + "data": {"block": True, "trigger": "my_trigger", "retry_after_seconds": 10}, } assert thread_cache.stats.rate_limited_hits == 0 assert should_block_request() == { @@ -160,5 +160,6 @@ def test_cache_comms_with_endpoints(): "ip": "1.1.1.1", "type": "ratelimited", "trigger": "my_trigger", + "retry_after_seconds": 10, } assert thread_cache.stats.rate_limited_hits == 1 diff --git a/aikido_zen/middleware/should_block_request.py b/aikido_zen/middleware/should_block_request.py index 434f06a67..762cee4f7 100644 --- a/aikido_zen/middleware/should_block_request.py +++ b/aikido_zen/middleware/should_block_request.py @@ -69,6 +69,7 @@ def should_block_request(): "type": "ratelimited", "trigger": ratelimit_res["data"]["trigger"], "ip": context.remote_address, + "retry_after_seconds": ratelimit_res["data"]["retry_after_seconds"], } except Exception as e: logger.debug("Exception occurred in should_block_request: %s", e) diff --git a/aikido_zen/ratelimiting/__init__.py b/aikido_zen/ratelimiting/__init__.py index 4d60bdd01..dca8330b8 100644 --- a/aikido_zen/ratelimiting/__init__.py +++ b/aikido_zen/ratelimiting/__init__.py @@ -28,34 +28,46 @@ def should_ratelimit_request( windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"]) if group: - allowed = connection_manager.rate_limiter.is_allowed( + result = connection_manager.rate_limiter.is_allowed( get_key_for_group(endpoint, group), windows_size_in_ms, max_requests, ) - if not allowed: - return {"block": True, "trigger": "group"} + if not result["allowed"]: + return { + "block": True, + "trigger": "group", + "retry_after_seconds": result["retry_after_seconds"], + } # Do not check IP or user rate limit if group is set return {"block": False} if user: - allowed = connection_manager.rate_limiter.is_allowed( + result = connection_manager.rate_limiter.is_allowed( get_key_for_user(endpoint, user), windows_size_in_ms, max_requests, ) - if not allowed: - return {"block": True, "trigger": "user"} + if not result["allowed"]: + return { + "block": True, + "trigger": "user", + "retry_after_seconds": result["retry_after_seconds"], + } # Do not check IP rate limit if user is set return {"block": False} if remote_address: - allowed = connection_manager.rate_limiter.is_allowed( + result = connection_manager.rate_limiter.is_allowed( get_key_for_ip(endpoint, remote_address), windows_size_in_ms, max_requests, ) - if not allowed: - return {"block": True, "trigger": "ip"} + if not result["allowed"]: + return { + "block": True, + "trigger": "ip", + "retry_after_seconds": result["retry_after_seconds"], + } return {"block": False} diff --git a/aikido_zen/ratelimiting/init_test.py b/aikido_zen/ratelimiting/init_test.py index dfd14acf6..f90ba05b8 100644 --- a/aikido_zen/ratelimiting/init_test.py +++ b/aikido_zen/ratelimiting/init_test.py @@ -60,10 +60,10 @@ def test_rate_limits_by_ip(): assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == { "block": False } - assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == { - "block": True, - "trigger": "ip", - } + result = should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) + assert result["block"] is True + assert result["trigger"] == "ip" + assert result["retry_after_seconds"] >= 0 def test_rate_limiting_ip_allowed(): @@ -126,10 +126,10 @@ def test_rate_limiting_by_user(user): assert should_ratelimit_request(route_metadata, "1.2.3.6", user, cm) == { "block": False } - assert should_ratelimit_request(route_metadata, "1.2.3.7", user, cm) == { - "block": True, - "trigger": "user", - } + result = should_ratelimit_request(route_metadata, "1.2.3.7", user, cm) + assert result["block"] is True + assert result["trigger"] == "user" + assert result["retry_after_seconds"] >= 0 def test_rate_limiting_with_wildcard(): @@ -160,9 +160,12 @@ def test_rate_limiting_with_wildcard(): ) == {"block": False} # This request should trigger the rate limit - assert should_ratelimit_request( + result = should_ratelimit_request( create_route_metadata(route="/api/login"), "1.2.3.4", None, cm - ) == {"block": True, "trigger": "ip"} + ) + assert result["block"] is True + assert result["trigger"] == "ip" + assert result["retry_after_seconds"] >= 0 def test_rate_limiting_with_wildcard2(): @@ -193,10 +196,10 @@ def test_rate_limiting_with_wildcard2(): # This request should trigger the rate limit metadata = create_route_metadata(route="/api/login", method="GET") - assert should_ratelimit_request(metadata, "1.2.3.4", None, cm) == { - "block": True, - "trigger": "ip", - } + result = should_ratelimit_request(metadata, "1.2.3.4", None, cm) + assert result["block"] is True + assert result["trigger"] == "ip" + assert result["retry_after_seconds"] >= 0 def test_rate_limiting_by_user_with_same_ip(): @@ -228,10 +231,10 @@ def test_rate_limiting_by_user_with_same_ip(): } # This request should trigger the rate limit - assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == { - "block": True, - "trigger": "user", - } + result = should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) + assert result["block"] is True + assert result["trigger"] == "user" + assert result["retry_after_seconds"] >= 0 def test_rate_limiting_by_user_with_different_ips(): @@ -267,10 +270,10 @@ def test_rate_limiting_by_user_with_different_ips(): } # This request from second IP should trigger the rate limit - assert should_ratelimit_request(metadata, "4.3.2.1", {"id": "123"}, cm) == { - "block": True, - "trigger": "user", - } + result = should_ratelimit_request(metadata, "4.3.2.1", {"id": "123"}, cm) + assert result["block"] is True + assert result["trigger"] == "user" + assert result["retry_after_seconds"] >= 0 def test_rate_limiting_same_ip_different_users(): @@ -385,10 +388,10 @@ def test_rate_limits_by_user_with_different_ips(): "block": False } # This request should trigger the rate limit by group - assert should_ratelimit_request(route_metadata, "4.3.2.1", user, cm, "group1") == { - "block": True, - "trigger": "group", - } + result = should_ratelimit_request(route_metadata, "4.3.2.1", user, cm, "group1") + assert result["block"] is True + assert result["trigger"] == "group" + assert result["retry_after_seconds"] >= 0 def test_rate_limits_different_users_in_same_group(): @@ -420,12 +423,12 @@ def test_rate_limits_different_users_in_same_group(): route_metadata, "1.2.3.4", {"id": "789"}, cm, "group1" ) == {"block": False} # This request should trigger the rate limit by group - assert should_ratelimit_request( + result = should_ratelimit_request( route_metadata, "4.3.2.1", {"id": "101112"}, cm, "group1" - ) == { - "block": True, - "trigger": "group", - } + ) + assert result["block"] is True + assert result["trigger"] == "group" + assert result["retry_after_seconds"] >= 0 def test_works_with_multiple_rate_limit_groups_and_different_users(): @@ -457,30 +460,30 @@ def test_works_with_multiple_rate_limit_groups_and_different_users(): route_metadata, "4.3.2.1", {"id": "101112"}, cm, "group2" ) == {"block": False} # This request should trigger the rate limit for group1 - assert should_ratelimit_request( + result = should_ratelimit_request( route_metadata, "1.2.3.4", {"id": "789"}, cm, "group1" - ) == { - "block": True, - "trigger": "group", - } + ) + assert result["block"] is True + assert result["trigger"] == "group" + assert result["retry_after_seconds"] >= 0 # This request should also trigger the rate limit for group1 - assert should_ratelimit_request( + result = should_ratelimit_request( route_metadata, "1.2.3.4", {"id": "4321"}, cm, "group1" - ) == { - "block": True, - "trigger": "group", - } + ) + assert result["block"] is True + assert result["trigger"] == "group" + assert result["retry_after_seconds"] >= 0 # First request from user 953, group2 assert should_ratelimit_request( route_metadata, "4.3.2.1", {"id": "953"}, cm, "group2" ) == {"block": False} # This request should trigger the rate limit for group2 - assert should_ratelimit_request( + result = should_ratelimit_request( route_metadata, "4.3.2.1", {"id": "1563"}, cm, "group2" - ) == { - "block": True, - "trigger": "group", - } + ) + assert result["block"] is True + assert result["trigger"] == "group" + assert result["retry_after_seconds"] >= 0 def test_rate_limits_by_group_if_user_is_not_set(): @@ -512,10 +515,10 @@ def test_rate_limits_by_group_if_user_is_not_set(): "block": False } # This request should trigger the rate limit by group - assert should_ratelimit_request(route_metadata, "4.3.2.1", None, cm, "group1") == { - "block": True, - "trigger": "group", - } + result = should_ratelimit_request(route_metadata, "4.3.2.1", None, cm, "group1") + assert result["block"] is True + assert result["trigger"] == "group" + assert result["retry_after_seconds"] >= 0 def test_does_not_rate_limit_excluded_users(): diff --git a/aikido_zen/ratelimiting/rate_limiter.py b/aikido_zen/ratelimiting/rate_limiter.py index 7efdca860..009ff6113 100644 --- a/aikido_zen/ratelimiting/rate_limiter.py +++ b/aikido_zen/ratelimiting/rate_limiter.py @@ -2,6 +2,7 @@ Mostly exports the class RateLimiter """ +import math from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms from .lru_cache import LRUCache @@ -18,7 +19,8 @@ def __init__(self, max_items, time_to_live_in_ms): def is_allowed(self, key, window_size_in_ms, max_requests): """ - Checks if the request is allowed given the history + Checks if the request is allowed given the history. + Returns {"allowed": True} or {"allowed": False, "retry_after_seconds": int}. """ current_time = get_unixtime_ms() request_timestamps = self.rate_limited_items.get(key) or [] @@ -39,5 +41,13 @@ def is_allowed(self, key, window_size_in_ms, max_requests): request_timestamps.append(current_time) self.rate_limited_items.set(key, request_timestamps) - # if the total amount of requests in the current window exceeds max requests, we rate-limit - return len(request_timestamps) <= max_requests + if len(request_timestamps) <= max_requests: + return {"allowed": True} + + retry_after_ms = max( + 0, request_timestamps[0] + window_size_in_ms - current_time + ) + return { + "allowed": False, + "retry_after_seconds": math.ceil(retry_after_ms / 1000), + } diff --git a/aikido_zen/ratelimiting/rate_limiter_test.py b/aikido_zen/ratelimiting/rate_limiter_test.py index 76ec3d9a4..fe2d516ef 100644 --- a/aikido_zen/ratelimiting/rate_limiter_test.py +++ b/aikido_zen/ratelimiting/rate_limiter_test.py @@ -13,16 +13,18 @@ def rate_limiter(): def test_allow_requests_within_limit(rate_limiter): key = "user1" for i in range(5): - assert rate_limiter.is_allowed( - key, 5000000, 5 - ), f"Request {i + 1} should be allowed" + assert rate_limiter.is_allowed(key, 5000000, 5)[ + "allowed" + ], f"Request {i + 1} should be allowed" def test_deny_requests_exceeding_limit(rate_limiter): key = "user2" for i in range(5): rate_limiter.is_allowed(key, 500, 5) - assert not rate_limiter.is_allowed(key, 500, 5), "Request 6 should not be allowed" + assert not rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "Request 6 should not be allowed" def test_clear_old_entries(rate_limiter): @@ -32,9 +34,9 @@ def test_clear_old_entries(rate_limiter): time.sleep(0.6) # Sleep to allow old entries to be cleared - assert rate_limiter.is_allowed( - key, 500, 5 - ), "New request should be allowed after clearing old entries" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "New request should be allowed after clearing old entries" def test_multiple_keys(rate_limiter): @@ -42,20 +44,20 @@ def test_multiple_keys(rate_limiter): key2 = "user5" for i in range(5): - assert rate_limiter.is_allowed( - key1, 500, 5 - ), f"Request {i + 1} for key1 should be allowed" + assert rate_limiter.is_allowed(key1, 500, 5)[ + "allowed" + ], f"Request {i + 1} for key1 should be allowed" for i in range(5): - assert rate_limiter.is_allowed( - key2, 500, 5 - ), f"Request {i + 1} for key2 should be allowed" + assert rate_limiter.is_allowed(key2, 500, 5)[ + "allowed" + ], f"Request {i + 1} for key2 should be allowed" - assert not rate_limiter.is_allowed( - key1, 500, 5 - ), "Request 6 for key1 should not be allowed" - assert not rate_limiter.is_allowed( - key2, 500, 5 - ), "Request 6 for key2 should not be allowed" + assert not rate_limiter.is_allowed(key1, 500, 5)[ + "allowed" + ], "Request 6 for key1 should not be allowed" + assert not rate_limiter.is_allowed(key2, 500, 5)[ + "allowed" + ], "Request 6 for key2 should not be allowed" def test_ttl_expiration(rate_limiter): @@ -65,16 +67,20 @@ def test_ttl_expiration(rate_limiter): time.sleep(1.1) # Sleep to allow the TTL to expire - assert rate_limiter.is_allowed(key, 500, 5), "Request after TTL should be allowed" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "Request after TTL should be allowed" def test_allow_requests_exactly_at_limit(rate_limiter): key = "user7" for i in range(5): - assert rate_limiter.is_allowed( - key, 500, 5 - ), f"Request {i + 1} should be allowed" - assert not rate_limiter.is_allowed(key, 500, 5), "Request 6 should not be allowed" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], f"Request {i + 1} should be allowed" + assert not rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "Request 6 should not be allowed" def test_allow_requests_after_clearing_old_entries(rate_limiter): @@ -84,23 +90,23 @@ def test_allow_requests_after_clearing_old_entries(rate_limiter): time.sleep(0.6) # Sleep to allow old entries to be cleared - assert rate_limiter.is_allowed( - key, 500, 5 - ), "New request should be allowed after clearing old entries" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "New request should be allowed after clearing old entries" def test_multiple_rapid_requests(rate_limiter): key = "user9" for i in range(5): - assert rate_limiter.is_allowed( - key, 500, 5 - ), f"Request {i + 1} should be allowed" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], f"Request {i + 1} should be allowed" time.sleep(0.1) # Sleep for 100 ms - assert not rate_limiter.is_allowed( - key, 500, 5 - ), "Request after rapid requests should not be allowed" + assert not rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "Request after rapid requests should not be allowed" def test_reset_after_ttl(rate_limiter): @@ -110,19 +116,21 @@ def test_reset_after_ttl(rate_limiter): time.sleep(1.1) # Sleep to allow the TTL to expire - assert rate_limiter.is_allowed(key, 500, 5), "Request after TTL should be allowed" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "Request after TTL should be allowed" def test_different_window_sizes(rate_limiter): key = "user11" different_window_size = 1000 # 1 second window for i in range(5): - assert rate_limiter.is_allowed( - key, different_window_size, 5 - ), f"Request {i + 1} should be allowed" - assert not rate_limiter.is_allowed( - key, different_window_size, 5 - ), "Request 6 should not be allowed" + assert rate_limiter.is_allowed(key, different_window_size, 5)[ + "allowed" + ], f"Request {i + 1} should be allowed" + assert not rate_limiter.is_allowed(key, different_window_size, 5)[ + "allowed" + ], "Request 6 should not be allowed" def test_sliding_window_with_intermittent_requests(rate_limiter): @@ -130,18 +138,18 @@ def test_sliding_window_with_intermittent_requests(rate_limiter): # Allow 5 requests in a 1-second window for i in range(5): - assert rate_limiter.is_allowed( - key, 500, 5 - ), f"Request {i + 1} should be allowed" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], f"Request {i + 1} should be allowed" time.sleep(0.1) # Sleep 100 ms between requests # Sleep for 600 ms to allow the first requests to slide out of the window time.sleep(0.6) # Now we should be able to make a new request - assert rate_limiter.is_allowed( - key, 500, 5 - ), "New request should be allowed after sliding" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "New request should be allowed after sliding" def test_sliding_window_edge_case(rate_limiter): @@ -149,25 +157,25 @@ def test_sliding_window_edge_case(rate_limiter): # Allow 5 requests in a 1-second window for i in range(5): - assert rate_limiter.is_allowed( - key, 500, 5 - ), f"Request {i + 1} should be allowed" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], f"Request {i + 1} should be allowed" # Sleep for 500 ms to simulate time passing time.sleep(0.6) # The next request should still be allowed as the window is sliding - assert rate_limiter.is_allowed( - key, 500, 5 - ), "Next request should be allowed as window slides" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "Next request should be allowed as window slides" # Sleep for another 500 ms to allow the first batch to expire time.sleep(0.6) # Now we should be able to make a new request - assert rate_limiter.is_allowed( - key, 500, 5 - ), "New request should be allowed after first batch expires" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "New request should be allowed after first batch expires" def test_sliding_window_with_burst_requests(rate_limiter): @@ -176,23 +184,23 @@ def test_sliding_window_with_burst_requests(rate_limiter): # Allow 5 requests in a 1-second window for i in range(5): - assert rate_limiter.is_allowed( - key, window_size_ms, 5 - ), f"Request {i + 1} should be allowed" + assert rate_limiter.is_allowed(key, window_size_ms, 5)[ + "allowed" + ], f"Request {i + 1} should be allowed" # Sleep for 250 ms to simulate time passing time.sleep((window_size_ms / 2) / 1000) # Add 3 more requests (should be denied) - assert not rate_limiter.is_allowed( - key, window_size_ms, 5 - ), "Request should not be allowed" - assert not rate_limiter.is_allowed( - key, window_size_ms, 5 - ), "Request should not be allowed" - assert not rate_limiter.is_allowed( - key, window_size_ms, 5 - ), "Request should not be allowed" + assert not rate_limiter.is_allowed(key, window_size_ms, 5)[ + "allowed" + ], "Request should not be allowed" + assert not rate_limiter.is_allowed(key, window_size_ms, 5)[ + "allowed" + ], "Request should not be allowed" + assert not rate_limiter.is_allowed(key, window_size_ms, 5)[ + "allowed" + ], "Request should not be allowed" time.sleep( (window_size_ms / 2 + 50) / 1000 @@ -200,20 +208,20 @@ def test_sliding_window_with_burst_requests(rate_limiter): # Make a burst of requests (should be allowed) for i in range(2): - assert rate_limiter.is_allowed( - key, window_size_ms, 5 - ), f"Burst request {i + 1} should be allowed" - assert not rate_limiter.is_allowed( - key, window_size_ms, 5 - ), "Burst request should not be allowed after limit" + assert rate_limiter.is_allowed(key, window_size_ms, 5)[ + "allowed" + ], f"Burst request {i + 1} should be allowed" + assert not rate_limiter.is_allowed(key, window_size_ms, 5)[ + "allowed" + ], "Burst request should not be allowed after limit" # Sleep for 500 ms to allow all batches to slide out of window time.sleep((window_size_ms + 100) / 1000) # Now we should be able to make a new request - assert rate_limiter.is_allowed( - key, window_size_ms, 5 - ), "New request should be allowed after all batches slide out" + assert rate_limiter.is_allowed(key, window_size_ms, 5)[ + "allowed" + ], "New request should be allowed after all batches slide out" def test_sliding_window_with_delayed_requests(rate_limiter): @@ -221,15 +229,31 @@ def test_sliding_window_with_delayed_requests(rate_limiter): # Allow 5 requests in a 1-second window for i in range(5): - assert rate_limiter.is_allowed( - key, 500, 5 - ), f"Request {i + 1} should be allowed" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], f"Request {i + 1} should be allowed" time.sleep(0.1) # Sleep 100 ms between requests # Sleep for 600 ms to allow the first requests to slide out of the window time.sleep(0.6) # Now we should be able to make a new request - assert rate_limiter.is_allowed( - key, 500, 5 - ), "New request should be allowed after sliding" + assert rate_limiter.is_allowed(key, 500, 5)[ + "allowed" + ], "New request should be allowed after sliding" + + +def test_retry_after_seconds_when_rate_limited(rate_limiter): + key = "user18" + window_size_ms = 1000 + max_requests = 2 + limiter = RateLimiter(10, window_size_ms) + + assert limiter.is_allowed(key, window_size_ms, max_requests) == {"allowed": True} + assert limiter.is_allowed(key, window_size_ms, max_requests) == {"allowed": True} + + time.sleep(0.3) + + result = limiter.is_allowed(key, window_size_ms, max_requests) + assert result["allowed"] is False + assert result["retry_after_seconds"] > 0