diff options
| author | Daniyal <abbasi.daniyal98@gmail.com> | 2021-05-24 05:31:50 +0530 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2021-09-14 15:50:08 +0200 |
| commit | ec212c66167759a2a40b13d5efc47d883816d4da (patch) | |
| tree | 740bed1fef217361ef973c732045de73bb6f64d1 /django | |
| parent | 676bd084f2509f4201561d5c77ed4ecbd157bfa0 (diff) | |
Fixed #33012 -- Added Redis cache backend.
Thanks Carlton Gibson, Chris Jerdonek, David Smith, Keryn Knight,
Mariusz Felisiak, and Nick Pope for reviews and mentoring this
Google Summer of Code 2021 project.
Diffstat (limited to 'django')
| -rw-r--r-- | django/core/cache/backends/redis.py | 224 |
1 files changed, 224 insertions, 0 deletions
diff --git a/django/core/cache/backends/redis.py b/django/core/cache/backends/redis.py new file mode 100644 index 0000000000..16556b1ded --- /dev/null +++ b/django/core/cache/backends/redis.py @@ -0,0 +1,224 @@ +"""Redis cache backend.""" + +import random +import re + +from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache +from django.core.serializers.base import PickleSerializer +from django.utils.functional import cached_property +from django.utils.module_loading import import_string + + +class RedisSerializer(PickleSerializer): + def dumps(self, obj): + if isinstance(obj, int): + return obj + return super().dumps(obj) + + def loads(self, data): + try: + return int(data) + except ValueError: + return super().loads(data) + + +class RedisCacheClient: + def __init__( + self, + servers, + serializer=None, + db=None, + pool_class=None, + parser_class=None, + ): + import redis + + self._lib = redis + self._servers = servers + self._pools = {} + + self._client = self._lib.Redis + + if isinstance(pool_class, str): + pool_class = import_string(pool_class) + self._pool_class = pool_class or self._lib.ConnectionPool + + if isinstance(serializer, str): + serializer = import_string(serializer) + if callable(serializer): + serializer = serializer() + self._serializer = serializer or RedisSerializer() + + if isinstance(parser_class, str): + parser_class = import_string(parser_class) + parser_class = parser_class or self._lib.connection.DefaultParser + + self._pool_options = {'parser_class': parser_class, 'db': db} + + def _get_connection_pool_index(self, write): + # Write to the first server. Read from other servers if there are more, + # otherwise read from the first server. + if write or len(self._servers) == 1: + return 0 + return random.randint(1, len(self._servers) - 1) + + def _get_connection_pool(self, write): + index = self._get_connection_pool_index(write) + if index not in self._pools: + self._pools[index] = self._pool_class.from_url( + self._servers[index], **self._pool_options, + ) + return self._pools[index] + + def get_client(self, key=None, *, write=False): + # key is used so that the method signature remains the same and custom + # cache client can be implemented which might require the key to select + # the server, e.g. sharding. + pool = self._get_connection_pool(write) + return self._client(connection_pool=pool) + + def add(self, key, value, timeout): + client = self.get_client(key, write=True) + value = self._serializer.dumps(value) + + if timeout == 0: + if ret := bool(client.set(key, value, nx=True)): + client.delete(key) + return ret + else: + return bool(client.set(key, value, ex=timeout, nx=True)) + + def get(self, key, default): + client = self.get_client(key) + value = client.get(key) + return default if value is None else self._serializer.loads(value) + + def set(self, key, value, timeout): + client = self.get_client(key, write=True) + value = self._serializer.dumps(value) + if timeout == 0: + client.delete(key) + else: + client.set(key, value, ex=timeout) + + def touch(self, key, timeout): + client = self.get_client(key, write=True) + if timeout is None: + return bool(client.persist(key)) + else: + return bool(client.expire(key, timeout)) + + def delete(self, key): + client = self.get_client(key, write=True) + return bool(client.delete(key)) + + def get_many(self, keys): + client = self.get_client(None) + ret = client.mget(keys) + return { + k: self._serializer.loads(v) for k, v in zip(keys, ret) if v is not None + } + + def has_key(self, key): + client = self.get_client(key) + return bool(client.exists(key)) + + def incr(self, key, delta): + client = self.get_client(key) + if not client.exists(key): + raise ValueError("Key '%s' not found." % key) + return client.incr(key, delta) + + def set_many(self, data, timeout): + client = self.get_client(None, write=True) + pipeline = client.pipeline() + pipeline.mset({k: self._serializer.dumps(v) for k, v in data.items()}) + + if timeout is not None: + # Setting timeout for each key as redis does not support timeout + # with mset(). + for key in data: + pipeline.expire(key, timeout) + pipeline.execute() + + def delete_many(self, keys): + client = self.get_client(None, write=True) + client.delete(*keys) + + def clear(self): + client = self.get_client(None, write=True) + return bool(client.flushdb()) + + +class RedisCache(BaseCache): + def __init__(self, server, params): + super().__init__(params) + if isinstance(server, str): + self._servers = re.split('[;,]', server) + else: + self._servers = server + + self._class = RedisCacheClient + self._options = params.get('OPTIONS', {}) + + @cached_property + def _cache(self): + return self._class(self._servers, **self._options) + + def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT): + if timeout == DEFAULT_TIMEOUT: + timeout = self.default_timeout + # The key will be made persistent if None used as a timeout. + # Non-positive values will cause the key to be deleted. + return None if timeout is None else max(0, int(timeout)) + + def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): + key = self.make_and_validate_key(key, version=version) + return self._cache.add(key, value, self.get_backend_timeout(timeout)) + + def get(self, key, default=None, version=None): + key = self.make_and_validate_key(key, version=version) + return self._cache.get(key, default) + + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): + key = self.make_and_validate_key(key, version=version) + self._cache.set(key, value, self.get_backend_timeout(timeout)) + + def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): + key = self.make_and_validate_key(key, version=version) + return self._cache.touch(key, self.get_backend_timeout(timeout)) + + def delete(self, key, version=None): + key = self.make_and_validate_key(key, version=version) + return self._cache.delete(key) + + def get_many(self, keys, version=None): + key_map = {self.make_and_validate_key(key, version=version): key for key in keys} + ret = self._cache.get_many(key_map.keys()) + return {key_map[k]: v for k, v in ret.items()} + + def has_key(self, key, version=None): + key = self.make_and_validate_key(key, version=version) + return self._cache.has_key(key) + + def incr(self, key, delta=1, version=None): + key = self.make_and_validate_key(key, version=version) + return self._cache.incr(key, delta) + + def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): + safe_data = {} + for key, value in data.items(): + key = self.make_and_validate_key(key, version=version) + safe_data[key] = value + self._cache.set_many(safe_data, self.get_backend_timeout(timeout)) + return [] + + def delete_many(self, keys, version=None): + safe_keys = [] + for key in keys: + key = self.make_and_validate_key(key, version=version) + safe_keys.append(key) + self._cache.delete_many(safe_keys) + + def clear(self): + return self._cache.clear() |
