diff options
| author | Bendeguz Csirmaz <csirmazbendeguz@gmail.com> | 2024-04-07 10:32:16 +0800 |
|---|---|---|
| committer | Sarah Boyce <42296566+sarahboyce@users.noreply.github.com> | 2024-11-29 11:23:04 +0100 |
| commit | 978aae4334fa71ba78a3e94408f0f3aebde8d07c (patch) | |
| tree | dd1cc322769441a3dd28b952ce52e07c3f72f90a /tests/composite_pk/tests.py | |
| parent | 86661f2449fb0903f72b3522c68e146934013377 (diff) | |
Fixed #373 -- Added CompositePrimaryKey.
Thanks Lily Foote and Simon Charette for reviews and mentoring
this Google Summer of Code 2024 project.
Co-authored-by: Simon Charette <charette.s@gmail.com>
Co-authored-by: Lily Foote <code@lilyf.org>
Diffstat (limited to 'tests/composite_pk/tests.py')
| -rw-r--r-- | tests/composite_pk/tests.py | 345 |
1 files changed, 345 insertions, 0 deletions
diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py new file mode 100644 index 0000000000..71522cb836 --- /dev/null +++ b/tests/composite_pk/tests.py @@ -0,0 +1,345 @@ +import json +import unittest +from uuid import UUID + +import yaml + +from django import forms +from django.core import serializers +from django.core.exceptions import FieldError +from django.db import IntegrityError, connection +from django.db.models import CompositePrimaryKey +from django.forms import modelform_factory +from django.test import TestCase + +from .models import Comment, Post, Tenant, User + + +class CommentForm(forms.ModelForm): + class Meta: + model = Comment + fields = "__all__" + + +class CompositePKTests(TestCase): + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + cls.user = User.objects.create( + tenant=cls.tenant, + id=1, + email="user0001@example.com", + ) + cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user) + + @staticmethod + def get_constraints(table): + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + def test_pk_updated_if_field_updated(self): + user = User.objects.get(pk=self.user.pk) + self.assertEqual(user.pk, (self.tenant.id, self.user.id)) + self.assertIs(user._is_pk_set(), True) + user.tenant_id = 9831 + self.assertEqual(user.pk, (9831, self.user.id)) + self.assertIs(user._is_pk_set(), True) + user.id = 4321 + self.assertEqual(user.pk, (9831, 4321)) + self.assertIs(user._is_pk_set(), True) + user.pk = (9132, 3521) + self.assertEqual(user.tenant_id, 9132) + self.assertEqual(user.id, 3521) + self.assertIs(user._is_pk_set(), True) + user.id = None + self.assertEqual(user.pk, (9132, None)) + self.assertEqual(user.tenant_id, 9132) + self.assertIsNone(user.id) + self.assertIs(user._is_pk_set(), False) + + def test_hash(self): + self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2))) + self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3))) + msg = "Model instances without primary key value are unhashable" + + with self.assertRaisesMessage(TypeError, msg): + hash(User()) + with self.assertRaisesMessage(TypeError, msg): + hash(User(tenant_id=1)) + with self.assertRaisesMessage(TypeError, msg): + hash(User(id=1)) + + def test_pk_must_be_list_or_tuple(self): + user = User.objects.get(pk=self.user.pk) + test_cases = [ + "foo", + 1000, + 3.14, + True, + False, + ] + + for pk in test_cases: + with self.assertRaisesMessage( + ValueError, "'pk' must be a list or a tuple." + ): + user.pk = pk + + def test_pk_must_have_2_elements(self): + user = User.objects.get(pk=self.user.pk) + test_cases = [ + (), + [], + (1000,), + [1000], + (1, 2, 3), + [1, 2, 3], + ] + + for pk in test_cases: + with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."): + user.pk = pk + + def test_composite_pk_in_fields(self): + user_fields = {f.name for f in User._meta.get_fields()} + self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"}) + + comment_fields = {f.name for f in Comment._meta.get_fields()} + self.assertEqual( + comment_fields, + {"pk", "tenant", "id", "user_id", "user", "text"}, + ) + + def test_pk_field(self): + pk = User._meta.get_field("pk") + self.assertIsInstance(pk, CompositePrimaryKey) + self.assertIs(User._meta.pk, pk) + + def test_error_on_user_pk_conflict(self): + with self.assertRaises(IntegrityError): + User.objects.create(tenant=self.tenant, id=self.user.id) + + def test_error_on_comment_pk_conflict(self): + with self.assertRaises(IntegrityError): + Comment.objects.create(tenant=self.tenant, id=self.comment.id) + + @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test") + def test_get_constraints_postgresql(self): + user_constraints = self.get_constraints(User._meta.db_table) + user_pk = user_constraints["composite_pk_user_pkey"] + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertIs(user_pk["primary_key"], True) + + comment_constraints = self.get_constraints(Comment._meta.db_table) + comment_pk = comment_constraints["composite_pk_comment_pkey"] + self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) + self.assertIs(comment_pk["primary_key"], True) + + @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test") + def test_get_constraints_sqlite(self): + user_constraints = self.get_constraints(User._meta.db_table) + user_pk = user_constraints["__primary__"] + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertIs(user_pk["primary_key"], True) + + comment_constraints = self.get_constraints(Comment._meta.db_table) + comment_pk = comment_constraints["__primary__"] + self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) + self.assertIs(comment_pk["primary_key"], True) + + @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test") + def test_get_constraints_mysql(self): + user_constraints = self.get_constraints(User._meta.db_table) + user_pk = user_constraints["PRIMARY"] + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertIs(user_pk["primary_key"], True) + + comment_constraints = self.get_constraints(Comment._meta.db_table) + comment_pk = comment_constraints["PRIMARY"] + self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) + self.assertIs(comment_pk["primary_key"], True) + + @unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test") + def test_get_constraints_oracle(self): + user_constraints = self.get_constraints(User._meta.db_table) + user_pk = next(c for c in user_constraints.values() if c["primary_key"]) + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertEqual(user_pk["primary_key"], 1) + + comment_constraints = self.get_constraints(Comment._meta.db_table) + comment_pk = next(c for c in comment_constraints.values() if c["primary_key"]) + self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) + self.assertEqual(comment_pk["primary_key"], 1) + + def test_in_bulk(self): + """ + Test the .in_bulk() method of composite_pk models. + """ + result = Comment.objects.in_bulk() + self.assertEqual(result, {self.comment.pk: self.comment}) + + result = Comment.objects.in_bulk([self.comment.pk]) + self.assertEqual(result, {self.comment.pk: self.comment}) + + def test_iterator(self): + """ + Test the .iterator() method of composite_pk models. + """ + result = list(Comment.objects.iterator()) + self.assertEqual(result, [self.comment]) + + def test_query(self): + users = User.objects.values_list("pk").order_by("pk") + self.assertNotIn('AS "pk"', str(users.query)) + + def test_only(self): + users = User.objects.only("pk") + self.assertSequenceEqual(users, (self.user,)) + user = users[0] + + with self.assertNumQueries(0): + self.assertEqual(user.pk, (self.user.tenant_id, self.user.id)) + self.assertEqual(user.tenant_id, self.user.tenant_id) + self.assertEqual(user.id, self.user.id) + with self.assertNumQueries(1): + self.assertEqual(user.email, self.user.email) + + def test_model_forms(self): + fields = ["tenant", "id", "user_id", "text"] + self.assertEqual(list(CommentForm.base_fields), fields) + + form = modelform_factory(Comment, fields="__all__") + self.assertEqual(list(form().fields), fields) + + with self.assertRaisesMessage( + FieldError, "Unknown field(s) (pk) specified for Comment" + ): + self.assertIsNone(modelform_factory(Comment, fields=["pk"])) + + +class CompositePKFixturesTests(TestCase): + fixtures = ["tenant"] + + def test_objects(self): + tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk") + self.assertEqual(tenant_1.id, 1) + self.assertEqual(tenant_1.name, "Tenant 1") + self.assertEqual(tenant_2.id, 2) + self.assertEqual(tenant_2.name, "Tenant 2") + self.assertEqual(tenant_3.id, 3) + self.assertEqual(tenant_3.name, "Tenant 3") + + user_1, user_2, user_3, user_4 = User.objects.order_by("pk") + self.assertEqual(user_1.id, 1) + self.assertEqual(user_1.tenant_id, 1) + self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id)) + self.assertEqual(user_1.email, "user0001@example.com") + self.assertEqual(user_2.id, 2) + self.assertEqual(user_2.tenant_id, 1) + self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id)) + self.assertEqual(user_2.email, "user0002@example.com") + self.assertEqual(user_3.id, 3) + self.assertEqual(user_3.tenant_id, 2) + self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id)) + self.assertEqual(user_3.email, "user0003@example.com") + self.assertEqual(user_4.id, 4) + self.assertEqual(user_4.tenant_id, 2) + self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id)) + self.assertEqual(user_4.email, "user0004@example.com") + + post_1, post_2 = Post.objects.order_by("pk") + self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111")) + self.assertEqual(post_1.tenant_id, 2) + self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id)) + self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")) + self.assertEqual(post_2.tenant_id, 2) + self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id)) + + def test_serialize_user_json(self): + users = User.objects.filter(pk=(1, 1)) + result = serializers.serialize("json", users) + self.assertEqual( + json.loads(result), + [ + { + "model": "composite_pk.user", + "pk": [1, 1], + "fields": { + "email": "user0001@example.com", + "id": 1, + "tenant": 1, + }, + } + ], + ) + + def test_serialize_user_jsonl(self): + users = User.objects.filter(pk=(1, 2)) + result = serializers.serialize("jsonl", users) + self.assertEqual( + json.loads(result), + { + "model": "composite_pk.user", + "pk": [1, 2], + "fields": { + "email": "user0002@example.com", + "id": 2, + "tenant": 1, + }, + }, + ) + + def test_serialize_user_yaml(self): + users = User.objects.filter(pk=(2, 3)) + result = serializers.serialize("yaml", users) + self.assertEqual( + yaml.safe_load(result), + [ + { + "model": "composite_pk.user", + "pk": [2, 3], + "fields": { + "email": "user0003@example.com", + "id": 3, + "tenant": 2, + }, + }, + ], + ) + + def test_serialize_user_python(self): + users = User.objects.filter(pk=(2, 4)) + result = serializers.serialize("python", users) + self.assertEqual( + result, + [ + { + "model": "composite_pk.user", + "pk": [2, 4], + "fields": { + "email": "user0004@example.com", + "id": 4, + "tenant": 2, + }, + }, + ], + ) + + def test_serialize_post_uuid(self): + posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111")) + result = serializers.serialize("json", posts) + self.assertEqual( + json.loads(result), + [ + { + "model": "composite_pk.post", + "pk": [2, "11111111-1111-1111-1111-111111111111"], + "fields": { + "id": "11111111-1111-1111-1111-111111111111", + "tenant": 2, + }, + }, + ], + ) |
