summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorElizabethU <elizabeth.uselton@gmail.com>2019-09-02 19:09:31 -0700
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2019-10-01 17:58:19 +0200
commit54ea290e5bbd19d87bd8dba807738eeeaf01a362 (patch)
treed83c186bde9f50faa13840e6ee227e3bb1e02ad6
parent6475e6318c970359a2f02798910a917229ee17d7 (diff)
Fixed #30651 -- Made __eq__() methods return NotImplemented for not implemented comparisons.
Changed __eq__ to return NotImplemented instead of False if compared to an object of the same type, as is recommended by the Python data model reference. Now these models can be compared to ANY (or other objects with __eq__ overwritten) without returning False automatically.
-rw-r--r--django/contrib/messages/storage/base.py5
-rw-r--r--django/contrib/postgres/constraints.py15
-rw-r--r--django/core/validators.py3
-rw-r--r--django/db/models/base.py2
-rw-r--r--django/db/models/constraints.py21
-rw-r--r--django/db/models/expressions.py4
-rw-r--r--django/db/models/indexes.py4
-rw-r--r--django/db/models/query.py4
-rw-r--r--django/db/models/query_utils.py3
-rw-r--r--django/template/context.py10
-rw-r--r--tests/basic/tests.py2
-rw-r--r--tests/constraints/tests.py7
-rw-r--r--tests/expressions/tests.py2
-rw-r--r--tests/filtered_relation/tests.py5
-rw-r--r--tests/messages_tests/tests.py3
-rw-r--r--tests/model_indexes/tests.py3
-rw-r--r--tests/postgres_tests/test_constraints.py2
-rw-r--r--tests/prefetch_related/tests.py3
-rw-r--r--tests/template_tests/test_context.py3
-rw-r--r--tests/validators/tests.py3
20 files changed, 71 insertions, 33 deletions
diff --git a/django/contrib/messages/storage/base.py b/django/contrib/messages/storage/base.py
index fd5d0c24aa..b2eeac77f4 100644
--- a/django/contrib/messages/storage/base.py
+++ b/django/contrib/messages/storage/base.py
@@ -25,8 +25,9 @@ class Message:
self.extra_tags = str(self.extra_tags) if self.extra_tags is not None else None
def __eq__(self, other):
- return isinstance(other, Message) and self.level == other.level and \
- self.message == other.message
+ if not isinstance(other, Message):
+ return NotImplemented
+ return self.level == other.level and self.message == other.message
def __str__(self):
return str(self.message)
diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py
index 2fcb076ecf..67e415ddcf 100644
--- a/django/contrib/postgres/constraints.py
+++ b/django/contrib/postgres/constraints.py
@@ -89,13 +89,14 @@ class ExclusionConstraint(BaseConstraint):
return path, args, kwargs
def __eq__(self, other):
- return (
- isinstance(other, self.__class__) and
- self.name == other.name and
- self.index_type == other.index_type and
- self.expressions == other.expressions and
- self.condition == other.condition
- )
+ if isinstance(other, self.__class__):
+ return (
+ self.name == other.name and
+ self.index_type == other.index_type and
+ self.expressions == other.expressions and
+ self.condition == other.condition
+ )
+ return super().__eq__(other)
def __repr__(self):
return '<%s: index_type=%s, expressions=%s%s>' % (
diff --git a/django/core/validators.py b/django/core/validators.py
index 2e00ca3ff3..38345a844f 100644
--- a/django/core/validators.py
+++ b/django/core/validators.py
@@ -324,8 +324,9 @@ class BaseValidator:
raise ValidationError(self.message, code=self.code, params=params)
def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
return (
- isinstance(other, self.__class__) and
self.limit_value == other.limit_value and
self.message == other.message and
self.code == other.code
diff --git a/django/db/models/base.py b/django/db/models/base.py
index 0b8425aa85..0a5e5ff673 100644
--- a/django/db/models/base.py
+++ b/django/db/models/base.py
@@ -522,7 +522,7 @@ class Model(metaclass=ModelBase):
def __eq__(self, other):
if not isinstance(other, Model):
- return False
+ return NotImplemented
if self._meta.concrete_model != other._meta.concrete_model:
return False
my_pk = self.pk
diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py
index e7f81d3ee9..fe0d42a168 100644
--- a/django/db/models/constraints.py
+++ b/django/db/models/constraints.py
@@ -54,11 +54,9 @@ class CheckConstraint(BaseConstraint):
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
def __eq__(self, other):
- return (
- isinstance(other, CheckConstraint) and
- self.name == other.name and
- self.check == other.check
- )
+ if isinstance(other, CheckConstraint):
+ return self.name == other.name and self.check == other.check
+ return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
@@ -106,12 +104,13 @@ class UniqueConstraint(BaseConstraint):
)
def __eq__(self, other):
- return (
- isinstance(other, UniqueConstraint) and
- self.name == other.name and
- self.fields == other.fields and
- self.condition == other.condition
- )
+ if isinstance(other, UniqueConstraint):
+ return (
+ self.name == other.name and
+ self.fields == other.fields and
+ self.condition == other.condition
+ )
+ return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 2b59dd301a..5df765b626 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -401,7 +401,9 @@ class BaseExpression:
return tuple(identity)
def __eq__(self, other):
- return isinstance(other, BaseExpression) and other.identity == self.identity
+ if not isinstance(other, BaseExpression):
+ return NotImplemented
+ return other.identity == self.identity
def __hash__(self):
return hash(self.identity)
diff --git a/django/db/models/indexes.py b/django/db/models/indexes.py
index b156366764..49f4989462 100644
--- a/django/db/models/indexes.py
+++ b/django/db/models/indexes.py
@@ -112,4 +112,6 @@ class Index:
)
def __eq__(self, other):
- return (self.__class__ == other.__class__) and (self.deconstruct() == other.deconstruct())
+ if self.__class__ == other.__class__:
+ return self.deconstruct() == other.deconstruct()
+ return NotImplemented
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 4417c17592..794e0faae7 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1543,7 +1543,9 @@ class Prefetch:
return None
def __eq__(self, other):
- return isinstance(other, Prefetch) and self.prefetch_to == other.prefetch_to
+ if not isinstance(other, Prefetch):
+ return NotImplemented
+ return self.prefetch_to == other.prefetch_to
def __hash__(self):
return hash((self.__class__, self.prefetch_to))
diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
index 7a667814f4..189fb4fa44 100644
--- a/django/db/models/query_utils.py
+++ b/django/db/models/query_utils.py
@@ -309,8 +309,9 @@ class FilteredRelation:
self.path = []
def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
return (
- isinstance(other, self.__class__) and
self.relation_name == other.relation_name and
self.alias == other.alias and
self.condition == other.condition
diff --git a/django/template/context.py b/django/template/context.py
index 8f349a3a96..f0a0cf2a00 100644
--- a/django/template/context.py
+++ b/django/template/context.py
@@ -124,12 +124,10 @@ class BaseContext:
"""
Compare two contexts by comparing theirs 'dicts' attributes.
"""
- return (
- isinstance(other, BaseContext) and
- # because dictionaries can be put in different order
- # we have to flatten them like in templates
- self.flatten() == other.flatten()
- )
+ if not isinstance(other, BaseContext):
+ return NotImplemented
+ # flatten dictionaries because they can be put in a different order.
+ return self.flatten() == other.flatten()
class Context(BaseContext):
diff --git a/tests/basic/tests.py b/tests/basic/tests.py
index 89f6048c96..5eada343e1 100644
--- a/tests/basic/tests.py
+++ b/tests/basic/tests.py
@@ -1,5 +1,6 @@
import threading
from datetime import datetime, timedelta
+from unittest import mock
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
@@ -354,6 +355,7 @@ class ModelTest(TestCase):
self.assertNotEqual(object(), Article(id=1))
a = Article()
self.assertEqual(a, a)
+ self.assertEqual(a, mock.ANY)
self.assertNotEqual(Article(), a)
def test_hash(self):
diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py
index 3b28c99e7f..8e2eb11e2a 100644
--- a/tests/constraints/tests.py
+++ b/tests/constraints/tests.py
@@ -1,3 +1,5 @@
+from unittest import mock
+
from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models
from django.db.models.constraints import BaseConstraint
@@ -39,6 +41,7 @@ class CheckConstraintTests(TestCase):
models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check1, name='price'),
)
+ self.assertEqual(models.CheckConstraint(check=check1, name='price'), mock.ANY)
self.assertNotEqual(
models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check1, name='price2'),
@@ -102,6 +105,10 @@ class UniqueConstraintTests(TestCase):
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
)
+ self.assertEqual(
+ models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
+ mock.ANY,
+ )
self.assertNotEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'),
diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py
index f50c634014..094b738792 100644
--- a/tests/expressions/tests.py
+++ b/tests/expressions/tests.py
@@ -3,6 +3,7 @@ import pickle
import unittest
import uuid
from copy import deepcopy
+from unittest import mock
from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, models
@@ -965,6 +966,7 @@ class SimpleExpressionTests(SimpleTestCase):
Expression(models.IntegerField()),
Expression(output_field=models.IntegerField())
)
+ self.assertEqual(Expression(models.IntegerField()), mock.ANY)
self.assertNotEqual(
Expression(models.IntegerField()),
Expression(models.CharField())
diff --git a/tests/filtered_relation/tests.py b/tests/filtered_relation/tests.py
index 52fe64dfa5..48154413a5 100644
--- a/tests/filtered_relation/tests.py
+++ b/tests/filtered_relation/tests.py
@@ -1,3 +1,5 @@
+from unittest import mock
+
from django.db import connection, transaction
from django.db.models import Case, Count, F, FilteredRelation, Q, When
from django.test import TestCase
@@ -323,6 +325,9 @@ class FilteredRelationTests(TestCase):
[self.book1]
)
+ def test_eq(self):
+ self.assertEqual(FilteredRelation('book', condition=Q(book__title='b')), mock.ANY)
+
class FilteredRelationAggregationTests(TestCase):
diff --git a/tests/messages_tests/tests.py b/tests/messages_tests/tests.py
index 1464783b33..eea07c9c41 100644
--- a/tests/messages_tests/tests.py
+++ b/tests/messages_tests/tests.py
@@ -1,3 +1,5 @@
+from unittest import mock
+
from django.contrib.messages import constants
from django.contrib.messages.storage.base import Message
from django.test import SimpleTestCase
@@ -9,6 +11,7 @@ class MessageTests(SimpleTestCase):
msg_2 = Message(constants.INFO, 'Test message 2')
msg_3 = Message(constants.WARNING, 'Test message 1')
self.assertEqual(msg_1, msg_1)
+ self.assertEqual(msg_1, mock.ANY)
self.assertNotEqual(msg_1, msg_2)
self.assertNotEqual(msg_1, msg_3)
self.assertNotEqual(msg_2, msg_3)
diff --git a/tests/model_indexes/tests.py b/tests/model_indexes/tests.py
index ade27e1a4b..6a31109031 100644
--- a/tests/model_indexes/tests.py
+++ b/tests/model_indexes/tests.py
@@ -1,3 +1,5 @@
+from unittest import mock
+
from django.conf import settings
from django.db import connection, models
from django.db.models.query_utils import Q
@@ -28,6 +30,7 @@ class SimpleIndexesTests(SimpleTestCase):
same_index.model = Book
another_index.model = Book
self.assertEqual(index, same_index)
+ self.assertEqual(index, mock.ANY)
self.assertNotEqual(index, another_index)
def test_index_fields_type(self):
diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py
index d8665f59f6..b22821294a 100644
--- a/tests/postgres_tests/test_constraints.py
+++ b/tests/postgres_tests/test_constraints.py
@@ -1,4 +1,5 @@
import datetime
+from unittest import mock
from django.db import connection, transaction
from django.db.models import F, Func, Q
@@ -175,6 +176,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
condition=Q(cancelled=False),
)
self.assertEqual(constraint_1, constraint_1)
+ self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2)
self.assertNotEqual(constraint_1, constraint_3)
self.assertNotEqual(constraint_2, constraint_3)
diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py
index 9ae939dcdf..930ba9fbc8 100644
--- a/tests/prefetch_related/tests.py
+++ b/tests/prefetch_related/tests.py
@@ -1,3 +1,5 @@
+from unittest import mock
+
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist
from django.db import connection
@@ -243,6 +245,7 @@ class PrefetchRelatedTests(TestDataMixin, TestCase):
prefetch_1 = Prefetch('authors', queryset=Author.objects.all())
prefetch_2 = Prefetch('books', queryset=Book.objects.all())
self.assertEqual(prefetch_1, prefetch_1)
+ self.assertEqual(prefetch_1, mock.ANY)
self.assertNotEqual(prefetch_1, prefetch_2)
def test_forward_m2m_to_attr_conflict(self):
diff --git a/tests/template_tests/test_context.py b/tests/template_tests/test_context.py
index 8c6fc98b42..1150a14639 100644
--- a/tests/template_tests/test_context.py
+++ b/tests/template_tests/test_context.py
@@ -1,3 +1,5 @@
+from unittest import mock
+
from django.http import HttpRequest
from django.template import (
Context, Engine, RequestContext, Template, Variable, VariableDoesNotExist,
@@ -18,6 +20,7 @@ class ContextTests(SimpleTestCase):
self.assertEqual(c.pop(), {"a": 2})
self.assertEqual(c["a"], 1)
self.assertEqual(c.get("foo", 42), 42)
+ self.assertEqual(c, mock.ANY)
def test_push_context_manager(self):
c = Context({"a": 1})
diff --git a/tests/validators/tests.py b/tests/validators/tests.py
index 36d0b2a520..295c6c899f 100644
--- a/tests/validators/tests.py
+++ b/tests/validators/tests.py
@@ -3,7 +3,7 @@ import re
import types
from datetime import datetime, timedelta
from decimal import Decimal
-from unittest import TestCase
+from unittest import TestCase, mock
from django.core.exceptions import ValidationError
from django.core.files.base import ContentFile
@@ -424,6 +424,7 @@ class TestValidatorEquality(TestCase):
MaxValueValidator(44),
MaxValueValidator(44),
)
+ self.assertEqual(MaxValueValidator(44), mock.ANY)
self.assertNotEqual(
MaxValueValidator(44),
MinValueValidator(44),