summaryrefslogtreecommitdiff
path: root/tests/composite_pk
diff options
context:
space:
mode:
authorMariusz Felisiak <felisiak.mariusz@gmail.com>2024-12-01 20:55:35 +0100
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2024-12-02 11:03:42 +0100
commit49761ac99a064236b4280ca55f97a896913109cd (patch)
tree7c0ed97dd73c53af15f0eac2413b3025be40b5cd /tests/composite_pk
parent81cf690111e49b9cf9d8a3b8a71767f3c8685d5b (diff)
Refs #373 -- Simplified DatabaseIntrospection.get_constraints() tests for composite primary keys.
Diffstat (limited to 'tests/composite_pk')
-rw-r--r--tests/composite_pk/tests.py61
1 files changed, 11 insertions, 50 deletions
diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py
index 71522cb836..1a0a327baf 100644
--- a/tests/composite_pk/tests.py
+++ b/tests/composite_pk/tests.py
@@ -1,5 +1,4 @@
import json
-import unittest
from uuid import UUID
import yaml
@@ -35,9 +34,9 @@ class CompositePKTests(TestCase):
cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user)
@staticmethod
- def get_constraints(table):
+ def get_primary_key_columns(table):
with connection.cursor() as cursor:
- return connection.introspection.get_constraints(cursor, table)
+ return connection.introspection.get_primary_key_columns(cursor, table)
def test_pk_updated_if_field_updated(self):
user = User.objects.get(pk=self.user.pk)
@@ -125,53 +124,15 @@ class CompositePKTests(TestCase):
with self.assertRaises(IntegrityError):
Comment.objects.create(tenant=self.tenant, id=self.comment.id)
- @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
- def test_get_constraints_postgresql(self):
- user_constraints = self.get_constraints(User._meta.db_table)
- user_pk = user_constraints["composite_pk_user_pkey"]
- self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
- self.assertIs(user_pk["primary_key"], True)
-
- comment_constraints = self.get_constraints(Comment._meta.db_table)
- comment_pk = comment_constraints["composite_pk_comment_pkey"]
- self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
- self.assertIs(comment_pk["primary_key"], True)
-
- @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
- def test_get_constraints_sqlite(self):
- user_constraints = self.get_constraints(User._meta.db_table)
- user_pk = user_constraints["__primary__"]
- self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
- self.assertIs(user_pk["primary_key"], True)
-
- comment_constraints = self.get_constraints(Comment._meta.db_table)
- comment_pk = comment_constraints["__primary__"]
- self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
- self.assertIs(comment_pk["primary_key"], True)
-
- @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test")
- def test_get_constraints_mysql(self):
- user_constraints = self.get_constraints(User._meta.db_table)
- user_pk = user_constraints["PRIMARY"]
- self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
- self.assertIs(user_pk["primary_key"], True)
-
- comment_constraints = self.get_constraints(Comment._meta.db_table)
- comment_pk = comment_constraints["PRIMARY"]
- self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
- self.assertIs(comment_pk["primary_key"], True)
-
- @unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test")
- def test_get_constraints_oracle(self):
- user_constraints = self.get_constraints(User._meta.db_table)
- user_pk = next(c for c in user_constraints.values() if c["primary_key"])
- self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
- self.assertEqual(user_pk["primary_key"], 1)
-
- comment_constraints = self.get_constraints(Comment._meta.db_table)
- comment_pk = next(c for c in comment_constraints.values() if c["primary_key"])
- self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
- self.assertEqual(comment_pk["primary_key"], 1)
+ def test_get_primary_key_columns(self):
+ self.assertEqual(
+ self.get_primary_key_columns(User._meta.db_table),
+ ["tenant_id", "id"],
+ )
+ self.assertEqual(
+ self.get_primary_key_columns(Comment._meta.db_table),
+ ["tenant_id", "comment_id"],
+ )
def test_in_bulk(self):
"""