summaryrefslogtreecommitdiff
path: root/django/contrib/postgres/search.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/contrib/postgres/search.py')
-rw-r--r--django/contrib/postgres/search.py141
1 files changed, 141 insertions, 0 deletions
diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py
index 4ab27605cb..52e925d27a 100644
--- a/django/contrib/postgres/search.py
+++ b/django/contrib/postgres/search.py
@@ -1,3 +1,4 @@
+from django.db.backends.postgresql.psycopg_any import is_psycopg3
from django.db.models import (
CharField,
Expression,
@@ -10,9 +11,45 @@ from django.db.models import (
)
from django.db.models.expressions import CombinedExpression, register_combinable_fields
from django.db.models.functions import Cast, Coalesce
+from django.utils.regex_helper import _lazy_re_compile
from .utils import CheckPostgresInstalledMixin
+if is_psycopg3:
+ from psycopg.adapt import Dumper
+
+ class UTF8Dumper(Dumper):
+ def dump(self, obj):
+ return bytes(obj, "utf-8")
+
+ def quote_lexeme(value):
+ return UTF8Dumper(str).quote(psql_escape(value)).decode()
+
+else:
+ from psycopg2.extensions import adapt
+
+ def quote_lexeme(value):
+ adapter = adapt(psql_escape(value))
+ adapter.encoding = "utf-8"
+ return adapter.getquoted().decode()
+
+
+spec_chars_re = _lazy_re_compile(r"['\0\[\]()|&:*!@<>\\]")
+multiple_spaces_re = _lazy_re_compile(r"\s{2,}")
+
+
+def normalize_spaces(val):
+ """Convert multiple spaces to single and strip from both sides."""
+ if not (val := val.strip()):
+ return None
+ return multiple_spaces_re.sub(" ", val)
+
+
+def psql_escape(query):
+ """Replace chars not fit for use in search queries with a single space."""
+ query = spec_chars_re.sub(" ", query)
+ return normalize_spaces(query)
+
class SearchVectorExact(Lookup):
lookup_name = "exact"
@@ -205,6 +242,9 @@ class SearchQuery(SearchQueryCombinable, Func):
invert=False,
search_type="plain",
):
+ if isinstance(value, LexemeCombinable):
+ search_type = "raw"
+
self.function = self.SEARCH_TYPES.get(search_type)
if self.function is None:
raise ValueError("Unknown search_type argument '%s'." % search_type)
@@ -383,3 +423,104 @@ class TrigramWordSimilarity(TrigramWordBase):
class TrigramStrictWordSimilarity(TrigramWordBase):
function = "STRICT_WORD_SIMILARITY"
+
+
+class LexemeCombinable:
+ BITAND = "&"
+ BITOR = "|"
+
+ def _combine(self, other, connector, reversed):
+ if not isinstance(other, LexemeCombinable):
+ raise TypeError(
+ "A Lexeme can only be combined with another Lexeme, "
+ f"got {other.__class__.__name__}."
+ )
+ if reversed:
+ return CombinedLexeme(other, connector, self)
+ return CombinedLexeme(self, connector, other)
+
+ # On Combinable, these are not implemented to reduce confusion with Q. In
+ # this case we are actually (ab)using them to do logical combination so
+ # it's consistent with other usage in Django.
+ def __or__(self, other):
+ return self._combine(other, self.BITOR, False)
+
+ def __ror__(self, other):
+ return self._combine(other, self.BITOR, True)
+
+ def __and__(self, other):
+ return self._combine(other, self.BITAND, False)
+
+ def __rand__(self, other):
+ return self._combine(other, self.BITAND, True)
+
+
+class Lexeme(LexemeCombinable, Value):
+ _output_field = SearchQueryField()
+
+ def __init__(
+ self, value, output_field=None, *, invert=False, prefix=False, weight=None
+ ):
+ if value == "":
+ raise ValueError("Lexeme value cannot be empty.")
+
+ if not isinstance(value, str):
+ raise TypeError(
+ f"Lexeme value must be a string, got {value.__class__.__name__}."
+ )
+
+ if weight is not None and (
+ not isinstance(weight, str) or weight.lower() not in {"a", "b", "c", "d"}
+ ):
+ raise ValueError(
+ f"Weight must be one of 'A', 'B', 'C', and 'D', got {weight!r}."
+ )
+
+ self.prefix = prefix
+ self.invert = invert
+ self.weight = weight
+ super().__init__(value, output_field=output_field)
+
+ def as_sql(self, compiler, connection):
+ param = quote_lexeme(self.value)
+ label = ""
+ if self.prefix:
+ label += "*"
+ if self.weight:
+ label += self.weight
+
+ if label:
+ param = f"{param}:{label}"
+ if self.invert:
+ param = f"!{param}"
+
+ return "%s", (param,)
+
+ def __invert__(self):
+ cloned = self.copy()
+ cloned.invert = not self.invert
+ return cloned
+
+
+class CombinedLexeme(LexemeCombinable, CombinedExpression):
+ _output_field = SearchQueryField()
+
+ def as_sql(self, compiler, connection):
+ value_params = []
+ lsql, params = compiler.compile(self.lhs)
+ value_params.extend(params)
+
+ rsql, params = compiler.compile(self.rhs)
+ value_params.extend(params)
+
+ combined_sql = f"({lsql} {self.connector} {rsql})"
+ combined_value = combined_sql % tuple(value_params)
+ return "%s", (combined_value,)
+
+ def __invert__(self):
+ # Apply De Morgan's theorem.
+ cloned = self.copy()
+ cloned.connector = self.BITAND if self.connector == self.BITOR else self.BITOR
+ cloned.lhs = ~self.lhs
+ cloned.rhs = ~self.rhs
+ return cloned