summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnssi Kääriäinen <akaariai@gmail.com>2014-07-05 09:03:52 +0300
committerTim Graham <timograham@gmail.com>2014-11-28 06:54:00 -0500
commitc7175fcdfe94be60c04f3b1ceb6d0b2def2b6f09 (patch)
tree409248caf9fe722d53eb1d7654176bb8a5f5c631
parent912ad03226687dae91971ebd7e5cf87521f6b0de (diff)
Fixed #901 -- Added Model.refresh_from_db() method
Thanks to github aliases dbrgn, carljm, slurms, dfunckt, and timgraham for reviews.
-rw-r--r--django/db/models/base.py68
-rw-r--r--django/db/models/query_utils.py10
-rw-r--r--docs/ref/models/instances.txt67
-rw-r--r--docs/releases/1.8.txt6
-rw-r--r--tests/basic/tests.py59
-rw-r--r--tests/defer/models.py14
-rw-r--r--tests/defer/tests.py28
-rw-r--r--tests/field_subclassing/tests.py6
-rw-r--r--tests/multiple_database/tests.py18
9 files changed, 265 insertions, 11 deletions
diff --git a/django/db/models/base.py b/django/db/models/base.py
index e4f970af05..c2f711ea08 100644
--- a/django/db/models/base.py
+++ b/django/db/models/base.py
@@ -13,6 +13,8 @@ from django.core.exceptions import (ObjectDoesNotExist,
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
from django.db import (router, connections, transaction, DatabaseError,
DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY)
+from django.db.models import signals
+from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import Collector
from django.db.models.fields import AutoField, FieldDoesNotExist
from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
@@ -21,7 +23,6 @@ from django.db.models.manager import ensure_default_manager
from django.db.models.options import Options
from django.db.models.query import Q
from django.db.models.query_utils import DeferredAttribute, deferred_class_factory
-from django.db.models import signals
from django.utils import six
from django.utils.deprecation import RemovedInDjango19Warning
from django.utils.encoding import force_str, force_text
@@ -552,6 +553,71 @@ class Model(six.with_metaclass(ModelBase)):
pk = property(_get_pk_val, _set_pk_val)
+ def get_deferred_fields(self):
+ """
+ Returns a set containing names of deferred fields for this instance.
+ """
+ return {
+ f.attname for f in self._meta.concrete_fields
+ if isinstance(self.__class__.__dict__.get(f.attname), DeferredAttribute)
+ }
+
+ def refresh_from_db(self, using=None, fields=None, **kwargs):
+ """
+ Reloads field values from the database.
+
+ By default, the reloading happens from the database this instance was
+ loaded from, or by the read router if this instance wasn't loaded from
+ any database. The using parameter will override the default.
+
+ Fields can be used to specify which fields to reload. The fields
+ should be an iterable of field attnames. If fields is None, then
+ all non-deferred fields are reloaded.
+
+ When accessing deferred fields of an instance, the deferred loading
+ of the field will call this method.
+ """
+ if fields is not None:
+ if len(fields) == 0:
+ return
+ if any(LOOKUP_SEP in f for f in fields):
+ raise ValueError(
+ 'Found "%s" in fields argument. Relations and transforms '
+ 'are not allowed in fields.' % LOOKUP_SEP)
+
+ db = using if using is not None else self._state.db
+ if self._deferred:
+ non_deferred_model = self._meta.proxy_for_model
+ else:
+ non_deferred_model = self.__class__
+ db_instance_qs = non_deferred_model._default_manager.using(db).filter(pk=self.pk)
+
+ # Use provided fields, if not set then reload all non-deferred fields.
+ if fields is not None:
+ fields = list(fields)
+ db_instance_qs = db_instance_qs.only(*fields)
+ elif self._deferred:
+ deferred_fields = self.get_deferred_fields()
+ fields = [f.attname for f in self._meta.concrete_fields
+ if f.attname not in deferred_fields]
+ db_instance_qs = db_instance_qs.only(*fields)
+
+ db_instance = db_instance_qs.get()
+ non_loaded_fields = db_instance.get_deferred_fields()
+ for field in self._meta.concrete_fields:
+ if field.attname in non_loaded_fields:
+ # This field wasn't refreshed - skip ahead.
+ continue
+ setattr(self, field.attname, getattr(db_instance, field.attname))
+ # Throw away stale foreign key references.
+ if field.rel and field.get_cache_name() in self.__dict__:
+ rel_instance = getattr(self, field.get_cache_name())
+ local_val = getattr(db_instance, field.attname)
+ related_val = getattr(rel_instance, field.related_field.attname)
+ if local_val != related_val:
+ del self.__dict__[field.get_cache_name()]
+ self._state.db = db_instance._state.db
+
def serializable_value(self, field_name):
"""
Returns the value of the field name for this instance. If the field is
diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
index 6dbeb855d1..0fbe25cca7 100644
--- a/django/db/models/query_utils.py
+++ b/django/db/models/query_utils.py
@@ -109,14 +109,8 @@ class DeferredAttribute(object):
# might be able to reuse the already loaded value. Refs #18343.
val = self._check_parent_chain(instance, name)
if val is None:
- # We use only() instead of values() here because we want the
- # various data coercion methods (to_python(), etc.) to be
- # called here.
- val = getattr(
- non_deferred_model._base_manager.only(name).using(
- instance._state.db).get(pk=instance.pk),
- self.field_name
- )
+ instance.refresh_from_db(fields=[self.field_name])
+ val = getattr(instance, self.field_name)
data[self.field_name] = val
return data[self.field_name]
diff --git a/docs/ref/models/instances.txt b/docs/ref/models/instances.txt
index 4e8a1d1ee8..21dbb5af1a 100644
--- a/docs/ref/models/instances.txt
+++ b/docs/ref/models/instances.txt
@@ -116,6 +116,73 @@ The example above shows a full ``from_db()`` implementation to clarify how that
is done. In this case it would of course be possible to just use ``super()`` call
in the ``from_db()`` method.
+Refreshing objects from database
+================================
+
+.. method:: Model.refresh_from_db(using=None, fields=None, **kwargs)
+
+.. versionadded:: 1.8
+
+If you need to reload a model's values from the database, you can use the
+``refresh_from_db()`` method. When this method is called without arguments the
+following is done:
+
+1. All non-deferred fields of the model are updated to the values currently
+ present in the database.
+2. The previously loaded related instances for which the relation's value is no
+ longer valid are removed from the reloaded instance. For example, if you have
+ a foreign key from the reloaded instance to another model with name
+ ``Author``, then if ``obj.author_id != obj.author.id``, ``obj.author`` will
+ be thrown away, and when next accessed it will be reloaded with the value of
+ ``obj.author_id``.
+
+Note that only fields of the model are reloaded from the database. Other
+database dependent values such as annotations are not reloaded.
+
+The reloading happens from the database the instance was loaded from, or from
+the default database if the instance wasn't loaded from the database. The
+``using`` argument can be used to force the database used for reloading.
+
+It is possible to force the set of fields to be loaded by using the ``fields``
+argument.
+
+For example, to test that an ``update()`` call resulted in the expected
+update, you could write a test similar to this::
+
+ def test_update_result(self):
+ obj = MyModel.objects.create(val=1)
+ MyModel.objects.filter(pk=obj.pk).update(val=F('val') + 1)
+ # At this point obj.val is still 1, but the value in the database
+ # was updated to 2. The object's updated value needs to be reloaded
+ # from the database.
+ obj.refresh_from_db()
+ self.assertEqual(obj.val, 2)
+
+Note that when deferred fields are accessed, the loading of the deferred
+field's value happens through this method. Thus it is possible to customize
+the way deferred loading happens. The example below shows how one can reload
+all of the instance's fields when a deferred field is reloaded::
+
+ class ExampleModel(models.Model):
+ def refresh_from_db(self, using=None, fields=None, **kwargs):
+ # fields contains the name of the deferred field to be
+ # loaded.
+ if fields is not None:
+ fields = set(fields)
+ deferred_fields = self.get_deferred_fields()
+ # If any deferred field is going to be loaded
+ if fields.intersection(deferred_fields):
+ # then load all of them
+ fields = fields.union(deferred_fields)
+ super(ExampleModel, self).refresh_from_db(using, fields, **kwargs)
+
+.. method:: Model.get_deferred_fields()
+
+.. versionadded:: 1.8
+
+A helper method that returns a set containing the attribute names of all those
+fields that are currently deferred for this model.
+
.. _validating-objects:
Validating objects
diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt
index fb4e44576c..5df0cc1bd1 100644
--- a/docs/releases/1.8.txt
+++ b/docs/releases/1.8.txt
@@ -391,6 +391,12 @@ Models
by the database, which can lead to somewhat complex queries involving nested
``REPLACE`` function calls.
+* You can now refresh model instances by using :meth:`Model.refresh_from_db()
+ <django.db.models.Model.refresh_from_db>`.
+
+* You can now get the set of deferred fields for a model using
+ :meth:`Model.get_deferred_fields() <django.db.models.Model.get_deferred_fields>`.
+
Signals
^^^^^^^
diff --git a/tests/basic/tests.py b/tests/basic/tests.py
index a10e38f2a4..31e8b724bc 100644
--- a/tests/basic/tests.py
+++ b/tests/basic/tests.py
@@ -1,6 +1,6 @@
from __future__ import unicode_literals
-from datetime import datetime
+from datetime import datetime, timedelta
import threading
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
@@ -713,3 +713,60 @@ class SelectOnSaveTests(TestCase):
asos.save(update_fields=['pub_date'])
finally:
Article._base_manager.__class__ = orig_class
+
+
+class ModelRefreshTests(TestCase):
+ def _truncate_ms(self, val):
+ # MySQL < 5.6.4 removes microseconds from the datetimes which can cause
+ # problems when comparing the original value to that loaded from DB
+ return val - timedelta(microseconds=val.microsecond)
+
+ def test_refresh(self):
+ a = Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
+ Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
+ Article.objects.filter(pk=a.pk).update(headline='new headline')
+ with self.assertNumQueries(1):
+ a.refresh_from_db()
+ self.assertEqual(a.headline, 'new headline')
+
+ orig_pub_date = a.pub_date
+ new_pub_date = a.pub_date + timedelta(10)
+ Article.objects.update(headline='new headline 2', pub_date=new_pub_date)
+ with self.assertNumQueries(1):
+ a.refresh_from_db(fields=['headline'])
+ self.assertEqual(a.headline, 'new headline 2')
+ self.assertEqual(a.pub_date, orig_pub_date)
+ with self.assertNumQueries(1):
+ a.refresh_from_db()
+ self.assertEqual(a.pub_date, new_pub_date)
+
+ def test_refresh_fk(self):
+ s1 = SelfRef.objects.create()
+ s2 = SelfRef.objects.create()
+ s3 = SelfRef.objects.create(selfref=s1)
+ s3_copy = SelfRef.objects.get(pk=s3.pk)
+ s3_copy.selfref.touched = True
+ s3.selfref = s2
+ s3.save()
+ with self.assertNumQueries(1):
+ s3_copy.refresh_from_db()
+ with self.assertNumQueries(1):
+ # The old related instance was thrown away (the selfref_id has
+ # changed). It needs to be reloaded on access, so one query
+ # executed.
+ self.assertFalse(hasattr(s3_copy.selfref, 'touched'))
+ self.assertEqual(s3_copy.selfref, s2)
+
+ def test_refresh_unsaved(self):
+ pub_date = self._truncate_ms(datetime.now())
+ a = Article.objects.create(pub_date=pub_date)
+ a2 = Article(id=a.pk)
+ with self.assertNumQueries(1):
+ a2.refresh_from_db()
+ self.assertEqual(a2.pub_date, pub_date)
+ self.assertEqual(a2._state.db, "default")
+
+ def test_refresh_no_fields(self):
+ a = Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
+ with self.assertNumQueries(0):
+ a.refresh_from_db(fields=[])
diff --git a/tests/defer/models.py b/tests/defer/models.py
index ffc8a0c2c7..ecf69c0d7f 100644
--- a/tests/defer/models.py
+++ b/tests/defer/models.py
@@ -32,3 +32,17 @@ class BigChild(Primary):
class ChildProxy(Child):
class Meta:
proxy = True
+
+
+class RefreshPrimaryProxy(Primary):
+ class Meta:
+ proxy = True
+
+ def refresh_from_db(self, using=None, fields=None, **kwargs):
+ # Reloads all deferred fields if any of the fields is deferred.
+ if fields is not None:
+ fields = set(fields)
+ deferred_fields = self.get_deferred_fields()
+ if fields.intersection(deferred_fields):
+ fields = fields.union(deferred_fields)
+ super(RefreshPrimaryProxy, self).refresh_from_db(using, fields, **kwargs)
diff --git a/tests/defer/tests.py b/tests/defer/tests.py
index 43a088f3e2..597f871cc8 100644
--- a/tests/defer/tests.py
+++ b/tests/defer/tests.py
@@ -3,7 +3,7 @@ from __future__ import unicode_literals
from django.db.models.query_utils import DeferredAttribute, InvalidQuery
from django.test import TestCase
-from .models import Secondary, Primary, Child, BigChild, ChildProxy
+from .models import Secondary, Primary, Child, BigChild, ChildProxy, RefreshPrimaryProxy
class DeferTests(TestCase):
@@ -189,3 +189,29 @@ class DeferTests(TestCase):
s1_defer = Secondary.objects.only('pk').get(pk=s1.pk)
self.assertEqual(s1, s1_defer)
self.assertEqual(s1_defer, s1)
+
+ def test_refresh_not_loading_deferred_fields(self):
+ s = Secondary.objects.create()
+ rf = Primary.objects.create(name='foo', value='bar', related=s)
+ rf2 = Primary.objects.only('related', 'value').get()
+ rf.name = 'new foo'
+ rf.value = 'new bar'
+ rf.save()
+ with self.assertNumQueries(1):
+ rf2.refresh_from_db()
+ self.assertEqual(rf2.value, 'new bar')
+ with self.assertNumQueries(1):
+ self.assertEqual(rf2.name, 'new foo')
+
+ def test_custom_refresh_on_deferred_loading(self):
+ s = Secondary.objects.create()
+ rf = RefreshPrimaryProxy.objects.create(name='foo', value='bar', related=s)
+ rf2 = RefreshPrimaryProxy.objects.only('related').get()
+ rf.name = 'new foo'
+ rf.value = 'new bar'
+ rf.save()
+ with self.assertNumQueries(1):
+ # Customized refresh_from_db() reloads all deferred fields on
+ # access of any of them.
+ self.assertEqual(rf2.name, 'new foo')
+ self.assertEqual(rf2.value, 'new bar')
diff --git a/tests/field_subclassing/tests.py b/tests/field_subclassing/tests.py
index 5c695a455c..9e40d92496 100644
--- a/tests/field_subclassing/tests.py
+++ b/tests/field_subclassing/tests.py
@@ -11,6 +11,12 @@ from .models import ChoicesModel, DataModel, MyModel, OtherModel
class CustomField(TestCase):
+ def test_refresh(self):
+ d = DataModel.objects.create(data=[1, 2, 3])
+ d.refresh_from_db(fields=['data'])
+ self.assertIsInstance(d.data, list)
+ self.assertEqual(d.data, [1, 2, 3])
+
def test_defer(self):
d = DataModel.objects.create(data=[1, 2, 3])
diff --git a/tests/multiple_database/tests.py b/tests/multiple_database/tests.py
index f230311eed..17a1ef90b1 100644
--- a/tests/multiple_database/tests.py
+++ b/tests/multiple_database/tests.py
@@ -112,6 +112,24 @@ class QueryTestCase(TestCase):
title="Dive into Python"
)
+ def test_refresh(self):
+ dive = Book()
+ dive.title = "Dive into Python"
+ dive = Book()
+ dive.title = "Dive into Python"
+ dive.published = datetime.date(2009, 5, 4)
+ dive.save(using='other')
+ dive.published = datetime.date(2009, 5, 4)
+ dive.save(using='other')
+ dive2 = Book.objects.using('other').get()
+ dive2.title = "Dive into Python (on default)"
+ dive2.save(using='default')
+ dive.refresh_from_db()
+ self.assertEqual(dive.title, "Dive into Python")
+ dive.refresh_from_db(using='default')
+ self.assertEqual(dive.title, "Dive into Python (on default)")
+ self.assertEqual(dive._state.db, "default")
+
def test_basic_queries(self):
"Queries are constrained to a single database"
dive = Book.objects.using('other').create(title="Dive into Python",