diff options
| author | Bendeguz Csirmaz <csirmazbendeguz@gmail.com> | 2024-04-07 10:32:16 +0800 |
|---|---|---|
| committer | Sarah Boyce <42296566+sarahboyce@users.noreply.github.com> | 2024-11-29 11:23:04 +0100 |
| commit | 978aae4334fa71ba78a3e94408f0f3aebde8d07c (patch) | |
| tree | dd1cc322769441a3dd28b952ce52e07c3f72f90a /django/db/models/sql | |
| parent | 86661f2449fb0903f72b3522c68e146934013377 (diff) | |
Fixed #373 -- Added CompositePrimaryKey.
Thanks Lily Foote and Simon Charette for reviews and mentoring
this Google Summer of Code 2024 project.
Co-authored-by: Simon Charette <charette.s@gmail.com>
Co-authored-by: Lily Foote <code@lilyf.org>
Diffstat (limited to 'django/db/models/sql')
| -rw-r--r-- | django/db/models/sql/compiler.py | 75 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 8 |
2 files changed, 71 insertions, 12 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 49263d5944..053bdc09d5 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -7,7 +7,9 @@ from itertools import chain from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import DatabaseError, NotSupportedError from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value +from django.db.models.expressions import ColPairs, F, OrderBy, RawSQL, Ref, Value +from django.db.models.fields import composite +from django.db.models.fields.composite import CompositePrimaryKey from django.db.models.functions import Cast, Random from django.db.models.lookups import Lookup from django.db.models.query_utils import select_related_descend @@ -283,6 +285,9 @@ class SQLCompiler: # Reference to a column. elif isinstance(expression, int): expression = cols[expression] + # ColPairs cannot be aliased. + if isinstance(expression, ColPairs): + alias = None selected.append((alias, expression)) for select_idx, (alias, expression) in enumerate(selected): @@ -997,6 +1002,7 @@ class SQLCompiler: # alias for a given field. This also includes None -> start_alias to # be used by local fields. seen_models = {None: start_alias} + select_mask_fields = set(composite.unnest(select_mask)) for field in opts.concrete_fields: model = field.model._meta.concrete_model @@ -1017,7 +1023,7 @@ class SQLCompiler: # parent model data is already present in the SELECT clause, # and we want to avoid reloading the same data again. continue - if select_mask and field not in select_mask: + if select_mask and field not in select_mask_fields: continue alias = self.query.join_parent_model(opts, model, start_alias, seen_models) column = field.get_col(alias) @@ -1110,9 +1116,10 @@ class SQLCompiler: ) return results targets, alias, _ = self.query.trim_joins(targets, joins, path) + target_fields = composite.unnest(targets) return [ (OrderBy(transform_function(t, alias), descending=descending), False) - for t in targets + for t in target_fields ] def _setup_joins(self, pieces, opts, alias): @@ -1504,13 +1511,25 @@ class SQLCompiler: return result def get_converters(self, expressions): + i = 0 converters = {} - for i, expression in enumerate(expressions): - if expression: + + for expression in expressions: + if isinstance(expression, ColPairs): + cols = expression.get_source_expressions() + cols_converters = self.get_converters(cols) + for j, (convs, col) in cols_converters.items(): + converters[i + j] = (convs, col) + i += len(expression) + elif expression: backend_converters = self.connection.ops.get_db_converters(expression) field_converters = expression.get_db_converters(self.connection) if backend_converters or field_converters: converters[i] = (backend_converters + field_converters, expression) + i += 1 + else: + i += 1 + return converters def apply_converters(self, rows, converters): @@ -1524,6 +1543,24 @@ class SQLCompiler: row[pos] = value yield row + def has_composite_fields(self, expressions): + # Check for composite fields before calling the relatively costly + # composite_fields_to_tuples. + return any(isinstance(expression, ColPairs) for expression in expressions) + + def composite_fields_to_tuples(self, rows, expressions): + col_pair_slices = [ + slice(i, i + len(expression)) + for i, expression in enumerate(expressions) + if isinstance(expression, ColPairs) + ] + + for row in map(list, rows): + for pos in col_pair_slices: + row[pos] = (tuple(row[pos]),) + + yield row + def results_iter( self, results=None, @@ -1541,8 +1578,10 @@ class SQLCompiler: rows = chain.from_iterable(results) if converters: rows = self.apply_converters(rows, converters) - if tuple_expected: - rows = map(tuple, rows) + if self.has_composite_fields(fields): + rows = self.composite_fields_to_tuples(rows, fields) + if tuple_expected: + rows = map(tuple, rows) return rows def has_results(self): @@ -1863,6 +1902,18 @@ class SQLInsertCompiler(SQLCompiler): ) ] cols = [field.get_col(opts.db_table) for field in self.returning_fields] + elif isinstance(opts.pk, CompositePrimaryKey): + returning_field = returning_fields[0] + cols = [returning_field.get_col(opts.db_table)] + rows = [ + ( + self.connection.ops.last_insert_id( + cursor, + opts.db_table, + returning_field.column, + ), + ) + ] else: cols = [opts.pk.get_col(opts.db_table)] rows = [ @@ -1876,8 +1927,10 @@ class SQLInsertCompiler(SQLCompiler): ] converters = self.get_converters(cols) if converters: - rows = list(self.apply_converters(rows, converters)) - return rows + rows = self.apply_converters(rows, converters) + if self.has_composite_fields(cols): + rows = self.composite_fields_to_tuples(rows, cols) + return list(rows) class SQLDeleteCompiler(SQLCompiler): @@ -2065,6 +2118,7 @@ class SQLUpdateCompiler(SQLCompiler): query.add_fields(fields) super().pre_sql_setup() + is_composite_pk = meta.is_composite_pk must_pre_select = ( count > 1 and not self.connection.features.update_can_self_select ) @@ -2079,7 +2133,8 @@ class SQLUpdateCompiler(SQLCompiler): idents = [] related_ids = collections.defaultdict(list) for rows in query.get_compiler(self.using).execute_sql(MULTI): - idents.extend(r[0] for r in rows) + pks = [row if is_composite_pk else row[0] for row in rows] + idents.extend(pks) for parent, index in related_ids_index: related_ids[parent].extend(r[index] for r in rows) self.query.add_filter("pk__in", idents) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b7b93c235a..cca11bfcc2 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -627,8 +627,12 @@ class Query(BaseExpression): if result is None: result = empty_set_result else: - converters = compiler.get_converters(outer_query.annotation_select.values()) - result = next(compiler.apply_converters((result,), converters)) + cols = outer_query.annotation_select.values() + converters = compiler.get_converters(cols) + rows = compiler.apply_converters((result,), converters) + if compiler.has_composite_fields(cols): + rows = compiler.composite_fields_to_tuples(rows, cols) + result = next(rows) return dict(zip(outer_query.annotation_select, result)) |
