Source code for baseplate.lib.ratelimit.backends.redis

from redis import ConnectionPool

from baseplate import Span
from baseplate.clients import ContextFactory
from baseplate.clients.redis import MonitoredRedisConnection
from baseplate.clients.redis import RedisContextFactory
from baseplate.lib.ratelimit.backends import _get_current_bucket
from baseplate.lib.ratelimit.backends import RateLimitBackend


[docs]class RedisRateLimitBackendContextFactory(ContextFactory): """RedisRateLimitBackend context factory. :param redis_pool: The redis pool to back this ratelimit. :param prefix: A prefix to add to keys during rate limiting. This is useful if you will have two different rate limiters that will receive the same keys. """ def __init__(self, redis_pool: ConnectionPool, prefix: str = "rl:"): self.redis_context_factory = RedisContextFactory(redis_pool) self.prefix = prefix
[docs] def make_object_for_context(self, name: str, span: Span) -> "RedisRateLimitBackend": redis = self.redis_context_factory.make_object_for_context(name, span) return RedisRateLimitBackend(redis, prefix=self.prefix)
[docs]class RedisRateLimitBackend(RateLimitBackend): """A Redis backend for rate limiting. :param redis: An instance of :py:class:`baseplate.clients.redis.MonitoredRedisConnection`. :param prefix: A prefix to add to keys during rate limiting. This is useful if you will have two different rate limiters that will receive the same keys. """ def __init__(self, redis: MonitoredRedisConnection, prefix: str = "rl:"): self.redis = redis self.prefix = prefix
[docs] def consume(self, key: str, amount: int, allowance: int, interval: int) -> bool: """Consume the given `amount` from the allowance for the given `key`. This will return true if the `key` remains below the `allowance` after consuming the given `amount`. :param key: The name of the rate limit bucket to consume from. :param amount: The amount to consume from the rate limit bucket. :param allowance: The maximum allowance for the rate limit bucket. :param interval: The interval to reset the allowance. """ current_bucket = _get_current_bucket(interval) key = self.prefix + key + current_bucket ttl = interval * 2 with self.redis.pipeline("ratelimit") as pipe: pipe.incr(key, amount) pipe.expire(key, time=ttl) responses = pipe.execute() count = responses[0] return count <= allowance