summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorBendeguz Csirmaz <csirmazbendeguz@gmail.com>2024-04-07 10:32:16 +0800
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2024-11-29 11:23:04 +0100
commit978aae4334fa71ba78a3e94408f0f3aebde8d07c (patch)
treedd1cc322769441a3dd28b952ce52e07c3f72f90a /tests
parent86661f2449fb0903f72b3522c68e146934013377 (diff)
Fixed #373 -- Added CompositePrimaryKey.
Thanks Lily Foote and Simon Charette for reviews and mentoring this Google Summer of Code 2024 project. Co-authored-by: Simon Charette <charette.s@gmail.com> Co-authored-by: Lily Foote <code@lilyf.org>
Diffstat (limited to 'tests')
-rw-r--r--tests/admin_registration/models.py6
-rw-r--r--tests/admin_registration/tests.py10
-rw-r--r--tests/composite_pk/__init__.py0
-rw-r--r--tests/composite_pk/fixtures/tenant.json75
-rw-r--r--tests/composite_pk/models/__init__.py9
-rw-r--r--tests/composite_pk/models/tenant.py50
-rw-r--r--tests/composite_pk/test_aggregate.py139
-rw-r--r--tests/composite_pk/test_checks.py242
-rw-r--r--tests/composite_pk/test_create.py138
-rw-r--r--tests/composite_pk/test_delete.py83
-rw-r--r--tests/composite_pk/test_filter.py412
-rw-r--r--tests/composite_pk/test_get.py126
-rw-r--r--tests/composite_pk/test_models.py153
-rw-r--r--tests/composite_pk/test_names_to_path.py134
-rw-r--r--tests/composite_pk/test_update.py135
-rw-r--r--tests/composite_pk/test_values.py212
-rw-r--r--tests/composite_pk/tests.py345
-rw-r--r--tests/migrations/test_autodetector.py89
-rw-r--r--tests/migrations/test_operations.py55
-rw-r--r--tests/migrations/test_state.py22
-rw-r--r--tests/migrations/test_writer.py19
21 files changed, 2453 insertions, 1 deletions
diff --git a/tests/admin_registration/models.py b/tests/admin_registration/models.py
index 0ae9251133..2231c236de 100644
--- a/tests/admin_registration/models.py
+++ b/tests/admin_registration/models.py
@@ -20,3 +20,9 @@ class Location(models.Model):
class Place(Location):
name = models.CharField(max_length=200)
+
+
+class Guest(models.Model):
+ pk = models.CompositePrimaryKey("traveler", "place")
+ traveler = models.ForeignKey(Traveler, on_delete=models.CASCADE)
+ place = models.ForeignKey(Place, on_delete=models.CASCADE)
diff --git a/tests/admin_registration/tests.py b/tests/admin_registration/tests.py
index 3b0e656f5f..0a881caf65 100644
--- a/tests/admin_registration/tests.py
+++ b/tests/admin_registration/tests.py
@@ -5,7 +5,7 @@ from django.contrib.admin.sites import site
from django.core.exceptions import ImproperlyConfigured
from django.test import SimpleTestCase
-from .models import Location, Person, Place, Traveler
+from .models import Guest, Location, Person, Place, Traveler
class NameAdmin(admin.ModelAdmin):
@@ -92,6 +92,14 @@ class TestRegistration(SimpleTestCase):
with self.assertRaisesMessage(ImproperlyConfigured, msg):
self.site.register(Location)
+ def test_composite_pk_model(self):
+ msg = (
+ "The model Guest has a composite primary key, so it cannot be registered "
+ "with admin."
+ )
+ with self.assertRaisesMessage(ImproperlyConfigured, msg):
+ self.site.register(Guest)
+
def test_is_registered_model(self):
"Checks for registered models should return true."
self.site.register(Person)
diff --git a/tests/composite_pk/__init__.py b/tests/composite_pk/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/composite_pk/__init__.py
diff --git a/tests/composite_pk/fixtures/tenant.json b/tests/composite_pk/fixtures/tenant.json
new file mode 100644
index 0000000000..3eeff42fef
--- /dev/null
+++ b/tests/composite_pk/fixtures/tenant.json
@@ -0,0 +1,75 @@
+[
+ {
+ "pk": 1,
+ "model": "composite_pk.tenant",
+ "fields": {
+ "id": 1,
+ "name": "Tenant 1"
+ }
+ },
+ {
+ "pk": 2,
+ "model": "composite_pk.tenant",
+ "fields": {
+ "id": 2,
+ "name": "Tenant 2"
+ }
+ },
+ {
+ "pk": 3,
+ "model": "composite_pk.tenant",
+ "fields": {
+ "id": 3,
+ "name": "Tenant 3"
+ }
+ },
+ {
+ "pk": [1, 1],
+ "model": "composite_pk.user",
+ "fields": {
+ "tenant_id": 1,
+ "id": 1,
+ "email": "user0001@example.com"
+ }
+ },
+ {
+ "pk": [1, 2],
+ "model": "composite_pk.user",
+ "fields": {
+ "tenant_id": 1,
+ "id": 2,
+ "email": "user0002@example.com"
+ }
+ },
+ {
+ "pk": [2, 3],
+ "model": "composite_pk.user",
+ "fields": {
+ "email": "user0003@example.com"
+ }
+ },
+ {
+ "model": "composite_pk.user",
+ "fields": {
+ "tenant_id": 2,
+ "id": 4,
+ "email": "user0004@example.com"
+ }
+ },
+ {
+ "pk": [2, "11111111-1111-1111-1111-111111111111"],
+ "model": "composite_pk.post",
+ "fields": {
+ "tenant_id": 2,
+ "id": "11111111-1111-1111-1111-111111111111"
+ }
+ },
+ {
+ "pk": [2, "ffffffff-ffff-ffff-ffff-ffffffffffff"],
+ "model": "composite_pk.post",
+ "fields": {
+ "tenant_id": 2,
+ "id": "ffffffff-ffff-ffff-ffff-ffffffffffff"
+ }
+ }
+]
diff --git a/tests/composite_pk/models/__init__.py b/tests/composite_pk/models/__init__.py
new file mode 100644
index 0000000000..35c3943716
--- /dev/null
+++ b/tests/composite_pk/models/__init__.py
@@ -0,0 +1,9 @@
+from .tenant import Comment, Post, Tenant, Token, User
+
+__all__ = [
+ "Comment",
+ "Post",
+ "Tenant",
+ "Token",
+ "User",
+]
diff --git a/tests/composite_pk/models/tenant.py b/tests/composite_pk/models/tenant.py
new file mode 100644
index 0000000000..ac0b3d9715
--- /dev/null
+++ b/tests/composite_pk/models/tenant.py
@@ -0,0 +1,50 @@
+from django.db import models
+
+
+class Tenant(models.Model):
+ name = models.CharField(max_length=10, default="", blank=True)
+
+
+class Token(models.Model):
+ pk = models.CompositePrimaryKey("tenant_id", "id")
+ tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, related_name="tokens")
+ id = models.SmallIntegerField()
+ secret = models.CharField(max_length=10, default="", blank=True)
+
+
+class BaseModel(models.Model):
+ pk = models.CompositePrimaryKey("tenant_id", "id")
+ tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
+ id = models.SmallIntegerField(unique=True)
+
+ class Meta:
+ abstract = True
+
+
+class User(BaseModel):
+ email = models.EmailField(unique=True)
+
+
+class Comment(models.Model):
+ pk = models.CompositePrimaryKey("tenant", "id")
+ tenant = models.ForeignKey(
+ Tenant,
+ on_delete=models.CASCADE,
+ related_name="comments",
+ )
+ id = models.SmallIntegerField(unique=True, db_column="comment_id")
+ user_id = models.SmallIntegerField()
+ user = models.ForeignObject(
+ User,
+ on_delete=models.CASCADE,
+ from_fields=("tenant_id", "user_id"),
+ to_fields=("tenant_id", "id"),
+ related_name="comments",
+ )
+ text = models.TextField(default="", blank=True)
+
+
+class Post(models.Model):
+ pk = models.CompositePrimaryKey("tenant_id", "id")
+ tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
+ id = models.UUIDField()
diff --git a/tests/composite_pk/test_aggregate.py b/tests/composite_pk/test_aggregate.py
new file mode 100644
index 0000000000..b5474c5218
--- /dev/null
+++ b/tests/composite_pk/test_aggregate.py
@@ -0,0 +1,139 @@
+from django.db import NotSupportedError
+from django.db.models import Count, Q
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class CompositePKAggregateTests(TestCase):
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant_1 = Tenant.objects.create()
+ cls.tenant_2 = Tenant.objects.create()
+ cls.user_1 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=1,
+ email="user0001@example.com",
+ )
+ cls.user_2 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=2,
+ email="user0002@example.com",
+ )
+ cls.user_3 = User.objects.create(
+ tenant=cls.tenant_2,
+ id=3,
+ email="user0003@example.com",
+ )
+ cls.comment_1 = Comment.objects.create(id=1, user=cls.user_2, text="foo")
+ cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1, text="bar")
+ cls.comment_3 = Comment.objects.create(id=3, user=cls.user_1, text="foobar")
+ cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3, text="foobarbaz")
+ cls.comment_5 = Comment.objects.create(id=5, user=cls.user_3, text="barbaz")
+ cls.comment_6 = Comment.objects.create(id=6, user=cls.user_3, text="baz")
+
+ def test_users_annotated_with_comments_id_count(self):
+ user_1, user_2, user_3 = User.objects.annotate(Count("comments__id")).order_by(
+ "pk"
+ )
+
+ self.assertEqual(user_1, self.user_1)
+ self.assertEqual(user_1.comments__id__count, 2)
+ self.assertEqual(user_2, self.user_2)
+ self.assertEqual(user_2.comments__id__count, 1)
+ self.assertEqual(user_3, self.user_3)
+ self.assertEqual(user_3.comments__id__count, 3)
+
+ def test_users_annotated_with_aliased_comments_id_count(self):
+ user_1, user_2, user_3 = User.objects.annotate(
+ comments_count=Count("comments__id")
+ ).order_by("pk")
+
+ self.assertEqual(user_1, self.user_1)
+ self.assertEqual(user_1.comments_count, 2)
+ self.assertEqual(user_2, self.user_2)
+ self.assertEqual(user_2.comments_count, 1)
+ self.assertEqual(user_3, self.user_3)
+ self.assertEqual(user_3.comments_count, 3)
+
+ def test_users_annotated_with_comments_count(self):
+ user_1, user_2, user_3 = User.objects.annotate(Count("comments")).order_by("pk")
+
+ self.assertEqual(user_1, self.user_1)
+ self.assertEqual(user_1.comments__count, 2)
+ self.assertEqual(user_2, self.user_2)
+ self.assertEqual(user_2.comments__count, 1)
+ self.assertEqual(user_3, self.user_3)
+ self.assertEqual(user_3.comments__count, 3)
+
+ def test_users_annotated_with_comments_count_filter(self):
+ user_1, user_2, user_3 = User.objects.annotate(
+ comments__count=Count(
+ "comments", filter=Q(pk__in=[self.user_1.pk, self.user_2.pk])
+ )
+ ).order_by("pk")
+
+ self.assertEqual(user_1, self.user_1)
+ self.assertEqual(user_1.comments__count, 2)
+ self.assertEqual(user_2, self.user_2)
+ self.assertEqual(user_2.comments__count, 1)
+ self.assertEqual(user_3, self.user_3)
+ self.assertEqual(user_3.comments__count, 0)
+
+ def test_count_distinct_not_supported(self):
+ with self.assertRaisesMessage(
+ NotSupportedError, "COUNT(DISTINCT) doesn't support composite primary keys"
+ ):
+ self.assertIsNone(
+ User.objects.annotate(comments__count=Count("comments", distinct=True))
+ )
+
+ def test_user_values_annotated_with_comments_id_count(self):
+ self.assertSequenceEqual(
+ User.objects.values("pk").annotate(Count("comments__id")).order_by("pk"),
+ (
+ {"pk": self.user_1.pk, "comments__id__count": 2},
+ {"pk": self.user_2.pk, "comments__id__count": 1},
+ {"pk": self.user_3.pk, "comments__id__count": 3},
+ ),
+ )
+
+ def test_user_values_annotated_with_filtered_comments_id_count(self):
+ self.assertSequenceEqual(
+ User.objects.values("pk")
+ .annotate(
+ comments_count=Count(
+ "comments__id",
+ filter=Q(comments__text__icontains="foo"),
+ )
+ )
+ .order_by("pk"),
+ (
+ {"pk": self.user_1.pk, "comments_count": 1},
+ {"pk": self.user_2.pk, "comments_count": 1},
+ {"pk": self.user_3.pk, "comments_count": 1},
+ ),
+ )
+
+ def test_filter_and_count_users_by_comments_fields(self):
+ users = User.objects.filter(comments__id__gt=2).order_by("pk")
+ self.assertEqual(users.count(), 4)
+ self.assertSequenceEqual(
+ users, (self.user_1, self.user_3, self.user_3, self.user_3)
+ )
+
+ users = User.objects.filter(comments__text__icontains="foo").order_by("pk")
+ self.assertEqual(users.count(), 3)
+ self.assertSequenceEqual(users, (self.user_1, self.user_2, self.user_3))
+
+ users = User.objects.filter(comments__text__icontains="baz").order_by("pk")
+ self.assertEqual(users.count(), 3)
+ self.assertSequenceEqual(users, (self.user_3, self.user_3, self.user_3))
+
+ def test_order_by_comments_id_count(self):
+ self.assertSequenceEqual(
+ User.objects.annotate(comments_count=Count("comments__id")).order_by(
+ "-comments_count"
+ ),
+ (self.user_3, self.user_1, self.user_2),
+ )
diff --git a/tests/composite_pk/test_checks.py b/tests/composite_pk/test_checks.py
new file mode 100644
index 0000000000..02a162c31d
--- /dev/null
+++ b/tests/composite_pk/test_checks.py
@@ -0,0 +1,242 @@
+from django.core import checks
+from django.db import connection, models
+from django.db.models import F
+from django.test import TestCase
+from django.test.utils import isolate_apps
+
+
+@isolate_apps("composite_pk")
+class CompositePKChecksTests(TestCase):
+ maxDiff = None
+
+ def test_composite_pk_must_be_unique_strings(self):
+ test_cases = (
+ (),
+ (0,),
+ (1,),
+ ("id", False),
+ ("id", "id"),
+ (("id",),),
+ )
+
+ for i, args in enumerate(test_cases):
+ with (
+ self.subTest(args=args),
+ self.assertRaisesMessage(
+ ValueError, "CompositePrimaryKey args must be unique strings."
+ ),
+ ):
+ models.CompositePrimaryKey(*args)
+
+ def test_composite_pk_must_include_at_least_2_fields(self):
+ expected_message = "CompositePrimaryKey must include at least two fields."
+ with self.assertRaisesMessage(ValueError, expected_message):
+ models.CompositePrimaryKey("id")
+
+ def test_composite_pk_cannot_have_a_default(self):
+ expected_message = "CompositePrimaryKey cannot have a default."
+ with self.assertRaisesMessage(ValueError, expected_message):
+ models.CompositePrimaryKey("tenant_id", "id", default=(1, 1))
+
+ def test_composite_pk_cannot_have_a_database_default(self):
+ expected_message = "CompositePrimaryKey cannot have a database default."
+ with self.assertRaisesMessage(ValueError, expected_message):
+ models.CompositePrimaryKey("tenant_id", "id", db_default=models.F("id"))
+
+ def test_composite_pk_cannot_be_editable(self):
+ expected_message = "CompositePrimaryKey cannot be editable."
+ with self.assertRaisesMessage(ValueError, expected_message):
+ models.CompositePrimaryKey("tenant_id", "id", editable=True)
+
+ def test_composite_pk_must_be_a_primary_key(self):
+ expected_message = "CompositePrimaryKey must be a primary key."
+ with self.assertRaisesMessage(ValueError, expected_message):
+ models.CompositePrimaryKey("tenant_id", "id", primary_key=False)
+
+ def test_composite_pk_must_be_blank(self):
+ expected_message = "CompositePrimaryKey must be blank."
+ with self.assertRaisesMessage(ValueError, expected_message):
+ models.CompositePrimaryKey("tenant_id", "id", blank=False)
+
+ def test_composite_pk_must_not_have_other_pk_field(self):
+ class Foo(models.Model):
+ pk = models.CompositePrimaryKey("foo_id", "id")
+ foo_id = models.IntegerField()
+ id = models.IntegerField(primary_key=True)
+
+ self.assertEqual(
+ Foo.check(databases=self.databases),
+ [
+ checks.Error(
+ "The model cannot have more than one field with "
+ "'primary_key=True'.",
+ obj=Foo,
+ id="models.E026",
+ ),
+ ],
+ )
+
+ def test_composite_pk_cannot_include_nullable_field(self):
+ class Foo(models.Model):
+ pk = models.CompositePrimaryKey("foo_id", "id")
+ foo_id = models.IntegerField()
+ id = models.IntegerField(null=True)
+
+ self.assertEqual(
+ Foo.check(databases=self.databases),
+ [
+ checks.Error(
+ "'id' cannot be included in the composite primary key.",
+ hint="'id' field may not set 'null=True'.",
+ obj=Foo,
+ id="models.E042",
+ ),
+ ],
+ )
+
+ def test_composite_pk_can_include_fk_name(self):
+ class Foo(models.Model):
+ pass
+
+ class Bar(models.Model):
+ pk = models.CompositePrimaryKey("foo", "id")
+ foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
+ id = models.SmallIntegerField()
+
+ self.assertEqual(Foo.check(databases=self.databases), [])
+ self.assertEqual(Bar.check(databases=self.databases), [])
+
+ def test_composite_pk_cannot_include_same_field(self):
+ class Foo(models.Model):
+ pass
+
+ class Bar(models.Model):
+ pk = models.CompositePrimaryKey("foo", "foo_id")
+ foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
+ id = models.SmallIntegerField()
+
+ self.assertEqual(Foo.check(databases=self.databases), [])
+ self.assertEqual(
+ Bar.check(databases=self.databases),
+ [
+ checks.Error(
+ "'foo_id' cannot be included in the composite primary key.",
+ hint="'foo_id' and 'foo' are the same fields.",
+ obj=Bar,
+ id="models.E042",
+ ),
+ ],
+ )
+
+ def test_composite_pk_cannot_include_composite_pk_field(self):
+ class Foo(models.Model):
+ pk = models.CompositePrimaryKey("id", "pk")
+ id = models.SmallIntegerField()
+
+ self.assertEqual(
+ Foo.check(databases=self.databases),
+ [
+ checks.Error(
+ "'pk' cannot be included in the composite primary key.",
+ hint="'pk' field has no column.",
+ obj=Foo,
+ id="models.E042",
+ ),
+ ],
+ )
+
+ def test_composite_pk_cannot_include_db_column(self):
+ class Foo(models.Model):
+ pk = models.CompositePrimaryKey("foo", "bar")
+ foo = models.SmallIntegerField(db_column="foo_id")
+ bar = models.SmallIntegerField(db_column="bar_id")
+
+ class Bar(models.Model):
+ pk = models.CompositePrimaryKey("foo_id", "bar_id")
+ foo = models.SmallIntegerField(db_column="foo_id")
+ bar = models.SmallIntegerField(db_column="bar_id")
+
+ self.assertEqual(Foo.check(databases=self.databases), [])
+ self.assertEqual(
+ Bar.check(databases=self.databases),
+ [
+ checks.Error(
+ "'foo_id' cannot be included in the composite primary key.",
+ hint="'foo_id' is not a valid field.",
+ obj=Bar,
+ id="models.E042",
+ ),
+ checks.Error(
+ "'bar_id' cannot be included in the composite primary key.",
+ hint="'bar_id' is not a valid field.",
+ obj=Bar,
+ id="models.E042",
+ ),
+ ],
+ )
+
+ def test_foreign_object_can_refer_composite_pk(self):
+ class Foo(models.Model):
+ pass
+
+ class Bar(models.Model):
+ pk = models.CompositePrimaryKey("foo_id", "id")
+ foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
+ id = models.IntegerField()
+
+ class Baz(models.Model):
+ pk = models.CompositePrimaryKey("foo_id", "id")
+ foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
+ id = models.IntegerField()
+ bar_id = models.IntegerField()
+ bar = models.ForeignObject(
+ Bar,
+ on_delete=models.CASCADE,
+ from_fields=("foo_id", "bar_id"),
+ to_fields=("foo_id", "id"),
+ )
+
+ self.assertEqual(Foo.check(databases=self.databases), [])
+ self.assertEqual(Bar.check(databases=self.databases), [])
+ self.assertEqual(Baz.check(databases=self.databases), [])
+
+ def test_composite_pk_must_be_named_pk(self):
+ class Foo(models.Model):
+ primary_key = models.CompositePrimaryKey("foo_id", "id")
+ foo_id = models.IntegerField()
+ id = models.IntegerField()
+
+ self.assertEqual(
+ Foo.check(databases=self.databases),
+ [
+ checks.Error(
+ "'CompositePrimaryKey' must be named 'pk'.",
+ obj=Foo._meta.get_field("primary_key"),
+ id="fields.E013",
+ ),
+ ],
+ )
+
+ def test_composite_pk_cannot_include_generated_field(self):
+ is_oracle = connection.vendor == "oracle"
+
+ class Foo(models.Model):
+ pk = models.CompositePrimaryKey("id", "foo")
+ id = models.IntegerField()
+ foo = models.GeneratedField(
+ expression=F("id"),
+ output_field=models.IntegerField(),
+ db_persist=not is_oracle,
+ )
+
+ self.assertEqual(
+ Foo.check(databases=self.databases),
+ [
+ checks.Error(
+ "'foo' cannot be included in the composite primary key.",
+ hint="'foo' field is a generated field.",
+ obj=Foo,
+ id="models.E042",
+ ),
+ ],
+ )
diff --git a/tests/composite_pk/test_create.py b/tests/composite_pk/test_create.py
new file mode 100644
index 0000000000..7c9925b946
--- /dev/null
+++ b/tests/composite_pk/test_create.py
@@ -0,0 +1,138 @@
+from django.test import TestCase
+
+from .models import Tenant, User
+
+
+class CompositePKCreateTests(TestCase):
+ maxDiff = None
+
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant = Tenant.objects.create()
+ cls.user = User.objects.create(
+ tenant=cls.tenant,
+ id=1,
+ email="user0001@example.com",
+ )
+
+ def test_create_user(self):
+ test_cases = (
+ {"tenant": self.tenant, "id": 2412, "email": "user2412@example.com"},
+ {"tenant_id": self.tenant.id, "id": 5316, "email": "user5316@example.com"},
+ {"pk": (self.tenant.id, 7424), "email": "user7424@example.com"},
+ )
+
+ for fields in test_cases:
+ with self.subTest(fields=fields):
+ count = User.objects.count()
+ user = User(**fields)
+ obj = User.objects.create(**fields)
+ self.assertEqual(obj.tenant_id, self.tenant.id)
+ self.assertEqual(obj.id, user.id)
+ self.assertEqual(obj.pk, (self.tenant.id, user.id))
+ self.assertEqual(obj.email, user.email)
+ self.assertEqual(count + 1, User.objects.count())
+
+ def test_save_user(self):
+ test_cases = (
+ {"tenant": self.tenant, "id": 9241, "email": "user9241@example.com"},
+ {"tenant_id": self.tenant.id, "id": 5132, "email": "user5132@example.com"},
+ {"pk": (self.tenant.id, 3014), "email": "user3014@example.com"},
+ )
+
+ for fields in test_cases:
+ with self.subTest(fields=fields):
+ count = User.objects.count()
+ user = User(**fields)
+ self.assertIsNotNone(user.id)
+ self.assertIsNotNone(user.email)
+ user.save()
+ self.assertEqual(user.tenant_id, self.tenant.id)
+ self.assertEqual(user.tenant, self.tenant)
+ self.assertIsNotNone(user.id)
+ self.assertEqual(user.pk, (self.tenant.id, user.id))
+ self.assertEqual(user.email, fields["email"])
+ self.assertEqual(user.email, f"user{user.id}@example.com")
+ self.assertEqual(count + 1, User.objects.count())
+
+ def test_bulk_create_users(self):
+ objs = [
+ User(tenant=self.tenant, id=8291, email="user8291@example.com"),
+ User(tenant_id=self.tenant.id, id=4021, email="user4021@example.com"),
+ User(pk=(self.tenant.id, 8214), email="user8214@example.com"),
+ ]
+
+ obj_1, obj_2, obj_3 = User.objects.bulk_create(objs)
+
+ self.assertEqual(obj_1.tenant_id, self.tenant.id)
+ self.assertEqual(obj_1.id, 8291)
+ self.assertEqual(obj_1.pk, (obj_1.tenant_id, obj_1.id))
+ self.assertEqual(obj_1.email, "user8291@example.com")
+ self.assertEqual(obj_2.tenant_id, self.tenant.id)
+ self.assertEqual(obj_2.id, 4021)
+ self.assertEqual(obj_2.pk, (obj_2.tenant_id, obj_2.id))
+ self.assertEqual(obj_2.email, "user4021@example.com")
+ self.assertEqual(obj_3.tenant_id, self.tenant.id)
+ self.assertEqual(obj_3.id, 8214)
+ self.assertEqual(obj_3.pk, (obj_3.tenant_id, obj_3.id))
+ self.assertEqual(obj_3.email, "user8214@example.com")
+
+ def test_get_or_create_user(self):
+ test_cases = (
+ {
+ "pk": (self.tenant.id, 8314),
+ "defaults": {"email": "user8314@example.com"},
+ },
+ {
+ "tenant": self.tenant,
+ "id": 3142,
+ "defaults": {"email": "user3142@example.com"},
+ },
+ {
+ "tenant_id": self.tenant.id,
+ "id": 4218,
+ "defaults": {"email": "user4218@example.com"},
+ },
+ )
+
+ for fields in test_cases:
+ with self.subTest(fields=fields):
+ count = User.objects.count()
+ user, created = User.objects.get_or_create(**fields)
+ self.assertIs(created, True)
+ self.assertIsNotNone(user.id)
+ self.assertEqual(user.pk, (self.tenant.id, user.id))
+ self.assertEqual(user.tenant_id, self.tenant.id)
+ self.assertEqual(user.email, fields["defaults"]["email"])
+ self.assertEqual(user.email, f"user{user.id}@example.com")
+ self.assertEqual(count + 1, User.objects.count())
+
+ def test_update_or_create_user(self):
+ test_cases = (
+ {
+ "pk": (self.tenant.id, 2931),
+ "defaults": {"email": "user2931@example.com"},
+ },
+ {
+ "tenant": self.tenant,
+ "id": 6428,
+ "defaults": {"email": "user6428@example.com"},
+ },
+ {
+ "tenant_id": self.tenant.id,
+ "id": 5278,
+ "defaults": {"email": "user5278@example.com"},
+ },
+ )
+
+ for fields in test_cases:
+ with self.subTest(fields=fields):
+ count = User.objects.count()
+ user, created = User.objects.update_or_create(**fields)
+ self.assertIs(created, True)
+ self.assertIsNotNone(user.id)
+ self.assertEqual(user.pk, (self.tenant.id, user.id))
+ self.assertEqual(user.tenant_id, self.tenant.id)
+ self.assertEqual(user.email, fields["defaults"]["email"])
+ self.assertEqual(user.email, f"user{user.id}@example.com")
+ self.assertEqual(count + 1, User.objects.count())
diff --git a/tests/composite_pk/test_delete.py b/tests/composite_pk/test_delete.py
new file mode 100644
index 0000000000..9a14deb813
--- /dev/null
+++ b/tests/composite_pk/test_delete.py
@@ -0,0 +1,83 @@
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class CompositePKDeleteTests(TestCase):
+ maxDiff = None
+
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant_1 = Tenant.objects.create()
+ cls.tenant_2 = Tenant.objects.create()
+ cls.user_1 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=1,
+ email="user0001@example.com",
+ )
+ cls.user_2 = User.objects.create(
+ tenant=cls.tenant_2,
+ id=2,
+ email="user0002@example.com",
+ )
+ cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+ cls.comment_2 = Comment.objects.create(id=2, user=cls.user_2)
+ cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
+
+ def test_delete_tenant_by_pk(self):
+ result = Tenant.objects.filter(pk=self.tenant_1.pk).delete()
+
+ self.assertEqual(
+ result,
+ (
+ 3,
+ {
+ "composite_pk.Comment": 1,
+ "composite_pk.User": 1,
+ "composite_pk.Tenant": 1,
+ },
+ ),
+ )
+
+ self.assertIs(Tenant.objects.filter(pk=self.tenant_1.pk).exists(), False)
+ self.assertIs(Tenant.objects.filter(pk=self.tenant_2.pk).exists(), True)
+ self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False)
+ self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True)
+ self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False)
+ self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True)
+ self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True)
+
+ def test_delete_user_by_pk(self):
+ result = User.objects.filter(pk=self.user_1.pk).delete()
+
+ self.assertEqual(
+ result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1})
+ )
+
+ self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False)
+ self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True)
+ self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False)
+ self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True)
+ self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True)
+
+ def test_delete_comments_by_user(self):
+ result = Comment.objects.filter(user=self.user_2).delete()
+
+ self.assertEqual(result, (2, {"composite_pk.Comment": 2}))
+
+ self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), True)
+ self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), False)
+ self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), False)
+
+ def test_delete_without_pk(self):
+ msg = (
+ "Comment object can't be deleted because its pk attribute is set "
+ "to None."
+ )
+
+ with self.assertRaisesMessage(ValueError, msg):
+ Comment().delete()
+ with self.assertRaisesMessage(ValueError, msg):
+ Comment(tenant_id=1).delete()
+ with self.assertRaisesMessage(ValueError, msg):
+ Comment(id=1).delete()
diff --git a/tests/composite_pk/test_filter.py b/tests/composite_pk/test_filter.py
new file mode 100644
index 0000000000..7e361c5925
--- /dev/null
+++ b/tests/composite_pk/test_filter.py
@@ -0,0 +1,412 @@
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class CompositePKFilterTests(TestCase):
+ maxDiff = None
+
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant_1 = Tenant.objects.create()
+ cls.tenant_2 = Tenant.objects.create()
+ cls.tenant_3 = Tenant.objects.create()
+ cls.user_1 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=1,
+ email="user0001@example.com",
+ )
+ cls.user_2 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=2,
+ email="user0002@example.com",
+ )
+ cls.user_3 = User.objects.create(
+ tenant=cls.tenant_2,
+ id=3,
+ email="user0003@example.com",
+ )
+ cls.user_4 = User.objects.create(
+ tenant=cls.tenant_3,
+ id=4,
+ email="user0004@example.com",
+ )
+ cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+ cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
+ cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
+ cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3)
+ cls.comment_5 = Comment.objects.create(id=5, user=cls.user_1)
+
+ def test_filter_and_count_user_by_pk(self):
+ test_cases = (
+ ({"pk": self.user_1.pk}, 1),
+ ({"pk": self.user_2.pk}, 1),
+ ({"pk": self.user_3.pk}, 1),
+ ({"pk": (self.tenant_1.id, self.user_1.id)}, 1),
+ ({"pk": (self.tenant_1.id, self.user_2.id)}, 1),
+ ({"pk": (self.tenant_2.id, self.user_3.id)}, 1),
+ ({"pk": (self.tenant_1.id, self.user_3.id)}, 0),
+ ({"pk": (self.tenant_2.id, self.user_1.id)}, 0),
+ ({"pk": (self.tenant_2.id, self.user_2.id)}, 0),
+ )
+
+ for lookup, count in test_cases:
+ with self.subTest(lookup=lookup, count=count):
+ self.assertEqual(User.objects.filter(**lookup).count(), count)
+
+ def test_order_comments_by_pk_asc(self):
+ self.assertSequenceEqual(
+ Comment.objects.order_by("pk"),
+ (
+ self.comment_1, # (1, 1)
+ self.comment_2, # (1, 2)
+ self.comment_3, # (1, 3)
+ self.comment_5, # (1, 5)
+ self.comment_4, # (2, 4)
+ ),
+ )
+
+ def test_order_comments_by_pk_desc(self):
+ self.assertSequenceEqual(
+ Comment.objects.order_by("-pk"),
+ (
+ self.comment_4, # (2, 4)
+ self.comment_5, # (1, 5)
+ self.comment_3, # (1, 3)
+ self.comment_2, # (1, 2)
+ self.comment_1, # (1, 1)
+ ),
+ )
+
+ def test_filter_comments_by_pk_gt(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ test_cases = (
+ (c11, (c12, c13, c15, c24)),
+ (c12, (c13, c15, c24)),
+ (c13, (c15, c24)),
+ (c15, (c24,)),
+ (c24, ()),
+ )
+
+ for obj, objs in test_cases:
+ with self.subTest(obj=obj, objs=objs):
+ self.assertSequenceEqual(
+ Comment.objects.filter(pk__gt=obj.pk).order_by("pk"), objs
+ )
+
+ def test_filter_comments_by_pk_gte(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ test_cases = (
+ (c11, (c11, c12, c13, c15, c24)),
+ (c12, (c12, c13, c15, c24)),
+ (c13, (c13, c15, c24)),
+ (c15, (c15, c24)),
+ (c24, (c24,)),
+ )
+
+ for obj, objs in test_cases:
+ with self.subTest(obj=obj, objs=objs):
+ self.assertSequenceEqual(
+ Comment.objects.filter(pk__gte=obj.pk).order_by("pk"), objs
+ )
+
+ def test_filter_comments_by_pk_lt(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ test_cases = (
+ (c24, (c11, c12, c13, c15)),
+ (c15, (c11, c12, c13)),
+ (c13, (c11, c12)),
+ (c12, (c11,)),
+ (c11, ()),
+ )
+
+ for obj, objs in test_cases:
+ with self.subTest(obj=obj, objs=objs):
+ self.assertSequenceEqual(
+ Comment.objects.filter(pk__lt=obj.pk).order_by("pk"), objs
+ )
+
+ def test_filter_comments_by_pk_lte(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ test_cases = (
+ (c24, (c11, c12, c13, c15, c24)),
+ (c15, (c11, c12, c13, c15)),
+ (c13, (c11, c12, c13)),
+ (c12, (c11, c12)),
+ (c11, (c11,)),
+ )
+
+ for obj, objs in test_cases:
+ with self.subTest(obj=obj, objs=objs):
+ self.assertSequenceEqual(
+ Comment.objects.filter(pk__lte=obj.pk).order_by("pk"), objs
+ )
+
+ def test_filter_comments_by_pk_in(self):
+ test_cases = (
+ (),
+ (self.comment_1,),
+ (self.comment_1, self.comment_4),
+ )
+
+ for objs in test_cases:
+ with self.subTest(objs=objs):
+ pks = [obj.pk for obj in objs]
+ self.assertSequenceEqual(
+ Comment.objects.filter(pk__in=pks).order_by("pk"), objs
+ )
+
+ def test_filter_comments_by_user_and_order_by_pk_asc(self):
+ self.assertSequenceEqual(
+ Comment.objects.filter(user=self.user_1).order_by("pk"),
+ (self.comment_1, self.comment_2, self.comment_5),
+ )
+
+ def test_filter_comments_by_user_and_order_by_pk_desc(self):
+ self.assertSequenceEqual(
+ Comment.objects.filter(user=self.user_1).order_by("-pk"),
+ (self.comment_5, self.comment_2, self.comment_1),
+ )
+
+ def test_filter_comments_by_user_and_exclude_by_pk(self):
+ self.assertSequenceEqual(
+ Comment.objects.filter(user=self.user_1)
+ .exclude(pk=self.comment_1.pk)
+ .order_by("pk"),
+ (self.comment_2, self.comment_5),
+ )
+
+ def test_filter_comments_by_user_and_contains(self):
+ self.assertIs(
+ Comment.objects.filter(user=self.user_1).contains(self.comment_1), True
+ )
+
+ def test_filter_users_by_comments_in(self):
+ c1, c2, c3, c4, c5 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ u1, u2, u3 = (
+ self.user_1,
+ self.user_2,
+ self.user_3,
+ )
+ test_cases = (
+ ((), ()),
+ ((c1,), (u1,)),
+ ((c1, c2), (u1, u1)),
+ ((c1, c2, c3), (u1, u1, u2)),
+ ((c1, c2, c3, c4), (u1, u1, u2, u3)),
+ ((c1, c2, c3, c4, c5), (u1, u1, u1, u2, u3)),
+ )
+
+ for comments, users in test_cases:
+ with self.subTest(comments=comments, users=users):
+ self.assertSequenceEqual(
+ User.objects.filter(comments__in=comments).order_by("pk"), users
+ )
+
+ def test_filter_users_by_comments_lt(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ u1, u2 = (
+ self.user_1,
+ self.user_2,
+ )
+ test_cases = (
+ (c11, ()),
+ (c12, (u1,)),
+ (c13, (u1, u1)),
+ (c15, (u1, u1, u2)),
+ (c24, (u1, u1, u1, u2)),
+ )
+
+ for comment, users in test_cases:
+ with self.subTest(comment=comment, users=users):
+ self.assertSequenceEqual(
+ User.objects.filter(comments__lt=comment).order_by("pk"), users
+ )
+
+ def test_filter_users_by_comments_lte(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ u1, u2, u3 = (
+ self.user_1,
+ self.user_2,
+ self.user_3,
+ )
+ test_cases = (
+ (c11, (u1,)),
+ (c12, (u1, u1)),
+ (c13, (u1, u1, u2)),
+ (c15, (u1, u1, u1, u2)),
+ (c24, (u1, u1, u1, u2, u3)),
+ )
+
+ for comment, users in test_cases:
+ with self.subTest(comment=comment, users=users):
+ self.assertSequenceEqual(
+ User.objects.filter(comments__lte=comment).order_by("pk"), users
+ )
+
+ def test_filter_users_by_comments_gt(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ u1, u2, u3 = (
+ self.user_1,
+ self.user_2,
+ self.user_3,
+ )
+ test_cases = (
+ (c11, (u1, u1, u2, u3)),
+ (c12, (u1, u2, u3)),
+ (c13, (u1, u3)),
+ (c15, (u3,)),
+ (c24, ()),
+ )
+
+ for comment, users in test_cases:
+ with self.subTest(comment=comment, users=users):
+ self.assertSequenceEqual(
+ User.objects.filter(comments__gt=comment).order_by("pk"), users
+ )
+
+ def test_filter_users_by_comments_gte(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ u1, u2, u3 = (
+ self.user_1,
+ self.user_2,
+ self.user_3,
+ )
+ test_cases = (
+ (c11, (u1, u1, u1, u2, u3)),
+ (c12, (u1, u1, u2, u3)),
+ (c13, (u1, u2, u3)),
+ (c15, (u1, u3)),
+ (c24, (u3,)),
+ )
+
+ for comment, users in test_cases:
+ with self.subTest(comment=comment, users=users):
+ self.assertSequenceEqual(
+ User.objects.filter(comments__gte=comment).order_by("pk"), users
+ )
+
+ def test_filter_users_by_comments_exact(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+ u1, u2, u3 = (
+ self.user_1,
+ self.user_2,
+ self.user_3,
+ )
+ test_cases = (
+ (c11, (u1,)),
+ (c12, (u1,)),
+ (c13, (u2,)),
+ (c15, (u1,)),
+ (c24, (u3,)),
+ )
+
+ for comment, users in test_cases:
+ with self.subTest(comment=comment, users=users):
+ self.assertSequenceEqual(
+ User.objects.filter(comments=comment).order_by("pk"), users
+ )
+
+ def test_filter_users_by_comments_isnull(self):
+ u1, u2, u3, u4 = (
+ self.user_1,
+ self.user_2,
+ self.user_3,
+ self.user_4,
+ )
+
+ with self.subTest("comments__isnull=True"):
+ self.assertSequenceEqual(
+ User.objects.filter(comments__isnull=True).order_by("pk"),
+ (u4,),
+ )
+ with self.subTest("comments__isnull=False"):
+ self.assertSequenceEqual(
+ User.objects.filter(comments__isnull=False).order_by("pk"),
+ (u1, u1, u1, u2, u3),
+ )
+
+ def test_filter_comments_by_pk_isnull(self):
+ c11, c12, c13, c24, c15 = (
+ self.comment_1,
+ self.comment_2,
+ self.comment_3,
+ self.comment_4,
+ self.comment_5,
+ )
+
+ with self.subTest("pk__isnull=True"):
+ self.assertSequenceEqual(
+ Comment.objects.filter(pk__isnull=True).order_by("pk"),
+ (),
+ )
+ with self.subTest("pk__isnull=False"):
+ self.assertSequenceEqual(
+ Comment.objects.filter(pk__isnull=False).order_by("pk"),
+ (c11, c12, c13, c15, c24),
+ )
+
+ def test_filter_users_by_comments_subquery(self):
+ subquery = Comment.objects.filter(id=3).only("pk")
+ queryset = User.objects.filter(comments__in=subquery)
+ self.assertSequenceEqual(queryset, (self.user_2,))
diff --git a/tests/composite_pk/test_get.py b/tests/composite_pk/test_get.py
new file mode 100644
index 0000000000..c896ec26ed
--- /dev/null
+++ b/tests/composite_pk/test_get.py
@@ -0,0 +1,126 @@
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class CompositePKGetTests(TestCase):
+ maxDiff = None
+
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant_1 = Tenant.objects.create()
+ cls.tenant_2 = Tenant.objects.create()
+ cls.user_1 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=1,
+ email="user0001@example.com",
+ )
+ cls.user_2 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=2,
+ email="user0002@example.com",
+ )
+ cls.user_3 = User.objects.create(
+ tenant=cls.tenant_2,
+ id=3,
+ email="user0003@example.com",
+ )
+ cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+
+ def test_get_user(self):
+ test_cases = (
+ {"pk": self.user_1.pk},
+ {"pk": (self.tenant_1.id, self.user_1.id)},
+ {"id": self.user_1.id},
+ )
+
+ for lookup in test_cases:
+ with self.subTest(lookup=lookup):
+ self.assertEqual(User.objects.get(**lookup), self.user_1)
+
+ def test_get_comment(self):
+ test_cases = (
+ {"pk": self.comment_1.pk},
+ {"pk": (self.tenant_1.id, self.comment_1.id)},
+ {"id": self.comment_1.id},
+ {"user": self.user_1},
+ {"user_id": self.user_1.id},
+ {"user__id": self.user_1.id},
+ {"user__pk": self.user_1.pk},
+ {"tenant": self.tenant_1},
+ {"tenant_id": self.tenant_1.id},
+ {"tenant__id": self.tenant_1.id},
+ {"tenant__pk": self.tenant_1.pk},
+ )
+
+ for lookup in test_cases:
+ with self.subTest(lookup=lookup):
+ self.assertEqual(Comment.objects.get(**lookup), self.comment_1)
+
+ def test_get_or_create_user(self):
+ test_cases = (
+ {
+ "pk": self.user_1.pk,
+ "defaults": {"email": "user9201@example.com"},
+ },
+ {
+ "pk": (self.tenant_1.id, self.user_1.id),
+ "defaults": {"email": "user9201@example.com"},
+ },
+ {
+ "tenant": self.tenant_1,
+ "id": self.user_1.id,
+ "defaults": {"email": "user3512@example.com"},
+ },
+ {
+ "tenant_id": self.tenant_1.id,
+ "id": self.user_1.id,
+ "defaults": {"email": "user8239@example.com"},
+ },
+ )
+
+ for fields in test_cases:
+ with self.subTest(fields=fields):
+ count = User.objects.count()
+ user, created = User.objects.get_or_create(**fields)
+ self.assertIs(created, False)
+ self.assertEqual(user.id, self.user_1.id)
+ self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id))
+ self.assertEqual(user.tenant_id, self.tenant_1.id)
+ self.assertEqual(user.email, self.user_1.email)
+ self.assertEqual(count, User.objects.count())
+
+ def test_lookup_errors(self):
+ m_tuple = "'%s' lookup of 'pk' must be a tuple or a list"
+ m_2_elements = "'%s' lookup of 'pk' must have 2 elements"
+ m_tuple_collection = (
+ "'in' lookup of 'pk' must be a collection of tuples or lists"
+ )
+ m_2_elements_each = "'in' lookup of 'pk' must have 2 elements each"
+ test_cases = (
+ ({"pk": 1}, m_tuple % "exact"),
+ ({"pk": (1, 2, 3)}, m_2_elements % "exact"),
+ ({"pk__exact": 1}, m_tuple % "exact"),
+ ({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"),
+ ({"pk__in": 1}, m_tuple % "in"),
+ ({"pk__in": (1, 2, 3)}, m_tuple_collection),
+ ({"pk__in": ((1, 2, 3),)}, m_2_elements_each),
+ ({"pk__gt": 1}, m_tuple % "gt"),
+ ({"pk__gt": (1, 2, 3)}, m_2_elements % "gt"),
+ ({"pk__gte": 1}, m_tuple % "gte"),
+ ({"pk__gte": (1, 2, 3)}, m_2_elements % "gte"),
+ ({"pk__lt": 1}, m_tuple % "lt"),
+ ({"pk__lt": (1, 2, 3)}, m_2_elements % "lt"),
+ ({"pk__lte": 1}, m_tuple % "lte"),
+ ({"pk__lte": (1, 2, 3)}, m_2_elements % "lte"),
+ )
+
+ for kwargs, message in test_cases:
+ with (
+ self.subTest(kwargs=kwargs),
+ self.assertRaisesMessage(ValueError, message),
+ ):
+ Comment.objects.get(**kwargs)
+
+ def test_get_user_by_comments(self):
+ self.assertEqual(User.objects.get(comments=self.comment_1), self.user_1)
diff --git a/tests/composite_pk/test_models.py b/tests/composite_pk/test_models.py
new file mode 100644
index 0000000000..ca6ad8b5dc
--- /dev/null
+++ b/tests/composite_pk/test_models.py
@@ -0,0 +1,153 @@
+from django.contrib.contenttypes.models import ContentType
+from django.core.exceptions import ValidationError
+from django.test import TestCase
+
+from .models import Comment, Tenant, Token, User
+
+
+class CompositePKModelsTests(TestCase):
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant_1 = Tenant.objects.create()
+ cls.tenant_2 = Tenant.objects.create()
+ cls.user_1 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=1,
+ email="user0001@example.com",
+ )
+ cls.user_2 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=2,
+ email="user0002@example.com",
+ )
+ cls.user_3 = User.objects.create(
+ tenant=cls.tenant_2,
+ id=3,
+ email="user0003@example.com",
+ )
+ cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+ cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
+ cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
+ cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3)
+
+ def test_fields(self):
+ # tenant_1
+ self.assertSequenceEqual(
+ self.tenant_1.user_set.order_by("pk"),
+ [self.user_1, self.user_2],
+ )
+ self.assertSequenceEqual(
+ self.tenant_1.comments.order_by("pk"),
+ [self.comment_1, self.comment_2, self.comment_3],
+ )
+
+ # tenant_2
+ self.assertSequenceEqual(self.tenant_2.user_set.order_by("pk"), [self.user_3])
+ self.assertSequenceEqual(
+ self.tenant_2.comments.order_by("pk"), [self.comment_4]
+ )
+
+ # user_1
+ self.assertEqual(self.user_1.id, 1)
+ self.assertEqual(self.user_1.tenant_id, self.tenant_1.id)
+ self.assertEqual(self.user_1.tenant, self.tenant_1)
+ self.assertEqual(self.user_1.pk, (self.tenant_1.id, self.user_1.id))
+ self.assertSequenceEqual(
+ self.user_1.comments.order_by("pk"), [self.comment_1, self.comment_2]
+ )
+
+ # user_2
+ self.assertEqual(self.user_2.id, 2)
+ self.assertEqual(self.user_2.tenant_id, self.tenant_1.id)
+ self.assertEqual(self.user_2.tenant, self.tenant_1)
+ self.assertEqual(self.user_2.pk, (self.tenant_1.id, self.user_2.id))
+ self.assertSequenceEqual(self.user_2.comments.order_by("pk"), [self.comment_3])
+
+ # comment_1
+ self.assertEqual(self.comment_1.id, 1)
+ self.assertEqual(self.comment_1.user_id, self.user_1.id)
+ self.assertEqual(self.comment_1.user, self.user_1)
+ self.assertEqual(self.comment_1.tenant_id, self.tenant_1.id)
+ self.assertEqual(self.comment_1.tenant, self.tenant_1)
+ self.assertEqual(self.comment_1.pk, (self.tenant_1.id, self.user_1.id))
+
+ def test_full_clean_success(self):
+ test_cases = (
+ # 1, 1234, {}
+ ({"tenant": self.tenant_1, "id": 1234}, {}),
+ ({"tenant_id": self.tenant_1.id, "id": 1234}, {}),
+ ({"pk": (self.tenant_1.id, 1234)}, {}),
+ # 1, 1, {"id"}
+ ({"tenant": self.tenant_1, "id": 1}, {"id"}),
+ ({"tenant_id": self.tenant_1.id, "id": 1}, {"id"}),
+ ({"pk": (self.tenant_1.id, 1)}, {"id"}),
+ # 1, 1, {"tenant", "id"}
+ ({"tenant": self.tenant_1, "id": 1}, {"tenant", "id"}),
+ ({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant", "id"}),
+ ({"pk": (self.tenant_1.id, 1)}, {"tenant", "id"}),
+ )
+
+ for kwargs, exclude in test_cases:
+ with self.subTest(kwargs):
+ kwargs["email"] = "user0004@example.com"
+ User(**kwargs).full_clean(exclude=exclude)
+
+ def test_full_clean_failure(self):
+ e_tenant_and_id = "User with this Tenant and Id already exists."
+ e_id = "User with this Id already exists."
+ test_cases = (
+ # 1, 1, {}
+ ({"tenant": self.tenant_1, "id": 1}, {}, (e_tenant_and_id, e_id)),
+ ({"tenant_id": self.tenant_1.id, "id": 1}, {}, (e_tenant_and_id, e_id)),
+ ({"pk": (self.tenant_1.id, 1)}, {}, (e_tenant_and_id, e_id)),
+ # 2, 1, {}
+ ({"tenant": self.tenant_2, "id": 1}, {}, (e_id,)),
+ ({"tenant_id": self.tenant_2.id, "id": 1}, {}, (e_id,)),
+ ({"pk": (self.tenant_2.id, 1)}, {}, (e_id,)),
+ # 1, 1, {"tenant"}
+ ({"tenant": self.tenant_1, "id": 1}, {"tenant"}, (e_id,)),
+ ({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant"}, (e_id,)),
+ ({"pk": (self.tenant_1.id, 1)}, {"tenant"}, (e_id,)),
+ )
+
+ for kwargs, exclude, messages in test_cases:
+ with self.subTest(kwargs):
+ with self.assertRaises(ValidationError) as ctx:
+ kwargs["email"] = "user0004@example.com"
+ User(**kwargs).full_clean(exclude=exclude)
+
+ self.assertSequenceEqual(ctx.exception.messages, messages)
+
+ def test_field_conflicts(self):
+ test_cases = (
+ ({"pk": (1, 1), "id": 2}, (1, 1)),
+ ({"id": 2, "pk": (1, 1)}, (1, 1)),
+ ({"pk": (1, 1), "tenant_id": 2}, (1, 1)),
+ ({"tenant_id": 2, "pk": (1, 1)}, (1, 1)),
+ ({"pk": (2, 2), "tenant_id": 3, "id": 4}, (2, 2)),
+ ({"tenant_id": 3, "id": 4, "pk": (2, 2)}, (2, 2)),
+ )
+
+ for kwargs, pk in test_cases:
+ with self.subTest(kwargs=kwargs):
+ user = User(**kwargs)
+ self.assertEqual(user.pk, pk)
+
+ def test_validate_unique(self):
+ user = User.objects.get(pk=self.user_1.pk)
+ user.id = None
+
+ with self.assertRaises(ValidationError) as ctx:
+ user.validate_unique()
+
+ self.assertSequenceEqual(
+ ctx.exception.messages, ("User with this Email already exists.",)
+ )
+
+ def test_permissions(self):
+ token = ContentType.objects.get_for_model(Token)
+ user = ContentType.objects.get_for_model(User)
+ comment = ContentType.objects.get_for_model(Comment)
+ self.assertEqual(4, token.permission_set.count())
+ self.assertEqual(4, user.permission_set.count())
+ self.assertEqual(4, comment.permission_set.count())
diff --git a/tests/composite_pk/test_names_to_path.py b/tests/composite_pk/test_names_to_path.py
new file mode 100644
index 0000000000..de4a04f4cb
--- /dev/null
+++ b/tests/composite_pk/test_names_to_path.py
@@ -0,0 +1,134 @@
+from django.db.models.query_utils import PathInfo
+from django.db.models.sql import Query
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class NamesToPathTests(TestCase):
+ def test_id(self):
+ query = Query(User)
+ path, final_field, targets, rest = query.names_to_path(["id"], User._meta)
+
+ self.assertEqual(path, [])
+ self.assertEqual(final_field, User._meta.get_field("id"))
+ self.assertEqual(targets, (User._meta.get_field("id"),))
+ self.assertEqual(rest, [])
+
+ def test_pk(self):
+ query = Query(User)
+ path, final_field, targets, rest = query.names_to_path(["pk"], User._meta)
+
+ self.assertEqual(path, [])
+ self.assertEqual(final_field, User._meta.get_field("pk"))
+ self.assertEqual(targets, (User._meta.get_field("pk"),))
+ self.assertEqual(rest, [])
+
+ def test_tenant_id(self):
+ query = Query(User)
+ path, final_field, targets, rest = query.names_to_path(
+ ["tenant", "id"], User._meta
+ )
+
+ self.assertEqual(
+ path,
+ [
+ PathInfo(
+ from_opts=User._meta,
+ to_opts=Tenant._meta,
+ target_fields=(Tenant._meta.get_field("id"),),
+ join_field=User._meta.get_field("tenant"),
+ m2m=False,
+ direct=True,
+ filtered_relation=None,
+ ),
+ ],
+ )
+ self.assertEqual(final_field, Tenant._meta.get_field("id"))
+ self.assertEqual(targets, (Tenant._meta.get_field("id"),))
+ self.assertEqual(rest, [])
+
+ def test_user_id(self):
+ query = Query(Comment)
+ path, final_field, targets, rest = query.names_to_path(
+ ["user", "id"], Comment._meta
+ )
+
+ self.assertEqual(
+ path,
+ [
+ PathInfo(
+ from_opts=Comment._meta,
+ to_opts=User._meta,
+ target_fields=(
+ User._meta.get_field("tenant"),
+ User._meta.get_field("id"),
+ ),
+ join_field=Comment._meta.get_field("user"),
+ m2m=False,
+ direct=True,
+ filtered_relation=None,
+ ),
+ ],
+ )
+ self.assertEqual(final_field, User._meta.get_field("id"))
+ self.assertEqual(targets, (User._meta.get_field("id"),))
+ self.assertEqual(rest, [])
+
+ def test_user_tenant_id(self):
+ query = Query(Comment)
+ path, final_field, targets, rest = query.names_to_path(
+ ["user", "tenant", "id"], Comment._meta
+ )
+
+ self.assertEqual(
+ path,
+ [
+ PathInfo(
+ from_opts=Comment._meta,
+ to_opts=User._meta,
+ target_fields=(
+ User._meta.get_field("tenant"),
+ User._meta.get_field("id"),
+ ),
+ join_field=Comment._meta.get_field("user"),
+ m2m=False,
+ direct=True,
+ filtered_relation=None,
+ ),
+ PathInfo(
+ from_opts=User._meta,
+ to_opts=Tenant._meta,
+ target_fields=(Tenant._meta.get_field("id"),),
+ join_field=User._meta.get_field("tenant"),
+ m2m=False,
+ direct=True,
+ filtered_relation=None,
+ ),
+ ],
+ )
+ self.assertEqual(final_field, Tenant._meta.get_field("id"))
+ self.assertEqual(targets, (Tenant._meta.get_field("id"),))
+ self.assertEqual(rest, [])
+
+ def test_comments(self):
+ query = Query(User)
+ path, final_field, targets, rest = query.names_to_path(["comments"], User._meta)
+
+ self.assertEqual(
+ path,
+ [
+ PathInfo(
+ from_opts=User._meta,
+ to_opts=Comment._meta,
+ target_fields=(Comment._meta.get_field("pk"),),
+ join_field=User._meta.get_field("comments"),
+ m2m=True,
+ direct=False,
+ filtered_relation=None,
+ ),
+ ],
+ )
+ self.assertEqual(final_field, User._meta.get_field("comments"))
+ self.assertEqual(targets, (Comment._meta.get_field("pk"),))
+ self.assertEqual(rest, [])
diff --git a/tests/composite_pk/test_update.py b/tests/composite_pk/test_update.py
new file mode 100644
index 0000000000..e711745447
--- /dev/null
+++ b/tests/composite_pk/test_update.py
@@ -0,0 +1,135 @@
+from django.test import TestCase
+
+from .models import Comment, Tenant, Token, User
+
+
+class CompositePKUpdateTests(TestCase):
+ maxDiff = None
+
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant_1 = Tenant.objects.create(name="A")
+ cls.tenant_2 = Tenant.objects.create(name="B")
+ cls.user_1 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=1,
+ email="user0001@example.com",
+ )
+ cls.user_2 = User.objects.create(
+ tenant=cls.tenant_1,
+ id=2,
+ email="user0002@example.com",
+ )
+ cls.user_3 = User.objects.create(
+ tenant=cls.tenant_2,
+ id=3,
+ email="user0003@example.com",
+ )
+ cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+ cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
+ cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
+ cls.token_1 = Token.objects.create(id=1, tenant=cls.tenant_1)
+ cls.token_2 = Token.objects.create(id=2, tenant=cls.tenant_2)
+ cls.token_3 = Token.objects.create(id=3, tenant=cls.tenant_1)
+ cls.token_4 = Token.objects.create(id=4, tenant=cls.tenant_2)
+
+ def test_update_user(self):
+ email = "user9315@example.com"
+ result = User.objects.filter(pk=self.user_1.pk).update(email=email)
+ self.assertEqual(result, 1)
+ user = User.objects.get(pk=self.user_1.pk)
+ self.assertEqual(user.email, email)
+
+ def test_save_user(self):
+ count = User.objects.count()
+ email = "user9314@example.com"
+ user = User.objects.get(pk=self.user_1.pk)
+ user.email = email
+ user.save()
+ user.refresh_from_db()
+ self.assertEqual(user.email, email)
+ user = User.objects.get(pk=self.user_1.pk)
+ self.assertEqual(user.email, email)
+ self.assertEqual(count, User.objects.count())
+
+ def test_bulk_update_comments(self):
+ comment_1 = Comment.objects.get(pk=self.comment_1.pk)
+ comment_2 = Comment.objects.get(pk=self.comment_2.pk)
+ comment_3 = Comment.objects.get(pk=self.comment_3.pk)
+ comment_1.text = "foo"
+ comment_2.text = "bar"
+ comment_3.text = "baz"
+
+ result = Comment.objects.bulk_update(
+ [comment_1, comment_2, comment_3], ["text"]
+ )
+
+ self.assertEqual(result, 3)
+ comment_1 = Comment.objects.get(pk=self.comment_1.pk)
+ comment_2 = Comment.objects.get(pk=self.comment_2.pk)
+ comment_3 = Comment.objects.get(pk=self.comment_3.pk)
+ self.assertEqual(comment_1.text, "foo")
+ self.assertEqual(comment_2.text, "bar")
+ self.assertEqual(comment_3.text, "baz")
+
+ def test_update_or_create_user(self):
+ test_cases = (
+ {
+ "pk": self.user_1.pk,
+ "defaults": {"email": "user3914@example.com"},
+ },
+ {
+ "pk": (self.tenant_1.id, self.user_1.id),
+ "defaults": {"email": "user9375@example.com"},
+ },
+ {
+ "tenant": self.tenant_1,
+ "id": self.user_1.id,
+ "defaults": {"email": "user3517@example.com"},
+ },
+ {
+ "tenant_id": self.tenant_1.id,
+ "id": self.user_1.id,
+ "defaults": {"email": "user8391@example.com"},
+ },
+ )
+
+ for fields in test_cases:
+ with self.subTest(fields=fields):
+ count = User.objects.count()
+ user, created = User.objects.update_or_create(**fields)
+ self.assertIs(created, False)
+ self.assertEqual(user.id, self.user_1.id)
+ self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id))
+ self.assertEqual(user.tenant_id, self.tenant_1.id)
+ self.assertEqual(user.email, fields["defaults"]["email"])
+ self.assertEqual(count, User.objects.count())
+
+ def test_update_comment_by_user_email(self):
+ result = Comment.objects.filter(user__email=self.user_1.email).update(
+ text="foo"
+ )
+
+ self.assertEqual(result, 2)
+ comment_1 = Comment.objects.get(pk=self.comment_1.pk)
+ comment_2 = Comment.objects.get(pk=self.comment_2.pk)
+ self.assertEqual(comment_1.text, "foo")
+ self.assertEqual(comment_2.text, "foo")
+
+ def test_update_token_by_tenant_name(self):
+ result = Token.objects.filter(tenant__name="A").update(secret="bar")
+
+ self.assertEqual(result, 2)
+ token_1 = Token.objects.get(pk=self.token_1.pk)
+ self.assertEqual(token_1.secret, "bar")
+ token_3 = Token.objects.get(pk=self.token_3.pk)
+ self.assertEqual(token_3.secret, "bar")
+
+ def test_cant_update_to_unsaved_object(self):
+ msg = (
+ "Unsaved model instance <User: User object ((None, None))> cannot be used "
+ "in an ORM query."
+ )
+
+ with self.assertRaisesMessage(ValueError, msg):
+ Comment.objects.update(user=User())
diff --git a/tests/composite_pk/test_values.py b/tests/composite_pk/test_values.py
new file mode 100644
index 0000000000..a3c7a589cc
--- /dev/null
+++ b/tests/composite_pk/test_values.py
@@ -0,0 +1,212 @@
+from collections import namedtuple
+from uuid import UUID
+
+from django.test import TestCase
+
+from .models import Post, Tenant, User
+
+
+class CompositePKValuesTests(TestCase):
+ USER_1_EMAIL = "user0001@example.com"
+ USER_2_EMAIL = "user0002@example.com"
+ USER_3_EMAIL = "user0003@example.com"
+ POST_1_ID = "77777777-7777-7777-7777-777777777777"
+ POST_2_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
+ POST_3_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
+
+ @classmethod
+ def setUpTestData(cls):
+ super().setUpTestData()
+ cls.tenant_1 = Tenant.objects.create()
+ cls.tenant_2 = Tenant.objects.create()
+ cls.user_1 = User.objects.create(
+ tenant=cls.tenant_1, id=1, email=cls.USER_1_EMAIL
+ )
+ cls.user_2 = User.objects.create(
+ tenant=cls.tenant_1, id=2, email=cls.USER_2_EMAIL
+ )
+ cls.user_3 = User.objects.create(
+ tenant=cls.tenant_2, id=3, email=cls.USER_3_EMAIL
+ )
+ cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_1_ID)
+ cls.post_2 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_2_ID)
+ cls.post_3 = Post.objects.create(tenant=cls.tenant_2, id=cls.POST_3_ID)
+
+ def test_values_list(self):
+ with self.subTest('User.objects.values_list("pk")'):
+ self.assertSequenceEqual(
+ User.objects.values_list("pk").order_by("pk"),
+ (
+ (self.user_1.pk,),
+ (self.user_2.pk,),
+ (self.user_3.pk,),
+ ),
+ )
+ with self.subTest('User.objects.values_list("pk", "email")'):
+ self.assertSequenceEqual(
+ User.objects.values_list("pk", "email").order_by("pk"),
+ (
+ (self.user_1.pk, self.USER_1_EMAIL),
+ (self.user_2.pk, self.USER_2_EMAIL),
+ (self.user_3.pk, self.USER_3_EMAIL),
+ ),
+ )
+ with self.subTest('User.objects.values_list("pk", "id")'):
+ self.assertSequenceEqual(
+ User.objects.values_list("pk", "id").order_by("pk"),
+ (
+ (self.user_1.pk, self.user_1.id),
+ (self.user_2.pk, self.user_2.id),
+ (self.user_3.pk, self.user_3.id),
+ ),
+ )
+ with self.subTest('User.objects.values_list("pk", "tenant_id", "id")'):
+ self.assertSequenceEqual(
+ User.objects.values_list("pk", "tenant_id", "id").order_by("pk"),
+ (
+ (self.user_1.pk, self.user_1.tenant_id, self.user_1.id),
+ (self.user_2.pk, self.user_2.tenant_id, self.user_2.id),
+ (self.user_3.pk, self.user_3.tenant_id, self.user_3.id),
+ ),
+ )
+ with self.subTest('User.objects.values_list("pk", flat=True)'):
+ self.assertSequenceEqual(
+ User.objects.values_list("pk", flat=True).order_by("pk"),
+ (
+ self.user_1.pk,
+ self.user_2.pk,
+ self.user_3.pk,
+ ),
+ )
+ with self.subTest('Post.objects.values_list("pk", flat=True)'):
+ self.assertSequenceEqual(
+ Post.objects.values_list("pk", flat=True).order_by("pk"),
+ (
+ (self.tenant_1.id, UUID(self.POST_1_ID)),
+ (self.tenant_1.id, UUID(self.POST_2_ID)),
+ (self.tenant_2.id, UUID(self.POST_3_ID)),
+ ),
+ )
+ with self.subTest('Post.objects.values_list("pk")'):
+ self.assertSequenceEqual(
+ Post.objects.values_list("pk").order_by("pk"),
+ (
+ ((self.tenant_1.id, UUID(self.POST_1_ID)),),
+ ((self.tenant_1.id, UUID(self.POST_2_ID)),),
+ ((self.tenant_2.id, UUID(self.POST_3_ID)),),
+ ),
+ )
+ with self.subTest('Post.objects.values_list("pk", "id")'):
+ self.assertSequenceEqual(
+ Post.objects.values_list("pk", "id").order_by("pk"),
+ (
+ ((self.tenant_1.id, UUID(self.POST_1_ID)), UUID(self.POST_1_ID)),
+ ((self.tenant_1.id, UUID(self.POST_2_ID)), UUID(self.POST_2_ID)),
+ ((self.tenant_2.id, UUID(self.POST_3_ID)), UUID(self.POST_3_ID)),
+ ),
+ )
+ with self.subTest('Post.objects.values_list("id", "pk")'):
+ self.assertSequenceEqual(
+ Post.objects.values_list("id", "pk").order_by("pk"),
+ (
+ (UUID(self.POST_1_ID), (self.tenant_1.id, UUID(self.POST_1_ID))),
+ (UUID(self.POST_2_ID), (self.tenant_1.id, UUID(self.POST_2_ID))),
+ (UUID(self.POST_3_ID), (self.tenant_2.id, UUID(self.POST_3_ID))),
+ ),
+ )
+ with self.subTest('User.objects.values_list("pk", named=True)'):
+ Row = namedtuple("Row", ["pk"])
+ self.assertSequenceEqual(
+ User.objects.values_list("pk", named=True).order_by("pk"),
+ (
+ Row(pk=self.user_1.pk),
+ Row(pk=self.user_2.pk),
+ Row(pk=self.user_3.pk),
+ ),
+ )
+ with self.subTest('User.objects.values_list("pk", "pk")'):
+ self.assertSequenceEqual(
+ User.objects.values_list("pk", "pk").order_by("pk"),
+ (
+ (self.user_1.pk,),
+ (self.user_2.pk,),
+ (self.user_3.pk,),
+ ),
+ )
+ with self.subTest('User.objects.values_list("pk", "id", "pk", "id")'):
+ self.assertSequenceEqual(
+ User.objects.values_list("pk", "id", "pk", "id").order_by("pk"),
+ (
+ (self.user_1.pk, self.user_1.id),
+ (self.user_2.pk, self.user_2.id),
+ (self.user_3.pk, self.user_3.id),
+ ),
+ )
+
+ def test_values(self):
+ with self.subTest('User.objects.values("pk")'):
+ self.assertSequenceEqual(
+ User.objects.values("pk").order_by("pk"),
+ (
+ {"pk": self.user_1.pk},
+ {"pk": self.user_2.pk},
+ {"pk": self.user_3.pk},
+ ),
+ )
+ with self.subTest('User.objects.values("pk", "email")'):
+ self.assertSequenceEqual(
+ User.objects.values("pk", "email").order_by("pk"),
+ (
+ {"pk": self.user_1.pk, "email": self.USER_1_EMAIL},
+ {"pk": self.user_2.pk, "email": self.USER_2_EMAIL},
+ {"pk": self.user_3.pk, "email": self.USER_3_EMAIL},
+ ),
+ )
+ with self.subTest('User.objects.values("pk", "id")'):
+ self.assertSequenceEqual(
+ User.objects.values("pk", "id").order_by("pk"),
+ (
+ {"pk": self.user_1.pk, "id": self.user_1.id},
+ {"pk": self.user_2.pk, "id": self.user_2.id},
+ {"pk": self.user_3.pk, "id": self.user_3.id},
+ ),
+ )
+ with self.subTest('User.objects.values("pk", "tenant_id", "id")'):
+ self.assertSequenceEqual(
+ User.objects.values("pk", "tenant_id", "id").order_by("pk"),
+ (
+ {
+ "pk": self.user_1.pk,
+ "tenant_id": self.user_1.tenant_id,
+ "id": self.user_1.id,
+ },
+ {
+ "pk": self.user_2.pk,
+ "tenant_id": self.user_2.tenant_id,
+ "id": self.user_2.id,
+ },
+ {
+ "pk": self.user_3.pk,
+ "tenant_id": self.user_3.tenant_id,
+ "id": self.user_3.id,
+ },
+ ),
+ )
+ with self.subTest('User.objects.values("pk", "pk")'):
+ self.assertSequenceEqual(
+ User.objects.values("pk", "pk").order_by("pk"),
+ (
+ {"pk": self.user_1.pk},
+ {"pk": self.user_2.pk},
+ {"pk": self.user_3.pk},
+ ),
+ )
+ with self.subTest('User.objects.values("pk", "id", "pk", "id")'):
+ self.assertSequenceEqual(
+ User.objects.values("pk", "id", "pk", "id").order_by("pk"),
+ (
+ {"pk": self.user_1.pk, "id": self.user_1.id},
+ {"pk": self.user_2.pk, "id": self.user_2.id},
+ {"pk": self.user_3.pk, "id": self.user_3.id},
+ ),
+ )
diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py
new file mode 100644
index 0000000000..71522cb836
--- /dev/null
+++ b/tests/composite_pk/tests.py
@@ -0,0 +1,345 @@
+import json
+import unittest
+from uuid import UUID
+
+import yaml
+
+from django import forms
+from django.core import serializers
+from django.core.exceptions import FieldError
+from django.db import IntegrityError, connection
+from django.db.models import CompositePrimaryKey
+from django.forms import modelform_factory
+from django.test import TestCase
+
+from .models import Comment, Post, Tenant, User
+
+
+class CommentForm(forms.ModelForm):
+ class Meta:
+ model = Comment
+ fields = "__all__"
+
+
+class CompositePKTests(TestCase):
+ maxDiff = None
+
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant = Tenant.objects.create()
+ cls.user = User.objects.create(
+ tenant=cls.tenant,
+ id=1,
+ email="user0001@example.com",
+ )
+ cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user)
+
+ @staticmethod
+ def get_constraints(table):
+ with connection.cursor() as cursor:
+ return connection.introspection.get_constraints(cursor, table)
+
+ def test_pk_updated_if_field_updated(self):
+ user = User.objects.get(pk=self.user.pk)
+ self.assertEqual(user.pk, (self.tenant.id, self.user.id))
+ self.assertIs(user._is_pk_set(), True)
+ user.tenant_id = 9831
+ self.assertEqual(user.pk, (9831, self.user.id))
+ self.assertIs(user._is_pk_set(), True)
+ user.id = 4321
+ self.assertEqual(user.pk, (9831, 4321))
+ self.assertIs(user._is_pk_set(), True)
+ user.pk = (9132, 3521)
+ self.assertEqual(user.tenant_id, 9132)
+ self.assertEqual(user.id, 3521)
+ self.assertIs(user._is_pk_set(), True)
+ user.id = None
+ self.assertEqual(user.pk, (9132, None))
+ self.assertEqual(user.tenant_id, 9132)
+ self.assertIsNone(user.id)
+ self.assertIs(user._is_pk_set(), False)
+
+ def test_hash(self):
+ self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2)))
+ self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3)))
+ msg = "Model instances without primary key value are unhashable"
+
+ with self.assertRaisesMessage(TypeError, msg):
+ hash(User())
+ with self.assertRaisesMessage(TypeError, msg):
+ hash(User(tenant_id=1))
+ with self.assertRaisesMessage(TypeError, msg):
+ hash(User(id=1))
+
+ def test_pk_must_be_list_or_tuple(self):
+ user = User.objects.get(pk=self.user.pk)
+ test_cases = [
+ "foo",
+ 1000,
+ 3.14,
+ True,
+ False,
+ ]
+
+ for pk in test_cases:
+ with self.assertRaisesMessage(
+ ValueError, "'pk' must be a list or a tuple."
+ ):
+ user.pk = pk
+
+ def test_pk_must_have_2_elements(self):
+ user = User.objects.get(pk=self.user.pk)
+ test_cases = [
+ (),
+ [],
+ (1000,),
+ [1000],
+ (1, 2, 3),
+ [1, 2, 3],
+ ]
+
+ for pk in test_cases:
+ with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."):
+ user.pk = pk
+
+ def test_composite_pk_in_fields(self):
+ user_fields = {f.name for f in User._meta.get_fields()}
+ self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"})
+
+ comment_fields = {f.name for f in Comment._meta.get_fields()}
+ self.assertEqual(
+ comment_fields,
+ {"pk", "tenant", "id", "user_id", "user", "text"},
+ )
+
+ def test_pk_field(self):
+ pk = User._meta.get_field("pk")
+ self.assertIsInstance(pk, CompositePrimaryKey)
+ self.assertIs(User._meta.pk, pk)
+
+ def test_error_on_user_pk_conflict(self):
+ with self.assertRaises(IntegrityError):
+ User.objects.create(tenant=self.tenant, id=self.user.id)
+
+ def test_error_on_comment_pk_conflict(self):
+ with self.assertRaises(IntegrityError):
+ Comment.objects.create(tenant=self.tenant, id=self.comment.id)
+
+ @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
+ def test_get_constraints_postgresql(self):
+ user_constraints = self.get_constraints(User._meta.db_table)
+ user_pk = user_constraints["composite_pk_user_pkey"]
+ self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+ self.assertIs(user_pk["primary_key"], True)
+
+ comment_constraints = self.get_constraints(Comment._meta.db_table)
+ comment_pk = comment_constraints["composite_pk_comment_pkey"]
+ self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+ self.assertIs(comment_pk["primary_key"], True)
+
+ @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
+ def test_get_constraints_sqlite(self):
+ user_constraints = self.get_constraints(User._meta.db_table)
+ user_pk = user_constraints["__primary__"]
+ self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+ self.assertIs(user_pk["primary_key"], True)
+
+ comment_constraints = self.get_constraints(Comment._meta.db_table)
+ comment_pk = comment_constraints["__primary__"]
+ self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+ self.assertIs(comment_pk["primary_key"], True)
+
+ @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test")
+ def test_get_constraints_mysql(self):
+ user_constraints = self.get_constraints(User._meta.db_table)
+ user_pk = user_constraints["PRIMARY"]
+ self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+ self.assertIs(user_pk["primary_key"], True)
+
+ comment_constraints = self.get_constraints(Comment._meta.db_table)
+ comment_pk = comment_constraints["PRIMARY"]
+ self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+ self.assertIs(comment_pk["primary_key"], True)
+
+ @unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test")
+ def test_get_constraints_oracle(self):
+ user_constraints = self.get_constraints(User._meta.db_table)
+ user_pk = next(c for c in user_constraints.values() if c["primary_key"])
+ self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+ self.assertEqual(user_pk["primary_key"], 1)
+
+ comment_constraints = self.get_constraints(Comment._meta.db_table)
+ comment_pk = next(c for c in comment_constraints.values() if c["primary_key"])
+ self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+ self.assertEqual(comment_pk["primary_key"], 1)
+
+ def test_in_bulk(self):
+ """
+ Test the .in_bulk() method of composite_pk models.
+ """
+ result = Comment.objects.in_bulk()
+ self.assertEqual(result, {self.comment.pk: self.comment})
+
+ result = Comment.objects.in_bulk([self.comment.pk])
+ self.assertEqual(result, {self.comment.pk: self.comment})
+
+ def test_iterator(self):
+ """
+ Test the .iterator() method of composite_pk models.
+ """
+ result = list(Comment.objects.iterator())
+ self.assertEqual(result, [self.comment])
+
+ def test_query(self):
+ users = User.objects.values_list("pk").order_by("pk")
+ self.assertNotIn('AS "pk"', str(users.query))
+
+ def test_only(self):
+ users = User.objects.only("pk")
+ self.assertSequenceEqual(users, (self.user,))
+ user = users[0]
+
+ with self.assertNumQueries(0):
+ self.assertEqual(user.pk, (self.user.tenant_id, self.user.id))
+ self.assertEqual(user.tenant_id, self.user.tenant_id)
+ self.assertEqual(user.id, self.user.id)
+ with self.assertNumQueries(1):
+ self.assertEqual(user.email, self.user.email)
+
+ def test_model_forms(self):
+ fields = ["tenant", "id", "user_id", "text"]
+ self.assertEqual(list(CommentForm.base_fields), fields)
+
+ form = modelform_factory(Comment, fields="__all__")
+ self.assertEqual(list(form().fields), fields)
+
+ with self.assertRaisesMessage(
+ FieldError, "Unknown field(s) (pk) specified for Comment"
+ ):
+ self.assertIsNone(modelform_factory(Comment, fields=["pk"]))
+
+
+class CompositePKFixturesTests(TestCase):
+ fixtures = ["tenant"]
+
+ def test_objects(self):
+ tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk")
+ self.assertEqual(tenant_1.id, 1)
+ self.assertEqual(tenant_1.name, "Tenant 1")
+ self.assertEqual(tenant_2.id, 2)
+ self.assertEqual(tenant_2.name, "Tenant 2")
+ self.assertEqual(tenant_3.id, 3)
+ self.assertEqual(tenant_3.name, "Tenant 3")
+
+ user_1, user_2, user_3, user_4 = User.objects.order_by("pk")
+ self.assertEqual(user_1.id, 1)
+ self.assertEqual(user_1.tenant_id, 1)
+ self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id))
+ self.assertEqual(user_1.email, "user0001@example.com")
+ self.assertEqual(user_2.id, 2)
+ self.assertEqual(user_2.tenant_id, 1)
+ self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id))
+ self.assertEqual(user_2.email, "user0002@example.com")
+ self.assertEqual(user_3.id, 3)
+ self.assertEqual(user_3.tenant_id, 2)
+ self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id))
+ self.assertEqual(user_3.email, "user0003@example.com")
+ self.assertEqual(user_4.id, 4)
+ self.assertEqual(user_4.tenant_id, 2)
+ self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id))
+ self.assertEqual(user_4.email, "user0004@example.com")
+
+ post_1, post_2 = Post.objects.order_by("pk")
+ self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111"))
+ self.assertEqual(post_1.tenant_id, 2)
+ self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id))
+ self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"))
+ self.assertEqual(post_2.tenant_id, 2)
+ self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id))
+
+ def test_serialize_user_json(self):
+ users = User.objects.filter(pk=(1, 1))
+ result = serializers.serialize("json", users)
+ self.assertEqual(
+ json.loads(result),
+ [
+ {
+ "model": "composite_pk.user",
+ "pk": [1, 1],
+ "fields": {
+ "email": "user0001@example.com",
+ "id": 1,
+ "tenant": 1,
+ },
+ }
+ ],
+ )
+
+ def test_serialize_user_jsonl(self):
+ users = User.objects.filter(pk=(1, 2))
+ result = serializers.serialize("jsonl", users)
+ self.assertEqual(
+ json.loads(result),
+ {
+ "model": "composite_pk.user",
+ "pk": [1, 2],
+ "fields": {
+ "email": "user0002@example.com",
+ "id": 2,
+ "tenant": 1,
+ },
+ },
+ )
+
+ def test_serialize_user_yaml(self):
+ users = User.objects.filter(pk=(2, 3))
+ result = serializers.serialize("yaml", users)
+ self.assertEqual(
+ yaml.safe_load(result),
+ [
+ {
+ "model": "composite_pk.user",
+ "pk": [2, 3],
+ "fields": {
+ "email": "user0003@example.com",
+ "id": 3,
+ "tenant": 2,
+ },
+ },
+ ],
+ )
+
+ def test_serialize_user_python(self):
+ users = User.objects.filter(pk=(2, 4))
+ result = serializers.serialize("python", users)
+ self.assertEqual(
+ result,
+ [
+ {
+ "model": "composite_pk.user",
+ "pk": [2, 4],
+ "fields": {
+ "email": "user0004@example.com",
+ "id": 4,
+ "tenant": 2,
+ },
+ },
+ ],
+ )
+
+ def test_serialize_post_uuid(self):
+ posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111"))
+ result = serializers.serialize("json", posts)
+ self.assertEqual(
+ json.loads(result),
+ [
+ {
+ "model": "composite_pk.post",
+ "pk": [2, "11111111-1111-1111-1111-111111111111"],
+ "fields": {
+ "id": "11111111-1111-1111-1111-111111111111",
+ "tenant": 2,
+ },
+ },
+ ],
+ )
diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py
index de62170eb3..33196ea6f4 100644
--- a/tests/migrations/test_autodetector.py
+++ b/tests/migrations/test_autodetector.py
@@ -5059,6 +5059,95 @@ class AutodetectorTests(BaseAutodetectorTests):
self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"])
self.assertOperationAttributes(changes, "testapp", 0, 0, name="Book")
+ @mock.patch(
+ "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition"
+ )
+ def test_add_composite_pk(self, mocked_ask_method):
+ before = [
+ ModelState(
+ "app",
+ "foo",
+ [
+ ("id", models.AutoField(primary_key=True)),
+ ],
+ ),
+ ]
+ after = [
+ ModelState(
+ "app",
+ "foo",
+ [
+ ("pk", models.CompositePrimaryKey("foo_id", "bar_id")),
+ ("id", models.IntegerField()),
+ ],
+ ),
+ ]
+
+ changes = self.get_changes(before, after)
+ self.assertEqual(mocked_ask_method.call_count, 0)
+ self.assertNumberMigrations(changes, "app", 1)
+ self.assertOperationTypes(changes, "app", 0, ["AddField", "AlterField"])
+ self.assertOperationAttributes(
+ changes,
+ "app",
+ 0,
+ 0,
+ name="pk",
+ model_name="foo",
+ preserve_default=True,
+ )
+ self.assertOperationAttributes(
+ changes,
+ "app",
+ 0,
+ 1,
+ name="id",
+ model_name="foo",
+ preserve_default=True,
+ )
+
+ def test_remove_composite_pk(self):
+ before = [
+ ModelState(
+ "app",
+ "foo",
+ [
+ ("pk", models.CompositePrimaryKey("foo_id", "bar_id")),
+ ("id", models.IntegerField()),
+ ],
+ ),
+ ]
+ after = [
+ ModelState(
+ "app",
+ "foo",
+ [
+ ("id", models.AutoField(primary_key=True)),
+ ],
+ ),
+ ]
+
+ changes = self.get_changes(before, after)
+ self.assertNumberMigrations(changes, "app", 1)
+ self.assertOperationTypes(changes, "app", 0, ["RemoveField", "AlterField"])
+ self.assertOperationAttributes(
+ changes,
+ "app",
+ 0,
+ 0,
+ name="pk",
+ model_name="foo",
+ )
+ self.assertOperationAttributes(
+ changes,
+ "app",
+ 0,
+ 1,
+ name="id",
+ model_name="foo",
+ preserve_default=True,
+ )
+
class MigrationSuggestNameTests(SimpleTestCase):
def test_no_operations(self):
diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py
index da0ec93dcd..6312a7d4a2 100644
--- a/tests/migrations/test_operations.py
+++ b/tests/migrations/test_operations.py
@@ -6287,6 +6287,61 @@ class OperationTests(OperationTestBase):
self.assertEqual(pony_new.generated, 1)
self.assertEqual(pony_new.static, 2)
+ def test_composite_pk_operations(self):
+ app_label = "test_d8d90af6"
+ project_state = self.set_up_test_model(app_label)
+ operation_1 = migrations.AddField(
+ "Pony", "pk", models.CompositePrimaryKey("id", "pink")
+ )
+ operation_2 = migrations.AlterField("Pony", "id", models.IntegerField())
+ operation_3 = migrations.RemoveField("Pony", "pk")
+ table_name = f"{app_label}_pony"
+
+ # 1. Add field (pk).
+ new_state = project_state.clone()
+ operation_1.state_forwards(app_label, new_state)
+ with connection.schema_editor() as editor:
+ operation_1.database_forwards(app_label, editor, project_state, new_state)
+ self.assertColumnNotExists(table_name, "pk")
+ Pony = new_state.apps.get_model(app_label, "pony")
+ obj_1 = Pony.objects.create(weight=1)
+ msg = (
+ f"obj_1={obj_1}, "
+ f"obj_1.id={obj_1.id}, "
+ f"obj_1.pink={obj_1.pink}, "
+ f"obj_1.pk={obj_1.pk}, "
+ f"Pony._meta.pk={repr(Pony._meta.pk)}, "
+ f"Pony._meta.get_field('id')={repr(Pony._meta.get_field('id'))}"
+ )
+ self.assertEqual(obj_1.pink, 3, msg)
+ self.assertEqual(obj_1.pk, (obj_1.id, obj_1.pink), msg)
+
+ # 2. Alter field (id -> IntegerField()).
+ project_state, new_state = new_state, new_state.clone()
+ operation_2.state_forwards(app_label, new_state)
+ with connection.schema_editor() as editor:
+ operation_2.database_forwards(app_label, editor, project_state, new_state)
+ Pony = new_state.apps.get_model(app_label, "pony")
+ obj_1 = Pony.objects.get(id=obj_1.id)
+ self.assertEqual(obj_1.pink, 3)
+ self.assertEqual(obj_1.pk, (obj_1.id, obj_1.pink))
+ obj_2 = Pony.objects.create(id=2, weight=2)
+ self.assertEqual(obj_2.id, 2)
+ self.assertEqual(obj_2.pink, 3)
+ self.assertEqual(obj_2.pk, (obj_2.id, obj_2.pink))
+
+ # 3. Remove field (pk).
+ project_state, new_state = new_state, new_state.clone()
+ operation_3.state_forwards(app_label, new_state)
+ with connection.schema_editor() as editor:
+ operation_3.database_forwards(app_label, editor, project_state, new_state)
+ Pony = new_state.apps.get_model(app_label, "pony")
+ obj_1 = Pony.objects.get(id=obj_1.id)
+ self.assertEqual(obj_1.pk, obj_1.id)
+ obj_2 = Pony.objects.get(id=obj_2.id)
+ self.assertEqual(obj_2.id, 2)
+ self.assertEqual(obj_2.pk, obj_2.id)
+
class SwappableOperationTests(OperationTestBase):
"""
diff --git a/tests/migrations/test_state.py b/tests/migrations/test_state.py
index dbbdf77734..d6ecaa1c5d 100644
--- a/tests/migrations/test_state.py
+++ b/tests/migrations/test_state.py
@@ -1206,6 +1206,28 @@ class StateTests(SimpleTestCase):
choices_field = Author._meta.get_field("choice")
self.assertEqual(list(choices_field.choices), choices)
+ def test_composite_pk_state(self):
+ new_apps = Apps(["migrations"])
+
+ class Foo(models.Model):
+ pk = models.CompositePrimaryKey("account_id", "id")
+ account_id = models.SmallIntegerField()
+ id = models.SmallIntegerField()
+
+ class Meta:
+ app_label = "migrations"
+ apps = new_apps
+
+ project_state = ProjectState.from_apps(new_apps)
+ model_state = project_state.models["migrations", "foo"]
+ self.assertEqual(len(model_state.options), 2)
+ self.assertEqual(model_state.options["constraints"], [])
+ self.assertEqual(model_state.options["indexes"], [])
+ self.assertEqual(len(model_state.fields), 3)
+ self.assertIn("pk", model_state.fields)
+ self.assertIn("account_id", model_state.fields)
+ self.assertIn("id", model_state.fields)
+
class StateRelationsTests(SimpleTestCase):
def get_base_project_state(self):
diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py
index 51783b7346..953a3cdb6c 100644
--- a/tests/migrations/test_writer.py
+++ b/tests/migrations/test_writer.py
@@ -1138,3 +1138,22 @@ class WriterTests(SimpleTestCase):
ValueError, "'TestModel1' must inherit from 'BaseSerializer'."
):
MigrationWriter.register_serializer(complex, TestModel1)
+
+ def test_composite_pk_import(self):
+ migration = type(
+ "Migration",
+ (migrations.Migration,),
+ {
+ "operations": [
+ migrations.AddField(
+ "foo",
+ "bar",
+ models.CompositePrimaryKey("foo_id", "bar_id"),
+ ),
+ ],
+ },
+ )
+ writer = MigrationWriter(migration)
+ output = writer.as_string()
+ self.assertEqual(output.count("import"), 1)
+ self.assertIn("from django.db import migrations, models", output)