summaryrefslogtreecommitdiff
path: root/django/db/models/query_utils.py
diff options
context:
space:
mode:
authorAllen Jonathan David <allenajdjonathan@gmail.com>2022-06-23 12:02:53 +0530
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-09-02 10:02:24 +0200
commitcd1afd553f9c175ebccfc0f50e72b43b9604bd97 (patch)
tree20910c2c7e1843d39cd3ca3bb7e0b0dbb4658706 /django/db/models/query_utils.py
parentfdf0f625216cc5a70d28a3ac9a41f41935f1827c (diff)
Fixed #29799 -- Allowed registering lookups per field instances.
Thanks Simon Charette and Mariusz Felisiak for reviews and mentoring this Google Summer of Code 2022 project.
Diffstat (limited to 'django/db/models/query_utils.py')
-rw-r--r--django/db/models/query_utils.py74
1 files changed, 61 insertions, 13 deletions
diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
index f4215ed48e..5562303e00 100644
--- a/django/db/models/query_utils.py
+++ b/django/db/models/query_utils.py
@@ -188,19 +188,42 @@ class DeferredAttribute:
return None
+class class_or_instance_method:
+ """
+ Hook used in RegisterLookupMixin to return partial functions depending on
+ the caller type (instance or class of models.Field).
+ """
+
+ def __init__(self, class_method, instance_method):
+ self.class_method = class_method
+ self.instance_method = instance_method
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return functools.partial(self.class_method, owner)
+ return functools.partial(self.instance_method, instance)
+
+
class RegisterLookupMixin:
- @classmethod
- def _get_lookup(cls, lookup_name):
- return cls.get_lookups().get(lookup_name, None)
+ def _get_lookup(self, lookup_name):
+ return self.get_lookups().get(lookup_name, None)
- @classmethod
@functools.lru_cache(maxsize=None)
- def get_lookups(cls):
+ def get_class_lookups(cls):
class_lookups = [
parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
]
return cls.merge_dicts(class_lookups)
+ def get_instance_lookups(self):
+ class_lookups = self.get_class_lookups()
+ if instance_lookups := getattr(self, "instance_lookups", None):
+ return {**class_lookups, **instance_lookups}
+ return class_lookups
+
+ get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
+ get_class_lookups = classmethod(get_class_lookups)
+
def get_lookup(self, lookup_name):
from django.db.models.lookups import Lookup
@@ -233,22 +256,33 @@ class RegisterLookupMixin:
return merged
@classmethod
- def _clear_cached_lookups(cls):
+ def _clear_cached_class_lookups(cls):
for subclass in subclasses(cls):
- subclass.get_lookups.cache_clear()
+ subclass.get_class_lookups.cache_clear()
- @classmethod
- def register_lookup(cls, lookup, lookup_name=None):
+ def register_class_lookup(cls, lookup, lookup_name=None):
if lookup_name is None:
lookup_name = lookup.lookup_name
if "class_lookups" not in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup_name] = lookup
- cls._clear_cached_lookups()
+ cls._clear_cached_class_lookups()
return lookup
- @classmethod
- def _unregister_lookup(cls, lookup, lookup_name=None):
+ def register_instance_lookup(self, lookup, lookup_name=None):
+ if lookup_name is None:
+ lookup_name = lookup.lookup_name
+ if "instance_lookups" not in self.__dict__:
+ self.instance_lookups = {}
+ self.instance_lookups[lookup_name] = lookup
+ return lookup
+
+ register_lookup = class_or_instance_method(
+ register_class_lookup, register_instance_lookup
+ )
+ register_class_lookup = classmethod(register_class_lookup)
+
+ def _unregister_class_lookup(cls, lookup, lookup_name=None):
"""
Remove given lookup from cls lookups. For use in tests only as it's
not thread-safe.
@@ -256,7 +290,21 @@ class RegisterLookupMixin:
if lookup_name is None:
lookup_name = lookup.lookup_name
del cls.class_lookups[lookup_name]
- cls._clear_cached_lookups()
+ cls._clear_cached_class_lookups()
+
+ def _unregister_instance_lookup(self, lookup, lookup_name=None):
+ """
+ Remove given lookup from instance lookups. For use in tests only as
+ it's not thread-safe.
+ """
+ if lookup_name is None:
+ lookup_name = lookup.lookup_name
+ del self.instance_lookups[lookup_name]
+
+ _unregister_lookup = class_or_instance_method(
+ _unregister_class_lookup, _unregister_instance_lookup
+ )
+ _unregister_class_lookup = classmethod(_unregister_class_lookup)
def select_related_descend(field, restricted, requested, select_mask, reverse=False):