summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/db/backends/postgresql/compiler.py50
-rw-r--r--django/db/backends/postgresql/operations.py7
-rw-r--r--tests/backends/postgresql/test_compilation.py29
3 files changed, 86 insertions, 0 deletions
diff --git a/django/db/backends/postgresql/compiler.py b/django/db/backends/postgresql/compiler.py
new file mode 100644
index 0000000000..2394d90f55
--- /dev/null
+++ b/django/db/backends/postgresql/compiler.py
@@ -0,0 +1,50 @@
+from django.db.models.sql.compiler import (
+ SQLAggregateCompiler,
+ SQLCompiler,
+ SQLDeleteCompiler,
+)
+from django.db.models.sql.compiler import SQLInsertCompiler as BaseSQLInsertCompiler
+from django.db.models.sql.compiler import SQLUpdateCompiler
+
+__all__ = [
+ "SQLAggregateCompiler",
+ "SQLCompiler",
+ "SQLDeleteCompiler",
+ "SQLInsertCompiler",
+ "SQLUpdateCompiler",
+]
+
+
+class InsertUnnest(list):
+ """
+ Sentinel value to signal DatabaseOperations.bulk_insert_sql() that the
+ UNNEST strategy should be used for the bulk insert.
+ """
+
+ def __str__(self):
+ return "UNNEST(%s)" % ", ".join(self)
+
+
+class SQLInsertCompiler(BaseSQLInsertCompiler):
+ def assemble_as_sql(self, fields, value_rows):
+ # Specialize bulk-insertion of literal non-array values through
+ # UNNEST to reduce the time spent planning the query.
+ if (
+ # The optimization is not worth doing if there is a single
+ # row as it will result in the same number of placeholders.
+ len(value_rows) <= 1
+ # Lack of fields denote the usage of the DEFAULT keyword
+ # for the insertion of empty rows.
+ or any(field is None for field in fields)
+ # Compilable cannot be combined in an array of literal values.
+ or any(any(hasattr(value, "as_sql") for value in row) for row in value_rows)
+ ):
+ return super().assemble_as_sql(fields, value_rows)
+ db_types = [field.db_type(self.connection) for field in fields]
+ # Abort if any of the fields are arrays as UNNEST indiscriminately
+ # flatten them instead of reducing their nesting by one.
+ if any(db_type.endswith("[]") for db_type in db_types):
+ return super().assemble_as_sql(fields, value_rows)
+ return InsertUnnest(["(%%s)::%s[]" % db_type for db_type in db_types]), [
+ list(map(list, zip(*value_rows)))
+ ]
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
index 8a0ca36a29..9db755bb89 100644
--- a/django/db/backends/postgresql/operations.py
+++ b/django/db/backends/postgresql/operations.py
@@ -3,6 +3,7 @@ from functools import lru_cache, partial
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
+from django.db.backends.postgresql.compiler import InsertUnnest
from django.db.backends.postgresql.psycopg_any import (
Inet,
Jsonb,
@@ -24,6 +25,7 @@ def get_json_dumps(encoder):
class DatabaseOperations(BaseDatabaseOperations):
+ compiler_module = "django.db.backends.postgresql.compiler"
cast_char_field_without_max_length = "varchar"
explain_prefix = "EXPLAIN"
explain_options = frozenset(
@@ -148,6 +150,11 @@ class DatabaseOperations(BaseDatabaseOperations):
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
+ def bulk_insert_sql(self, fields, placeholder_rows):
+ if isinstance(placeholder_rows, InsertUnnest):
+ return f"SELECT * FROM {placeholder_rows}"
+ return super().bulk_insert_sql(fields, placeholder_rows)
+
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
diff --git a/tests/backends/postgresql/test_compilation.py b/tests/backends/postgresql/test_compilation.py
new file mode 100644
index 0000000000..67fe893e35
--- /dev/null
+++ b/tests/backends/postgresql/test_compilation.py
@@ -0,0 +1,29 @@
+import unittest
+
+from django.db import connection
+from django.db.models.expressions import RawSQL
+from django.test import TestCase
+
+from ..models import Square
+
+
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests")
+class BulkCreateUnnestTests(TestCase):
+ def test_single_object(self):
+ with self.assertNumQueries(1) as ctx:
+ Square.objects.bulk_create([Square(root=2, square=4)])
+ self.assertNotIn("UNNEST", ctx[0]["sql"])
+
+ def test_non_literal(self):
+ with self.assertNumQueries(1) as ctx:
+ Square.objects.bulk_create(
+ [Square(root=2, square=RawSQL("%s", (4,))), Square(root=3, square=9)]
+ )
+ self.assertNotIn("UNNEST", ctx[0]["sql"])
+
+ def test_unnest_eligible(self):
+ with self.assertNumQueries(1) as ctx:
+ Square.objects.bulk_create(
+ [Square(root=2, square=4), Square(root=3, square=9)]
+ )
+ self.assertIn("UNNEST", ctx[0]["sql"])