summaryrefslogtreecommitdiff
path: root/tests/composite_pk/tests.py
diff options
context:
space:
mode:
authorBendeguz Csirmaz <csirmazbendeguz@gmail.com>2024-04-07 10:32:16 +0800
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2024-11-29 11:23:04 +0100
commit978aae4334fa71ba78a3e94408f0f3aebde8d07c (patch)
treedd1cc322769441a3dd28b952ce52e07c3f72f90a /tests/composite_pk/tests.py
parent86661f2449fb0903f72b3522c68e146934013377 (diff)
Fixed #373 -- Added CompositePrimaryKey.
Thanks Lily Foote and Simon Charette for reviews and mentoring this Google Summer of Code 2024 project. Co-authored-by: Simon Charette <charette.s@gmail.com> Co-authored-by: Lily Foote <code@lilyf.org>
Diffstat (limited to 'tests/composite_pk/tests.py')
-rw-r--r--tests/composite_pk/tests.py345
1 files changed, 345 insertions, 0 deletions
diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py
new file mode 100644
index 0000000000..71522cb836
--- /dev/null
+++ b/tests/composite_pk/tests.py
@@ -0,0 +1,345 @@
+import json
+import unittest
+from uuid import UUID
+
+import yaml
+
+from django import forms
+from django.core import serializers
+from django.core.exceptions import FieldError
+from django.db import IntegrityError, connection
+from django.db.models import CompositePrimaryKey
+from django.forms import modelform_factory
+from django.test import TestCase
+
+from .models import Comment, Post, Tenant, User
+
+
+class CommentForm(forms.ModelForm):
+ class Meta:
+ model = Comment
+ fields = "__all__"
+
+
+class CompositePKTests(TestCase):
+ maxDiff = None
+
+ @classmethod
+ def setUpTestData(cls):
+ cls.tenant = Tenant.objects.create()
+ cls.user = User.objects.create(
+ tenant=cls.tenant,
+ id=1,
+ email="user0001@example.com",
+ )
+ cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user)
+
+ @staticmethod
+ def get_constraints(table):
+ with connection.cursor() as cursor:
+ return connection.introspection.get_constraints(cursor, table)
+
+ def test_pk_updated_if_field_updated(self):
+ user = User.objects.get(pk=self.user.pk)
+ self.assertEqual(user.pk, (self.tenant.id, self.user.id))
+ self.assertIs(user._is_pk_set(), True)
+ user.tenant_id = 9831
+ self.assertEqual(user.pk, (9831, self.user.id))
+ self.assertIs(user._is_pk_set(), True)
+ user.id = 4321
+ self.assertEqual(user.pk, (9831, 4321))
+ self.assertIs(user._is_pk_set(), True)
+ user.pk = (9132, 3521)
+ self.assertEqual(user.tenant_id, 9132)
+ self.assertEqual(user.id, 3521)
+ self.assertIs(user._is_pk_set(), True)
+ user.id = None
+ self.assertEqual(user.pk, (9132, None))
+ self.assertEqual(user.tenant_id, 9132)
+ self.assertIsNone(user.id)
+ self.assertIs(user._is_pk_set(), False)
+
+ def test_hash(self):
+ self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2)))
+ self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3)))
+ msg = "Model instances without primary key value are unhashable"
+
+ with self.assertRaisesMessage(TypeError, msg):
+ hash(User())
+ with self.assertRaisesMessage(TypeError, msg):
+ hash(User(tenant_id=1))
+ with self.assertRaisesMessage(TypeError, msg):
+ hash(User(id=1))
+
+ def test_pk_must_be_list_or_tuple(self):
+ user = User.objects.get(pk=self.user.pk)
+ test_cases = [
+ "foo",
+ 1000,
+ 3.14,
+ True,
+ False,
+ ]
+
+ for pk in test_cases:
+ with self.assertRaisesMessage(
+ ValueError, "'pk' must be a list or a tuple."
+ ):
+ user.pk = pk
+
+ def test_pk_must_have_2_elements(self):
+ user = User.objects.get(pk=self.user.pk)
+ test_cases = [
+ (),
+ [],
+ (1000,),
+ [1000],
+ (1, 2, 3),
+ [1, 2, 3],
+ ]
+
+ for pk in test_cases:
+ with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."):
+ user.pk = pk
+
+ def test_composite_pk_in_fields(self):
+ user_fields = {f.name for f in User._meta.get_fields()}
+ self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"})
+
+ comment_fields = {f.name for f in Comment._meta.get_fields()}
+ self.assertEqual(
+ comment_fields,
+ {"pk", "tenant", "id", "user_id", "user", "text"},
+ )
+
+ def test_pk_field(self):
+ pk = User._meta.get_field("pk")
+ self.assertIsInstance(pk, CompositePrimaryKey)
+ self.assertIs(User._meta.pk, pk)
+
+ def test_error_on_user_pk_conflict(self):
+ with self.assertRaises(IntegrityError):
+ User.objects.create(tenant=self.tenant, id=self.user.id)
+
+ def test_error_on_comment_pk_conflict(self):
+ with self.assertRaises(IntegrityError):
+ Comment.objects.create(tenant=self.tenant, id=self.comment.id)
+
+ @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
+ def test_get_constraints_postgresql(self):
+ user_constraints = self.get_constraints(User._meta.db_table)
+ user_pk = user_constraints["composite_pk_user_pkey"]
+ self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+ self.assertIs(user_pk["primary_key"], True)
+
+ comment_constraints = self.get_constraints(Comment._meta.db_table)
+ comment_pk = comment_constraints["composite_pk_comment_pkey"]
+ self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+ self.assertIs(comment_pk["primary_key"], True)
+
+ @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
+ def test_get_constraints_sqlite(self):
+ user_constraints = self.get_constraints(User._meta.db_table)
+ user_pk = user_constraints["__primary__"]
+ self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+ self.assertIs(user_pk["primary_key"], True)
+
+ comment_constraints = self.get_constraints(Comment._meta.db_table)
+ comment_pk = comment_constraints["__primary__"]
+ self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+ self.assertIs(comment_pk["primary_key"], True)
+
+ @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test")
+ def test_get_constraints_mysql(self):
+ user_constraints = self.get_constraints(User._meta.db_table)
+ user_pk = user_constraints["PRIMARY"]
+ self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+ self.assertIs(user_pk["primary_key"], True)
+
+ comment_constraints = self.get_constraints(Comment._meta.db_table)
+ comment_pk = comment_constraints["PRIMARY"]
+ self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+ self.assertIs(comment_pk["primary_key"], True)
+
+ @unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test")
+ def test_get_constraints_oracle(self):
+ user_constraints = self.get_constraints(User._meta.db_table)
+ user_pk = next(c for c in user_constraints.values() if c["primary_key"])
+ self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+ self.assertEqual(user_pk["primary_key"], 1)
+
+ comment_constraints = self.get_constraints(Comment._meta.db_table)
+ comment_pk = next(c for c in comment_constraints.values() if c["primary_key"])
+ self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+ self.assertEqual(comment_pk["primary_key"], 1)
+
+ def test_in_bulk(self):
+ """
+ Test the .in_bulk() method of composite_pk models.
+ """
+ result = Comment.objects.in_bulk()
+ self.assertEqual(result, {self.comment.pk: self.comment})
+
+ result = Comment.objects.in_bulk([self.comment.pk])
+ self.assertEqual(result, {self.comment.pk: self.comment})
+
+ def test_iterator(self):
+ """
+ Test the .iterator() method of composite_pk models.
+ """
+ result = list(Comment.objects.iterator())
+ self.assertEqual(result, [self.comment])
+
+ def test_query(self):
+ users = User.objects.values_list("pk").order_by("pk")
+ self.assertNotIn('AS "pk"', str(users.query))
+
+ def test_only(self):
+ users = User.objects.only("pk")
+ self.assertSequenceEqual(users, (self.user,))
+ user = users[0]
+
+ with self.assertNumQueries(0):
+ self.assertEqual(user.pk, (self.user.tenant_id, self.user.id))
+ self.assertEqual(user.tenant_id, self.user.tenant_id)
+ self.assertEqual(user.id, self.user.id)
+ with self.assertNumQueries(1):
+ self.assertEqual(user.email, self.user.email)
+
+ def test_model_forms(self):
+ fields = ["tenant", "id", "user_id", "text"]
+ self.assertEqual(list(CommentForm.base_fields), fields)
+
+ form = modelform_factory(Comment, fields="__all__")
+ self.assertEqual(list(form().fields), fields)
+
+ with self.assertRaisesMessage(
+ FieldError, "Unknown field(s) (pk) specified for Comment"
+ ):
+ self.assertIsNone(modelform_factory(Comment, fields=["pk"]))
+
+
+class CompositePKFixturesTests(TestCase):
+ fixtures = ["tenant"]
+
+ def test_objects(self):
+ tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk")
+ self.assertEqual(tenant_1.id, 1)
+ self.assertEqual(tenant_1.name, "Tenant 1")
+ self.assertEqual(tenant_2.id, 2)
+ self.assertEqual(tenant_2.name, "Tenant 2")
+ self.assertEqual(tenant_3.id, 3)
+ self.assertEqual(tenant_3.name, "Tenant 3")
+
+ user_1, user_2, user_3, user_4 = User.objects.order_by("pk")
+ self.assertEqual(user_1.id, 1)
+ self.assertEqual(user_1.tenant_id, 1)
+ self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id))
+ self.assertEqual(user_1.email, "user0001@example.com")
+ self.assertEqual(user_2.id, 2)
+ self.assertEqual(user_2.tenant_id, 1)
+ self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id))
+ self.assertEqual(user_2.email, "user0002@example.com")
+ self.assertEqual(user_3.id, 3)
+ self.assertEqual(user_3.tenant_id, 2)
+ self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id))
+ self.assertEqual(user_3.email, "user0003@example.com")
+ self.assertEqual(user_4.id, 4)
+ self.assertEqual(user_4.tenant_id, 2)
+ self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id))
+ self.assertEqual(user_4.email, "user0004@example.com")
+
+ post_1, post_2 = Post.objects.order_by("pk")
+ self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111"))
+ self.assertEqual(post_1.tenant_id, 2)
+ self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id))
+ self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"))
+ self.assertEqual(post_2.tenant_id, 2)
+ self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id))
+
+ def test_serialize_user_json(self):
+ users = User.objects.filter(pk=(1, 1))
+ result = serializers.serialize("json", users)
+ self.assertEqual(
+ json.loads(result),
+ [
+ {
+ "model": "composite_pk.user",
+ "pk": [1, 1],
+ "fields": {
+ "email": "user0001@example.com",
+ "id": 1,
+ "tenant": 1,
+ },
+ }
+ ],
+ )
+
+ def test_serialize_user_jsonl(self):
+ users = User.objects.filter(pk=(1, 2))
+ result = serializers.serialize("jsonl", users)
+ self.assertEqual(
+ json.loads(result),
+ {
+ "model": "composite_pk.user",
+ "pk": [1, 2],
+ "fields": {
+ "email": "user0002@example.com",
+ "id": 2,
+ "tenant": 1,
+ },
+ },
+ )
+
+ def test_serialize_user_yaml(self):
+ users = User.objects.filter(pk=(2, 3))
+ result = serializers.serialize("yaml", users)
+ self.assertEqual(
+ yaml.safe_load(result),
+ [
+ {
+ "model": "composite_pk.user",
+ "pk": [2, 3],
+ "fields": {
+ "email": "user0003@example.com",
+ "id": 3,
+ "tenant": 2,
+ },
+ },
+ ],
+ )
+
+ def test_serialize_user_python(self):
+ users = User.objects.filter(pk=(2, 4))
+ result = serializers.serialize("python", users)
+ self.assertEqual(
+ result,
+ [
+ {
+ "model": "composite_pk.user",
+ "pk": [2, 4],
+ "fields": {
+ "email": "user0004@example.com",
+ "id": 4,
+ "tenant": 2,
+ },
+ },
+ ],
+ )
+
+ def test_serialize_post_uuid(self):
+ posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111"))
+ result = serializers.serialize("json", posts)
+ self.assertEqual(
+ json.loads(result),
+ [
+ {
+ "model": "composite_pk.post",
+ "pk": [2, "11111111-1111-1111-1111-111111111111"],
+ "fields": {
+ "id": "11111111-1111-1111-1111-111111111111",
+ "tenant": 2,
+ },
+ },
+ ],
+ )