summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
authorBendeguz Csirmaz <csirmazbendeguz@gmail.com>2024-04-07 10:32:16 +0800
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2024-11-29 11:23:04 +0100
commit978aae4334fa71ba78a3e94408f0f3aebde8d07c (patch)
treedd1cc322769441a3dd28b952ce52e07c3f72f90a /django/db/models/sql/compiler.py
parent86661f2449fb0903f72b3522c68e146934013377 (diff)
Fixed #373 -- Added CompositePrimaryKey.
Thanks Lily Foote and Simon Charette for reviews and mentoring this Google Summer of Code 2024 project. Co-authored-by: Simon Charette <charette.s@gmail.com> Co-authored-by: Lily Foote <code@lilyf.org>
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py75
1 files changed, 65 insertions, 10 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 49263d5944..053bdc09d5 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -7,7 +7,9 @@ from itertools import chain
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError
from django.db.models.constants import LOOKUP_SEP
-from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
+from django.db.models.expressions import ColPairs, F, OrderBy, RawSQL, Ref, Value
+from django.db.models.fields import composite
+from django.db.models.fields.composite import CompositePrimaryKey
from django.db.models.functions import Cast, Random
from django.db.models.lookups import Lookup
from django.db.models.query_utils import select_related_descend
@@ -283,6 +285,9 @@ class SQLCompiler:
# Reference to a column.
elif isinstance(expression, int):
expression = cols[expression]
+ # ColPairs cannot be aliased.
+ if isinstance(expression, ColPairs):
+ alias = None
selected.append((alias, expression))
for select_idx, (alias, expression) in enumerate(selected):
@@ -997,6 +1002,7 @@ class SQLCompiler:
# alias for a given field. This also includes None -> start_alias to
# be used by local fields.
seen_models = {None: start_alias}
+ select_mask_fields = set(composite.unnest(select_mask))
for field in opts.concrete_fields:
model = field.model._meta.concrete_model
@@ -1017,7 +1023,7 @@ class SQLCompiler:
# parent model data is already present in the SELECT clause,
# and we want to avoid reloading the same data again.
continue
- if select_mask and field not in select_mask:
+ if select_mask and field not in select_mask_fields:
continue
alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
column = field.get_col(alias)
@@ -1110,9 +1116,10 @@ class SQLCompiler:
)
return results
targets, alias, _ = self.query.trim_joins(targets, joins, path)
+ target_fields = composite.unnest(targets)
return [
(OrderBy(transform_function(t, alias), descending=descending), False)
- for t in targets
+ for t in target_fields
]
def _setup_joins(self, pieces, opts, alias):
@@ -1504,13 +1511,25 @@ class SQLCompiler:
return result
def get_converters(self, expressions):
+ i = 0
converters = {}
- for i, expression in enumerate(expressions):
- if expression:
+
+ for expression in expressions:
+ if isinstance(expression, ColPairs):
+ cols = expression.get_source_expressions()
+ cols_converters = self.get_converters(cols)
+ for j, (convs, col) in cols_converters.items():
+ converters[i + j] = (convs, col)
+ i += len(expression)
+ elif expression:
backend_converters = self.connection.ops.get_db_converters(expression)
field_converters = expression.get_db_converters(self.connection)
if backend_converters or field_converters:
converters[i] = (backend_converters + field_converters, expression)
+ i += 1
+ else:
+ i += 1
+
return converters
def apply_converters(self, rows, converters):
@@ -1524,6 +1543,24 @@ class SQLCompiler:
row[pos] = value
yield row
+ def has_composite_fields(self, expressions):
+ # Check for composite fields before calling the relatively costly
+ # composite_fields_to_tuples.
+ return any(isinstance(expression, ColPairs) for expression in expressions)
+
+ def composite_fields_to_tuples(self, rows, expressions):
+ col_pair_slices = [
+ slice(i, i + len(expression))
+ for i, expression in enumerate(expressions)
+ if isinstance(expression, ColPairs)
+ ]
+
+ for row in map(list, rows):
+ for pos in col_pair_slices:
+ row[pos] = (tuple(row[pos]),)
+
+ yield row
+
def results_iter(
self,
results=None,
@@ -1541,8 +1578,10 @@ class SQLCompiler:
rows = chain.from_iterable(results)
if converters:
rows = self.apply_converters(rows, converters)
- if tuple_expected:
- rows = map(tuple, rows)
+ if self.has_composite_fields(fields):
+ rows = self.composite_fields_to_tuples(rows, fields)
+ if tuple_expected:
+ rows = map(tuple, rows)
return rows
def has_results(self):
@@ -1863,6 +1902,18 @@ class SQLInsertCompiler(SQLCompiler):
)
]
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
+ elif isinstance(opts.pk, CompositePrimaryKey):
+ returning_field = returning_fields[0]
+ cols = [returning_field.get_col(opts.db_table)]
+ rows = [
+ (
+ self.connection.ops.last_insert_id(
+ cursor,
+ opts.db_table,
+ returning_field.column,
+ ),
+ )
+ ]
else:
cols = [opts.pk.get_col(opts.db_table)]
rows = [
@@ -1876,8 +1927,10 @@ class SQLInsertCompiler(SQLCompiler):
]
converters = self.get_converters(cols)
if converters:
- rows = list(self.apply_converters(rows, converters))
- return rows
+ rows = self.apply_converters(rows, converters)
+ if self.has_composite_fields(cols):
+ rows = self.composite_fields_to_tuples(rows, cols)
+ return list(rows)
class SQLDeleteCompiler(SQLCompiler):
@@ -2065,6 +2118,7 @@ class SQLUpdateCompiler(SQLCompiler):
query.add_fields(fields)
super().pre_sql_setup()
+ is_composite_pk = meta.is_composite_pk
must_pre_select = (
count > 1 and not self.connection.features.update_can_self_select
)
@@ -2079,7 +2133,8 @@ class SQLUpdateCompiler(SQLCompiler):
idents = []
related_ids = collections.defaultdict(list)
for rows in query.get_compiler(self.using).execute_sql(MULTI):
- idents.extend(r[0] for r in rows)
+ pks = [row if is_composite_pk else row[0] for row in rows]
+ idents.extend(pks)
for parent, index in related_ids_index:
related_ids[parent].extend(r[index] for r in rows)
self.query.add_filter("pk__in", idents)