summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAymeric Augustin <aymeric.augustin@m4x.org>2014-02-02 10:37:27 -0800
committerAymeric Augustin <aymeric.augustin@m4x.org>2014-02-02 10:37:27 -0800
commit54bfa4caab11d364eb208ba6639836fa22d69a04 (patch)
treef24020307dd5b529989329bfabfc95effcecffe8
parentab2f21080d8b3112c1ba9a0bf923eae733be4242 (diff)
parent3ffeb931869cc68a8e0916219702ee282afc6e9d (diff)
Merge pull request #2154 from manfre/close-cursors
Fixed #21751 -- Explicitly closed cursors.
-rw-r--r--django/contrib/gis/db/backends/postgis/creation.py14
-rw-r--r--django/contrib/gis/db/backends/spatialite/creation.py5
-rw-r--r--django/contrib/sites/management.py6
-rw-r--r--django/core/cache/backends/db.py129
-rw-r--r--django/core/management/commands/createcachetable.py18
-rw-r--r--django/core/management/commands/flush.py6
-rw-r--r--django/core/management/commands/inspectdb.py184
-rw-r--r--django/core/management/commands/loaddata.py7
-rw-r--r--django/core/management/commands/migrate.py181
-rw-r--r--django/core/management/sql.py55
-rw-r--r--django/db/backends/__init__.py30
-rw-r--r--django/db/backends/creation.py71
-rw-r--r--django/db/backends/mysql/base.py37
-rw-r--r--django/db/backends/oracle/base.py4
-rw-r--r--django/db/backends/postgresql_psycopg2/base.py6
-rw-r--r--django/db/backends/postgresql_psycopg2/version.py6
-rw-r--r--django/db/backends/schema.py8
-rw-r--r--django/db/backends/sqlite3/base.py16
-rw-r--r--django/db/models/query.py94
-rw-r--r--django/db/models/sql/compiler.py97
-rw-r--r--django/db/models/sql/constants.py2
-rw-r--r--django/db/models/sql/subqueries.py8
-rw-r--r--tests/backends/tests.py53
-rw-r--r--tests/cache/tests.py7
-rw-r--r--tests/custom_methods/models.py16
-rw-r--r--tests/initial_sql_regress/tests.py6
-rw-r--r--tests/introspection/tests.py60
-rw-r--r--tests/migrations/test_base.py35
-rw-r--r--tests/migrations/test_operations.py72
-rw-r--r--tests/requests/tests.py2
-rw-r--r--tests/schema/tests.py96
-rw-r--r--tests/transactions/tests.py13
-rw-r--r--tests/transactions_regress/tests.py12
33 files changed, 720 insertions, 636 deletions
diff --git a/django/contrib/gis/db/backends/postgis/creation.py b/django/contrib/gis/db/backends/postgis/creation.py
index 51ac197b8e..82be18cb65 100644
--- a/django/contrib/gis/db/backends/postgis/creation.py
+++ b/django/contrib/gis/db/backends/postgis/creation.py
@@ -11,10 +11,10 @@ class PostGISCreation(DatabaseCreation):
@cached_property
def template_postgis(self):
template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis')
- cursor = self.connection.cursor()
- cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
- if cursor.fetchone():
- return template_postgis
+ with self.connection.cursor() as cursor:
+ cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
+ if cursor.fetchone():
+ return template_postgis
return None
def sql_indexes_for_field(self, model, f, style):
@@ -88,8 +88,8 @@ class PostGISCreation(DatabaseCreation):
# Connect to the test database in order to create the postgis extension
self.connection.close()
self.connection.settings_dict["NAME"] = test_database_name
- cursor = self.connection.cursor()
- cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
- cursor.connection.commit()
+ with self.connection.cursor() as cursor:
+ cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
+ cursor.connection.commit()
return test_database_name
diff --git a/django/contrib/gis/db/backends/spatialite/creation.py b/django/contrib/gis/db/backends/spatialite/creation.py
index 521985259e..06f105d563 100644
--- a/django/contrib/gis/db/backends/spatialite/creation.py
+++ b/django/contrib/gis/db/backends/spatialite/creation.py
@@ -55,9 +55,8 @@ class SpatiaLiteCreation(DatabaseCreation):
call_command('createcachetable', database=self.connection.alias)
- # Get a cursor (even though we don't need one yet). This has
- # the side effect of initializing the test database.
- self.connection.cursor()
+ # Ensure a connection for the side effect of initializing the test database.
+ self.connection.ensure_connection()
return test_database_name
diff --git a/django/contrib/sites/management.py b/django/contrib/sites/management.py
index e7624e75cf..8353a6f496 100644
--- a/django/contrib/sites/management.py
+++ b/django/contrib/sites/management.py
@@ -33,9 +33,9 @@ def create_default_site(app_config, verbosity=2, interactive=True, db=DEFAULT_DB
if sequence_sql:
if verbosity >= 2:
print("Resetting sequence")
- cursor = connections[db].cursor()
- for command in sequence_sql:
- cursor.execute(command)
+ with connections[db].cursor() as cursor:
+ for command in sequence_sql:
+ cursor.execute(command)
Site.objects.clear_cache()
diff --git a/django/core/cache/backends/db.py b/django/core/cache/backends/db.py
index a21777aaba..959095026b 100644
--- a/django/core/cache/backends/db.py
+++ b/django/core/cache/backends/db.py
@@ -59,11 +59,11 @@ class DatabaseCache(BaseDatabaseCache):
self.validate_key(key)
db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
- cursor = connections[db].cursor()
- cursor.execute("SELECT cache_key, value, expires FROM %s "
- "WHERE cache_key = %%s" % table, [key])
- row = cursor.fetchone()
+ with connections[db].cursor() as cursor:
+ cursor.execute("SELECT cache_key, value, expires FROM %s "
+ "WHERE cache_key = %%s" % table, [key])
+ row = cursor.fetchone()
if row is None:
return default
now = timezone.now()
@@ -75,9 +75,9 @@ class DatabaseCache(BaseDatabaseCache):
expires = typecast_timestamp(str(expires))
if expires < now:
db = router.db_for_write(self.cache_model_class)
- cursor = connections[db].cursor()
- cursor.execute("DELETE FROM %s "
- "WHERE cache_key = %%s" % table, [key])
+ with connections[db].cursor() as cursor:
+ cursor.execute("DELETE FROM %s "
+ "WHERE cache_key = %%s" % table, [key])
return default
value = connections[db].ops.process_clob(row[1])
return pickle.loads(base64.b64decode(force_bytes(value)))
@@ -96,55 +96,55 @@ class DatabaseCache(BaseDatabaseCache):
timeout = self.get_backend_timeout(timeout)
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
- cursor = connections[db].cursor()
- cursor.execute("SELECT COUNT(*) FROM %s" % table)
- num = cursor.fetchone()[0]
- now = timezone.now()
- now = now.replace(microsecond=0)
- if timeout is None:
- exp = datetime.max
- elif settings.USE_TZ:
- exp = datetime.utcfromtimestamp(timeout)
- else:
- exp = datetime.fromtimestamp(timeout)
- exp = exp.replace(microsecond=0)
- if num > self._max_entries:
- self._cull(db, cursor, now)
- pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
- b64encoded = base64.b64encode(pickled)
- # The DB column is expecting a string, so make sure the value is a
- # string, not bytes. Refs #19274.
- if six.PY3:
- b64encoded = b64encoded.decode('latin1')
- try:
- # Note: typecasting for datetimes is needed by some 3rd party
- # database backends. All core backends work without typecasting,
- # so be careful about changes here - test suite will NOT pick
- # regressions.
- with transaction.atomic(using=db):
- cursor.execute("SELECT cache_key, expires FROM %s "
- "WHERE cache_key = %%s" % table, [key])
- result = cursor.fetchone()
- if result:
- current_expires = result[1]
- if (connections[db].features.needs_datetime_string_cast and not
- isinstance(current_expires, datetime)):
- current_expires = typecast_timestamp(str(current_expires))
- exp = connections[db].ops.value_to_db_datetime(exp)
- if result and (mode == 'set' or (mode == 'add' and current_expires < now)):
- cursor.execute("UPDATE %s SET value = %%s, expires = %%s "
- "WHERE cache_key = %%s" % table,
- [b64encoded, exp, key])
- else:
- cursor.execute("INSERT INTO %s (cache_key, value, expires) "
- "VALUES (%%s, %%s, %%s)" % table,
- [key, b64encoded, exp])
- except DatabaseError:
- # To be threadsafe, updates/inserts are allowed to fail silently
- return False
- else:
- return True
+ with connections[db].cursor() as cursor:
+ cursor.execute("SELECT COUNT(*) FROM %s" % table)
+ num = cursor.fetchone()[0]
+ now = timezone.now()
+ now = now.replace(microsecond=0)
+ if timeout is None:
+ exp = datetime.max
+ elif settings.USE_TZ:
+ exp = datetime.utcfromtimestamp(timeout)
+ else:
+ exp = datetime.fromtimestamp(timeout)
+ exp = exp.replace(microsecond=0)
+ if num > self._max_entries:
+ self._cull(db, cursor, now)
+ pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
+ b64encoded = base64.b64encode(pickled)
+ # The DB column is expecting a string, so make sure the value is a
+ # string, not bytes. Refs #19274.
+ if six.PY3:
+ b64encoded = b64encoded.decode('latin1')
+ try:
+ # Note: typecasting for datetimes is needed by some 3rd party
+ # database backends. All core backends work without typecasting,
+ # so be careful about changes here - test suite will NOT pick
+ # regressions.
+ with transaction.atomic(using=db):
+ cursor.execute("SELECT cache_key, expires FROM %s "
+ "WHERE cache_key = %%s" % table, [key])
+ result = cursor.fetchone()
+ if result:
+ current_expires = result[1]
+ if (connections[db].features.needs_datetime_string_cast and not
+ isinstance(current_expires, datetime)):
+ current_expires = typecast_timestamp(str(current_expires))
+ exp = connections[db].ops.value_to_db_datetime(exp)
+ if result and (mode == 'set' or (mode == 'add' and current_expires < now)):
+ cursor.execute("UPDATE %s SET value = %%s, expires = %%s "
+ "WHERE cache_key = %%s" % table,
+ [b64encoded, exp, key])
+ else:
+ cursor.execute("INSERT INTO %s (cache_key, value, expires) "
+ "VALUES (%%s, %%s, %%s)" % table,
+ [key, b64encoded, exp])
+ except DatabaseError:
+ # To be threadsafe, updates/inserts are allowed to fail silently
+ return False
+ else:
+ return True
def delete(self, key, version=None):
key = self.make_key(key, version=version)
@@ -152,9 +152,9 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
- cursor = connections[db].cursor()
- cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
+ with connections[db].cursor() as cursor:
+ cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
def has_key(self, key, version=None):
key = self.make_key(key, version=version)
@@ -162,17 +162,18 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
- cursor = connections[db].cursor()
if settings.USE_TZ:
now = datetime.utcnow()
else:
now = datetime.now()
now = now.replace(microsecond=0)
- cursor.execute("SELECT cache_key FROM %s "
- "WHERE cache_key = %%s and expires > %%s" % table,
- [key, connections[db].ops.value_to_db_datetime(now)])
- return cursor.fetchone() is not None
+
+ with connections[db].cursor() as cursor:
+ cursor.execute("SELECT cache_key FROM %s "
+ "WHERE cache_key = %%s and expires > %%s" % table,
+ [key, connections[db].ops.value_to_db_datetime(now)])
+ return cursor.fetchone() is not None
def _cull(self, db, cursor, now):
if self._cull_frequency == 0:
@@ -197,8 +198,8 @@ class DatabaseCache(BaseDatabaseCache):
def clear(self):
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
- cursor = connections[db].cursor()
- cursor.execute('DELETE FROM %s' % table)
+ with connections[db].cursor() as cursor:
+ cursor.execute('DELETE FROM %s' % table)
# For backwards compatibility
diff --git a/django/core/management/commands/createcachetable.py b/django/core/management/commands/createcachetable.py
index 10506525fc..909a5d08c8 100644
--- a/django/core/management/commands/createcachetable.py
+++ b/django/core/management/commands/createcachetable.py
@@ -72,14 +72,14 @@ class Command(BaseCommand):
full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
full_statement.append(');')
with transaction.commit_on_success_unless_managed():
- curs = connection.cursor()
- try:
- curs.execute("\n".join(full_statement))
- except DatabaseError as e:
- raise CommandError(
- "Cache table '%s' could not be created.\nThe error was: %s." %
- (tablename, force_text(e)))
- for statement in index_output:
- curs.execute(statement)
+ with connection.cursor() as curs:
+ try:
+ curs.execute("\n".join(full_statement))
+ except DatabaseError as e:
+ raise CommandError(
+ "Cache table '%s' could not be created.\nThe error was: %s." %
+ (tablename, force_text(e)))
+ for statement in index_output:
+ curs.execute(statement)
if self.verbosity > 1:
self.stdout.write("Cache table '%s' created." % tablename)
diff --git a/django/core/management/commands/flush.py b/django/core/management/commands/flush.py
index 4a3f7c2d8b..d99deb951e 100644
--- a/django/core/management/commands/flush.py
+++ b/django/core/management/commands/flush.py
@@ -64,9 +64,9 @@ Are you sure you want to do this?
if confirm == 'yes':
try:
with transaction.commit_on_success_unless_managed():
- cursor = connection.cursor()
- for sql in sql_list:
- cursor.execute(sql)
+ with connection.cursor() as cursor:
+ for sql in sql_list:
+ cursor.execute(sql)
except Exception as e:
new_msg = (
"Database %s couldn't be flushed. Possible reasons:\n"
diff --git a/django/core/management/commands/inspectdb.py b/django/core/management/commands/inspectdb.py
index 54fdad2001..4a51892e5a 100644
--- a/django/core/management/commands/inspectdb.py
+++ b/django/core/management/commands/inspectdb.py
@@ -37,108 +37,108 @@ class Command(NoArgsCommand):
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
strip_prefix = lambda s: s[1:] if s.startswith("u'") else s
- cursor = connection.cursor()
- yield "# This is an auto-generated Django model module."
- yield "# You'll have to do the following manually to clean this up:"
- yield "# * Rearrange models' order"
- yield "# * Make sure each model has one field with primary_key=True"
- yield "# * Remove `managed = False` lines for those models you wish to give write DB access"
- yield "# Feel free to rename the models, but don't rename db_table values or field names."
- yield "#"
- yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'"
- yield "# into your database."
- yield "from __future__ import unicode_literals"
- yield ''
- yield 'from %s import models' % self.db_module
- known_models = []
- for table_name in connection.introspection.table_names(cursor):
- if table_name_filter is not None and callable(table_name_filter):
- if not table_name_filter(table_name):
- continue
+ with connection.cursor() as cursor:
+ yield "# This is an auto-generated Django model module."
+ yield "# You'll have to do the following manually to clean this up:"
+ yield "# * Rearrange models' order"
+ yield "# * Make sure each model has one field with primary_key=True"
+ yield "# * Remove `managed = False` lines for those models you wish to give write DB access"
+ yield "# Feel free to rename the models, but don't rename db_table values or field names."
+ yield "#"
+ yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'"
+ yield "# into your database."
+ yield "from __future__ import unicode_literals"
yield ''
- yield ''
- yield 'class %s(models.Model):' % table2model(table_name)
- known_models.append(table2model(table_name))
- try:
- relations = connection.introspection.get_relations(cursor, table_name)
- except NotImplementedError:
- relations = {}
- try:
- indexes = connection.introspection.get_indexes(cursor, table_name)
- except NotImplementedError:
- indexes = {}
- used_column_names = [] # Holds column names used in the table so far
- for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
- comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
- extra_params = OrderedDict() # Holds Field parameters such as 'db_column'.
- column_name = row[0]
- is_relation = i in relations
+ yield 'from %s import models' % self.db_module
+ known_models = []
+ for table_name in connection.introspection.table_names(cursor):
+ if table_name_filter is not None and callable(table_name_filter):
+ if not table_name_filter(table_name):
+ continue
+ yield ''
+ yield ''
+ yield 'class %s(models.Model):' % table2model(table_name)
+ known_models.append(table2model(table_name))
+ try:
+ relations = connection.introspection.get_relations(cursor, table_name)
+ except NotImplementedError:
+ relations = {}
+ try:
+ indexes = connection.introspection.get_indexes(cursor, table_name)
+ except NotImplementedError:
+ indexes = {}
+ used_column_names = [] # Holds column names used in the table so far
+ for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
+ comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
+ extra_params = OrderedDict() # Holds Field parameters such as 'db_column'.
+ column_name = row[0]
+ is_relation = i in relations
- att_name, params, notes = self.normalize_col_name(
- column_name, used_column_names, is_relation)
- extra_params.update(params)
- comment_notes.extend(notes)
+ att_name, params, notes = self.normalize_col_name(
+ column_name, used_column_names, is_relation)
+ extra_params.update(params)
+ comment_notes.extend(notes)
- used_column_names.append(att_name)
+ used_column_names.append(att_name)
- # Add primary_key and unique, if necessary.
- if column_name in indexes:
- if indexes[column_name]['primary_key']:
- extra_params['primary_key'] = True
- elif indexes[column_name]['unique']:
- extra_params['unique'] = True
+ # Add primary_key and unique, if necessary.
+ if column_name in indexes:
+ if indexes[column_name]['primary_key']:
+ extra_params['primary_key'] = True
+ elif indexes[column_name]['unique']:
+ extra_params['unique'] = True
- if is_relation:
- rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
- if rel_to in known_models:
- field_type = 'ForeignKey(%s' % rel_to
+ if is_relation:
+ rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
+ if rel_to in known_models:
+ field_type = 'ForeignKey(%s' % rel_to
+ else:
+ field_type = "ForeignKey('%s'" % rel_to
else:
- field_type = "ForeignKey('%s'" % rel_to
- else:
- # Calling `get_field_type` to get the field type string and any
- # additional paramters and notes.
- field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
- extra_params.update(field_params)
- comment_notes.extend(field_notes)
+ # Calling `get_field_type` to get the field type string and any
+ # additional paramters and notes.
+ field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
+ extra_params.update(field_params)
+ comment_notes.extend(field_notes)
- field_type += '('
+ field_type += '('
- # Don't output 'id = meta.AutoField(primary_key=True)', because
- # that's assumed if it doesn't exist.
- if att_name == 'id' and extra_params == {'primary_key': True}:
- if field_type == 'AutoField(':
- continue
- elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield:
- comment_notes.append('AutoField?')
+ # Don't output 'id = meta.AutoField(primary_key=True)', because
+ # that's assumed if it doesn't exist.
+ if att_name == 'id' and extra_params == {'primary_key': True}:
+ if field_type == 'AutoField(':
+ continue
+ elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield:
+ comment_notes.append('AutoField?')
- # Add 'null' and 'blank', if the 'null_ok' flag was present in the
- # table description.
- if row[6]: # If it's NULL...
- if field_type == 'BooleanField(':
- field_type = 'NullBooleanField('
- else:
- extra_params['blank'] = True
- if not field_type in ('TextField(', 'CharField('):
- extra_params['null'] = True
+ # Add 'null' and 'blank', if the 'null_ok' flag was present in the
+ # table description.
+ if row[6]: # If it's NULL...
+ if field_type == 'BooleanField(':
+ field_type = 'NullBooleanField('
+ else:
+ extra_params['blank'] = True
+ if not field_type in ('TextField(', 'CharField('):
+ extra_params['null'] = True
- field_desc = '%s = %s%s' % (
- att_name,
- # Custom fields will have a dotted path
- '' if '.' in field_type else 'models.',
- field_type,
- )
- if extra_params:
- if not field_desc.endswith('('):
- field_desc += ', '
- field_desc += ', '.join([
- '%s=%s' % (k, strip_prefix(repr(v)))
- for k, v in extra_params.items()])
- field_desc += ')'
- if comment_notes:
- field_desc += ' # ' + ' '.join(comment_notes)
- yield ' %s' % field_desc
- for meta_line in self.get_meta(table_name):
- yield meta_line
+ field_desc = '%s = %s%s' % (
+ att_name,
+ # Custom fields will have a dotted path
+ '' if '.' in field_type else 'models.',
+ field_type,
+ )
+ if extra_params:
+ if not field_desc.endswith('('):
+ field_desc += ', '
+ field_desc += ', '.join([
+ '%s=%s' % (k, strip_prefix(repr(v)))
+ for k, v in extra_params.items()])
+ field_desc += ')'
+ if comment_notes:
+ field_desc += ' # ' + ' '.join(comment_notes)
+ yield ' %s' % field_desc
+ for meta_line in self.get_meta(table_name):
+ yield meta_line
def normalize_col_name(self, col_name, used_column_names, is_relation):
"""
diff --git a/django/core/management/commands/loaddata.py b/django/core/management/commands/loaddata.py
index 65bc96ba99..31a3af6225 100644
--- a/django/core/management/commands/loaddata.py
+++ b/django/core/management/commands/loaddata.py
@@ -100,10 +100,9 @@ class Command(BaseCommand):
if sequence_sql:
if self.verbosity >= 2:
self.stdout.write("Resetting sequences\n")
- cursor = connection.cursor()
- for line in sequence_sql:
- cursor.execute(line)
- cursor.close()
+ with connection.cursor() as cursor:
+ for line in sequence_sql:
+ cursor.execute(line)
if self.verbosity >= 1:
if self.fixture_object_count == self.loaded_object_count:
diff --git a/django/core/management/commands/migrate.py b/django/core/management/commands/migrate.py
index cb863509c8..60899ef09f 100644
--- a/django/core/management/commands/migrate.py
+++ b/django/core/management/commands/migrate.py
@@ -171,105 +171,110 @@ class Command(BaseCommand):
"Runs the old syncdb-style operation on a list of app_labels."
cursor = connection.cursor()
- # Get a list of already installed *models* so that references work right.
- tables = connection.introspection.table_names()
- seen_models = connection.introspection.installed_models(tables)
- created_models = set()
- pending_references = {}
+ try:
+ # Get a list of already installed *models* so that references work right.
+ tables = connection.introspection.table_names(cursor)
+ seen_models = connection.introspection.installed_models(tables)
+ created_models = set()
+ pending_references = {}
- # Build the manifest of apps and models that are to be synchronized
- all_models = [
- (app_config.label,
- router.get_migratable_models(app_config, connection.alias, include_auto_created=True))
- for app_config in apps.get_app_configs()
- if app_config.models_module is not None and app_config.label in app_labels
- ]
+ # Build the manifest of apps and models that are to be synchronized
+ all_models = [
+ (app_config.label,
+ router.get_migratable_models(app_config, connection.alias, include_auto_created=True))
+ for app_config in apps.get_app_configs()
+ if app_config.models_module is not None and app_config.label in app_labels
+ ]
- def model_installed(model):
- opts = model._meta
- converter = connection.introspection.table_name_converter
- # Note that if a model is unmanaged we short-circuit and never try to install it
- return not ((converter(opts.db_table) in tables) or
- (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables))
+ def model_installed(model):
+ opts = model._meta
+ converter = connection.introspection.table_name_converter
+ # Note that if a model is unmanaged we short-circuit and never try to install it
+ return not ((converter(opts.db_table) in tables) or
+ (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables))
- manifest = OrderedDict(
- (app_name, list(filter(model_installed, model_list)))
- for app_name, model_list in all_models
- )
+ manifest = OrderedDict(
+ (app_name, list(filter(model_installed, model_list)))
+ for app_name, model_list in all_models
+ )
- create_models = set(itertools.chain(*manifest.values()))
- emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias)
+ create_models = set(itertools.chain(*manifest.values()))
+ emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias)
- # Create the tables for each model
- if self.verbosity >= 1:
- self.stdout.write(" Creating tables...\n")
- with transaction.atomic(using=connection.alias, savepoint=False):
- for app_name, model_list in manifest.items():
- for model in model_list:
- # Create the model's database table, if it doesn't already exist.
- if self.verbosity >= 3:
- self.stdout.write(" Processing %s.%s model\n" % (app_name, model._meta.object_name))
- sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
- seen_models.add(model)
- created_models.add(model)
- for refto, refs in references.items():
- pending_references.setdefault(refto, []).extend(refs)
- if refto in seen_models:
- sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
- sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
- if self.verbosity >= 1 and sql:
- self.stdout.write(" Creating table %s\n" % model._meta.db_table)
- for statement in sql:
- cursor.execute(statement)
- tables.append(connection.introspection.table_name_converter(model._meta.db_table))
+ # Create the tables for each model
+ if self.verbosity >= 1:
+ self.stdout.write(" Creating tables...\n")
+ with transaction.atomic(using=connection.alias, savepoint=False):
+ for app_name, model_list in manifest.items():
+ for model in model_list:
+ # Create the model's database table, if it doesn't already exist.
+ if self.verbosity >= 3:
+ self.stdout.write(" Processing %s.%s model\n" % (app_name, model._meta.object_name))
+ sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
+ seen_models.add(model)
+ created_models.add(model)
+ for refto, refs in references.items():
+ pending_references.setdefault(refto, []).extend(refs)
+ if refto in seen_models:
+ sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
+ sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
+ if self.verbosity >= 1 and sql:
+ self.stdout.write(" Creating table %s\n" % model._meta.db_table)
+ for statement in sql:
+ cursor.execute(statement)
+ tables.append(connection.introspection.table_name_converter(model._meta.db_table))
- # We force a commit here, as that was the previous behaviour.
- # If you can prove we don't need this, remove it.
- transaction.set_dirty(using=connection.alias)
+ # We force a commit here, as that was the previous behaviour.
+ # If you can prove we don't need this, remove it.
+ transaction.set_dirty(using=connection.alias)
+ finally:
+ cursor.close()
# The connection may have been closed by a syncdb handler.
cursor = connection.cursor()
+ try:
+ # Install custom SQL for the app (but only if this
+ # is a model we've just created)
+ if self.verbosity >= 1:
+ self.stdout.write(" Installing custom SQL...\n")
+ for app_name, model_list in manifest.items():
+ for model in model_list:
+ if model in created_models:
+ custom_sql = custom_sql_for_model(model, no_style(), connection)
+ if custom_sql:
+ if self.verbosity >= 2:
+ self.stdout.write(" Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
+ try:
+ with transaction.commit_on_success_unless_managed(using=connection.alias):
+ for sql in custom_sql:
+ cursor.execute(sql)
+ except Exception as e:
+ self.stderr.write(" Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
+ if self.show_traceback:
+ traceback.print_exc()
+ else:
+ if self.verbosity >= 3:
+ self.stdout.write(" No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
- # Install custom SQL for the app (but only if this
- # is a model we've just created)
- if self.verbosity >= 1:
- self.stdout.write(" Installing custom SQL...\n")
- for app_name, model_list in manifest.items():
- for model in model_list:
- if model in created_models:
- custom_sql = custom_sql_for_model(model, no_style(), connection)
- if custom_sql:
- if self.verbosity >= 2:
- self.stdout.write(" Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
- try:
- with transaction.commit_on_success_unless_managed(using=connection.alias):
- for sql in custom_sql:
- cursor.execute(sql)
- except Exception as e:
- self.stderr.write(" Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
- if self.show_traceback:
- traceback.print_exc()
- else:
- if self.verbosity >= 3:
- self.stdout.write(" No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
-
- if self.verbosity >= 1:
- self.stdout.write(" Installing indexes...\n")
+ if self.verbosity >= 1:
+ self.stdout.write(" Installing indexes...\n")
- # Install SQL indices for all newly created models
- for app_name, model_list in manifest.items():
- for model in model_list:
- if model in created_models:
- index_sql = connection.creation.sql_indexes_for_model(model, no_style())
- if index_sql:
- if self.verbosity >= 2:
- self.stdout.write(" Installing index for %s.%s model\n" % (app_name, model._meta.object_name))
- try:
- with transaction.commit_on_success_unless_managed(using=connection.alias):
- for sql in index_sql:
- cursor.execute(sql)
- except Exception as e:
- self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
+ # Install SQL indices for all newly created models
+ for app_name, model_list in manifest.items():
+ for model in model_list:
+ if model in created_models:
+ index_sql = connection.creation.sql_indexes_for_model(model, no_style())
+ if index_sql:
+ if self.verbosity >= 2:
+ self.stdout.write(" Installing index for %s.%s model\n" % (app_name, model._meta.object_name))
+ try:
+ with transaction.commit_on_success_unless_managed(using=connection.alias):
+ for sql in index_sql:
+ cursor.execute(sql)
+ except Exception as e:
+ self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
+ finally:
+ cursor.close()
# Load initial_data fixtures (unless that has been disabled)
if self.load_initial_data:
diff --git a/django/core/management/sql.py b/django/core/management/sql.py
index ad91ca36c6..ccab11d2bd 100644
--- a/django/core/management/sql.py
+++ b/django/core/management/sql.py
@@ -67,38 +67,39 @@ def sql_delete(app_config, style, connection):
except Exception:
cursor = None
- # Figure out which tables already exist
- if cursor:
- table_names = connection.introspection.table_names(cursor)
- else:
- table_names = []
-
- output = []
+ try:
+ # Figure out which tables already exist
+ if cursor:
+ table_names = connection.introspection.table_names(cursor)
+ else:
+ table_names = []
- # Output DROP TABLE statements for standard application tables.
- to_delete = set()
+ output = []
- references_to_delete = {}
- app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True)
- for model in app_models:
- if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
- # The table exists, so it needs to be dropped
- opts = model._meta
- for f in opts.local_fields:
- if f.rel and f.rel.to not in to_delete:
- references_to_delete.setdefault(f.rel.to, []).append((model, f))
+ # Output DROP TABLE statements for standard application tables.
+ to_delete = set()
- to_delete.add(model)
+ references_to_delete = {}
+ app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True)
+ for model in app_models:
+ if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
+ # The table exists, so it needs to be dropped
+ opts = model._meta
+ for f in opts.local_fields:
+ if f.rel and f.rel.to not in to_delete:
+ references_to_delete.setdefault(f.rel.to, []).append((model, f))
- for model in app_models:
- if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
- output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
+ to_delete.add(model)
- # Close database connection explicitly, in case this output is being piped
- # directly into a database client, to avoid locking issues.
- if cursor:
- cursor.close()
- connection.close()
+ for model in app_models:
+ if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
+ output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
+ finally:
+ # Close database connection explicitly, in case this output is being piped
+ # directly into a database client, to avoid locking issues.
+ if cursor:
+ cursor.close()
+ connection.close()
return output[::-1] # Reverse it, to deal with table dependencies.
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
index b96407056a..50b8745f47 100644
--- a/django/db/backends/__init__.py
+++ b/django/db/backends/__init__.py
@@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object):
##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
- self.cursor().execute(self.ops.savepoint_create_sql(sid))
+ with self.cursor() as cursor:
+ cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid):
- self.cursor().execute(self.ops.savepoint_rollback_sql(sid))
+ with self.cursor() as cursor:
+ cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid):
- self.cursor().execute(self.ops.savepoint_commit_sql(sid))
+ with self.cursor() as cursor:
+ cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction
@@ -688,15 +691,15 @@ class BaseDatabaseFeatures(object):
# otherwise autocommit will cause the confimation to
# fail.
self.connection.enter_transaction_management()
- cursor = self.connection.cursor()
- cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
- self.connection.commit()
- cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
- self.connection.rollback()
- cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
- count, = cursor.fetchone()
- cursor.execute('DROP TABLE ROLLBACK_TEST')
- self.connection.commit()
+ with self.connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
+ self.connection.commit()
+ cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
+ self.connection.rollback()
+ cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
+ count, = cursor.fetchone()
+ cursor.execute('DROP TABLE ROLLBACK_TEST')
+ self.connection.commit()
finally:
self.connection.leave_transaction_management()
return count == 0
@@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object):
in sorting order between databases.
"""
if cursor is None:
- cursor = self.connection.cursor()
+ with self.connection.cursor() as cursor:
+ return sorted(self.get_table_list(cursor))
return sorted(self.get_table_list(cursor))
def get_table_list(self, cursor):
diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py
index ff62f30e71..3ee1e8448e 100644
--- a/django/db/backends/creation.py
+++ b/django/db/backends/creation.py
@@ -378,9 +378,8 @@ class BaseDatabaseCreation(object):
call_command('createcachetable', database=self.connection.alias)
- # Get a cursor (even though we don't need one yet). This has
- # the side effect of initializing the test database.
- self.connection.cursor()
+ # Ensure a connection for the side effect of initializing the test database.
+ self.connection.ensure_connection()
return test_database_name
@@ -406,34 +405,34 @@ class BaseDatabaseCreation(object):
qn = self.connection.ops.quote_name
# Create the test database and connect to it.
- cursor = self._nodb_connection.cursor()
- try:
- cursor.execute(
- "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
- except Exception as e:
- sys.stderr.write(
- "Got an error creating the test database: %s\n" % e)
- if not autoclobber:
- confirm = input(
- "Type 'yes' if you would like to try deleting the test "
- "database '%s', or 'no' to cancel: " % test_database_name)
- if autoclobber or confirm == 'yes':
- try:
- if verbosity >= 1:
- print("Destroying old test database '%s'..."
- % self.connection.alias)
- cursor.execute(
- "DROP DATABASE %s" % qn(test_database_name))
- cursor.execute(
- "CREATE DATABASE %s %s" % (qn(test_database_name),
- suffix))
- except Exception as e:
- sys.stderr.write(
- "Got an error recreating the test database: %s\n" % e)
- sys.exit(2)
- else:
- print("Tests cancelled.")
- sys.exit(1)
+ with self._nodb_connection.cursor() as cursor:
+ try:
+ cursor.execute(
+ "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
+ except Exception as e:
+ sys.stderr.write(
+ "Got an error creating the test database: %s\n" % e)
+ if not autoclobber:
+ confirm = input(
+ "Type 'yes' if you would like to try deleting the test "
+ "database '%s', or 'no' to cancel: " % test_database_name)
+ if autoclobber or confirm == 'yes':
+ try:
+ if verbosity >= 1:
+ print("Destroying old test database '%s'..."
+ % self.connection.alias)
+ cursor.execute(
+ "DROP DATABASE %s" % qn(test_database_name))
+ cursor.execute(
+ "CREATE DATABASE %s %s" % (qn(test_database_name),
+ suffix))
+ except Exception as e:
+ sys.stderr.write(
+ "Got an error recreating the test database: %s\n" % e)
+ sys.exit(2)
+ else:
+ print("Tests cancelled.")
+ sys.exit(1)
return test_database_name
@@ -461,11 +460,11 @@ class BaseDatabaseCreation(object):
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
- cursor = self._nodb_connection.cursor()
- # Wait to avoid "database is being accessed by other users" errors.
- time.sleep(1)
- cursor.execute("DROP DATABASE %s"
- % self.connection.ops.quote_name(test_database_name))
+ with self._nodb_connection.cursor() as cursor:
+ # Wait to avoid "database is being accessed by other users" errors.
+ time.sleep(1)
+ cursor.execute("DROP DATABASE %s"
+ % self.connection.ops.quote_name(test_database_name))
def set_autocommit(self):
"""
diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py
index e7932dd800..9d3935dc54 100644
--- a/django/db/backends/mysql/base.py
+++ b/django/db/backends/mysql/base.py
@@ -180,15 +180,15 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
- cursor = self.connection.cursor()
- cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
- # This command is MySQL specific; the second column
- # will tell you the default table type of the created
- # table. Since all Django's test tables will have the same
- # table type, that's enough to evaluate the feature.
- cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
- result = cursor.fetchone()
- cursor.execute('DROP TABLE INTROSPECT_TEST')
+ with self.connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
+ # This command is MySQL specific; the second column
+ # will tell you the default table type of the created
+ # table. Since all Django's test tables will have the same
+ # table type, that's enough to evaluate the feature.
+ cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
+ result = cursor.fetchone()
+ cursor.execute('DROP TABLE INTROSPECT_TEST')
return result[1]
@cached_property
@@ -207,9 +207,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
# Test if the time zone definitions are installed.
- cursor = self.connection.cursor()
- cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
- return cursor.fetchone() is not None
+ with self.connection.cursor() as cursor:
+ cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
+ return cursor.fetchone() is not None
class DatabaseOperations(BaseDatabaseOperations):
@@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return conn
def init_connection_state(self):
- cursor = self.connection.cursor()
- # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
- # on a recently-inserted row will return when the field is tested for
- # NULL. Disabling this value brings this aspect of MySQL in line with
- # SQL standards.
- cursor.execute('SET SQL_AUTO_IS_NULL = 0')
- cursor.close()
+ with self.connection.cursor() as cursor:
+ # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
+ # on a recently-inserted row will return when the field is tested for
+ # NULL. Disabling this value brings this aspect of MySQL in line with
+ # SQL standards.
+ cursor.execute('SET SQL_AUTO_IS_NULL = 0')
def create_cursor(self):
cursor = self.connection.cursor()
diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py
index cdb101d20c..2495986a02 100644
--- a/django/db/backends/oracle/base.py
+++ b/django/db/backends/oracle/base.py
@@ -353,8 +353,8 @@ WHEN (new.%(col_name)s IS NULL)
def regex_lookup(self, lookup_type):
# If regex_lookup is called before it's been initialized, then create
# a cursor to initialize it and recur.
- self.connection.cursor()
- return self.connection.ops.regex_lookup(lookup_type)
+ with self.connection.cursor():
+ return self.connection.ops.regex_lookup(lookup_type)
def return_insert_id(self):
return "RETURNING %s INTO %%s", (InsertIdVar(),)
diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py
index 33f885d50c..e89a4e604a 100644
--- a/django/db/backends/postgresql_psycopg2/base.py
+++ b/django/db/backends/postgresql_psycopg2/base.py
@@ -149,8 +149,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if conn_tz != tz:
cursor = self.connection.cursor()
- cursor.execute(self.ops.set_time_zone_sql(), [tz])
- cursor.close()
+ try:
+ cursor.execute(self.ops.set_time_zone_sql(), [tz])
+ finally:
+ cursor.close()
# Commit after setting the time zone (see #17062)
if not self.get_autocommit():
self.connection.commit()
diff --git a/django/db/backends/postgresql_psycopg2/version.py b/django/db/backends/postgresql_psycopg2/version.py
index dae94f2dac..64fd7c8298 100644
--- a/django/db/backends/postgresql_psycopg2/version.py
+++ b/django/db/backends/postgresql_psycopg2/version.py
@@ -39,6 +39,6 @@ def get_version(connection):
if hasattr(connection, 'server_version'):
return connection.server_version
else:
- cursor = connection.cursor()
- cursor.execute("SELECT version()")
- return _parse_version(cursor.fetchone()[0])
+ with connection.cursor() as cursor:
+ cursor.execute("SELECT version()")
+ return _parse_version(cursor.fetchone()[0])
diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py
index e690567738..88cc894437 100644
--- a/django/db/backends/schema.py
+++ b/django/db/backends/schema.py
@@ -86,14 +86,13 @@ class BaseDatabaseSchemaEditor(object):
"""
Executes the given SQL statement, with optional parameters.
"""
- # Get the cursor
- cursor = self.connection.cursor()
# Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params))
if self.collect_sql:
self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";")
else:
- cursor.execute(sql, params)
+ with self.connection.cursor() as cursor:
+ cursor.execute(sql, params)
def quote_name(self, name):
return self.connection.ops.quote_name(name)
@@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object):
Returns all constraint names matching the columns and conditions
"""
column_names = list(column_names) if column_names else None
- constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
+ with self.connection.cursor() as cursor:
+ constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
result = []
for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']:
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
index 3c8f170b7d..2adfbacaa9 100644
--- a/django/db/backends/sqlite3/base.py
+++ b/django/db/backends/sqlite3/base.py
@@ -122,14 +122,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
rule out support for STDDEV. We need to manually check
whether the call works.
"""
- cursor = self.connection.cursor()
- cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
- try:
- cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
- has_support = True
- except utils.DatabaseError:
- has_support = False
- cursor.execute('DROP TABLE STDDEV_TEST')
+ with self.connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
+ try:
+ cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
+ has_support = True
+ except utils.DatabaseError:
+ has_support = False
+ cursor.execute('DROP TABLE STDDEV_TEST')
return has_support
@cached_property
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 48d295ccca..6051b9f859 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -14,6 +14,7 @@ from django.db.models.fields import AutoField, Empty
from django.db.models.query_utils import (Q, select_related_descend,
deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector
+from django.db.models.sql.constants import CURSOR
from django.db.models import sql
from django.utils.functional import partition
from django.utils import six
@@ -574,7 +575,7 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs)
with transaction.commit_on_success_unless_managed(using=self.db):
- rows = query.get_compiler(self.db).execute_sql(None)
+ rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None
return rows
update.alters_data = True
@@ -591,7 +592,7 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery)
query.add_update_fields(values)
self._result_cache = None
- return query.get_compiler(self.db).execute_sql(None)
+ return query.get_compiler(self.db).execute_sql(CURSOR)
_update.alters_data = True
_update.queryset_only = False
@@ -1521,54 +1522,59 @@ class RawQuerySet(object):
query = iter(self.query)
- # Find out which columns are model's fields, and which ones should be
- # annotated to the model.
- for pos, column in enumerate(self.columns):
- if column in self.model_fields:
- model_init_field_names[self.model_fields[column].attname] = pos
- else:
- annotation_fields.append((column, pos))
+ try:
+ # Find out which columns are model's fields, and which ones should be
+ # annotated to the model.
+ for pos, column in enumerate(self.columns):
+ if column in self.model_fields:
+ model_init_field_names[self.model_fields[column].attname] = pos
+ else:
+ annotation_fields.append((column, pos))
- # Find out which model's fields are not present in the query.
- skip = set()
- for field in self.model._meta.fields:
- if field.attname not in model_init_field_names:
- skip.add(field.attname)
- if skip:
- if self.model._meta.pk.attname in skip:
- raise InvalidQuery('Raw query must include the primary key')
- model_cls = deferred_class_factory(self.model, skip)
- else:
- model_cls = self.model
- # All model's fields are present in the query. So, it is possible
- # to use *args based model instantation. For each field of the model,
- # record the query column position matching that field.
- model_init_field_pos = []
+ # Find out which model's fields are not present in the query.
+ skip = set()
for field in self.model._meta.fields:
- model_init_field_pos.append(model_init_field_names[field.attname])
- if need_resolv_columns:
- fields = [self.model_fields.get(c, None) for c in self.columns]
- # Begin looping through the query values.
- for values in query:
- if need_resolv_columns:
- values = compiler.resolve_columns(values, fields)
- # Associate fields to values
+ if field.attname not in model_init_field_names:
+ skip.add(field.attname)
if skip:
- model_init_kwargs = {}
- for attname, pos in six.iteritems(model_init_field_names):
- model_init_kwargs[attname] = values[pos]
- instance = model_cls(**model_init_kwargs)
+ if self.model._meta.pk.attname in skip:
+ raise InvalidQuery('Raw query must include the primary key')
+ model_cls = deferred_class_factory(self.model, skip)
else:
- model_init_args = [values[pos] for pos in model_init_field_pos]
- instance = model_cls(*model_init_args)
- if annotation_fields:
- for column, pos in annotation_fields:
- setattr(instance, column, values[pos])
+ model_cls = self.model
+ # All model's fields are present in the query. So, it is possible
+ # to use *args based model instantation. For each field of the model,
+ # record the query column position matching that field.
+ model_init_field_pos = []
+ for field in self.model._meta.fields:
+ model_init_field_pos.append(model_init_field_names[field.attname])
+ if need_resolv_columns:
+ fields = [self.model_fields.get(c, None) for c in self.columns]
+ # Begin looping through the query values.
+ for values in query:
+ if need_resolv_columns:
+ values = compiler.resolve_columns(values, fields)
+ # Associate fields to values
+ if skip:
+ model_init_kwargs = {}
+ for attname, pos in six.iteritems(model_init_field_names):
+ model_init_kwargs[attname] = values[pos]
+ instance = model_cls(**model_init_kwargs)
+ else:
+ model_init_args = [values[pos] for pos in model_init_field_pos]
+ instance = model_cls(*model_init_args)
+ if annotation_fields:
+ for column, pos in annotation_fields:
+ setattr(instance, column, values[pos])
- instance._state.db = db
- instance._state.adding = False
+ instance._state.db = db
+ instance._state.adding = False
- yield instance
+ yield instance
+ finally:
+ # Done iterating the Query. If it has its own cursor, close it.
+ if hasattr(self.query, 'cursor') and self.query.cursor:
+ self.query.cursor.close()
def __repr__(self):
text = self.raw_query
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 123427cf8b..d9161d820c 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -1,12 +1,13 @@
import datetime
+import sys
from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends.utils import truncate_name
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend, QueryWrapper
-from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
- GET_ITERATOR_CHUNK_SIZE, SelectInfo)
+from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
+ ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
@@ -762,6 +763,8 @@ class SQLCompiler(object):
is needed, as the filters describe an empty set. In that case, None is
returned, to avoid any unnecessary database interaction.
"""
+ if not result_type:
+ result_type = NO_RESULTS
try:
sql, params = self.as_sql()
if not sql:
@@ -773,27 +776,44 @@ class SQLCompiler(object):
return
cursor = self.connection.cursor()
- cursor.execute(sql, params)
+ try:
+ cursor.execute(sql, params)
+ except Exception:
+ cursor.close()
+ raise
- if not result_type:
+ if result_type == CURSOR:
+ # Caller didn't specify a result_type, so just give them back the
+ # cursor to process (and close).
return cursor
if result_type == SINGLE:
- if self.ordering_aliases:
- return cursor.fetchone()[:-len(self.ordering_aliases)]
- return cursor.fetchone()
+ try:
+ if self.ordering_aliases:
+ return cursor.fetchone()[:-len(self.ordering_aliases)]
+ return cursor.fetchone()
+ finally:
+ # done with the cursor
+ cursor.close()
+ if result_type == NO_RESULTS:
+ cursor.close()
+ return
# The MULTI case.
if self.ordering_aliases:
result = order_modified_iter(cursor, len(self.ordering_aliases),
self.connection.features.empty_fetchmany_value)
else:
- result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
- self.connection.features.empty_fetchmany_value)
+ result = cursor_iter(cursor,
+ self.connection.features.empty_fetchmany_value)
if not self.connection.features.can_use_chunked_reads:
- # If we are using non-chunked reads, we return the same data
- # structure as normally, but ensure it is all read into memory
- # before going any further.
- return list(result)
+ try:
+ # If we are using non-chunked reads, we return the same data
+ # structure as normally, but ensure it is all read into memory
+ # before going any further.
+ return list(result)
+ finally:
+ # done with the cursor
+ cursor.close()
return result
def as_subquery_condition(self, alias, columns, qn):
@@ -889,15 +909,15 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id
- cursor = self.connection.cursor()
- for sql, params in self.as_sql():
- cursor.execute(sql, params)
- if not (return_id and cursor):
- return
- if self.connection.features.can_return_id_from_insert:
- return self.connection.ops.fetch_returned_insert_id(cursor)
- return self.connection.ops.last_insert_id(cursor,
- self.query.get_meta().db_table, self.query.get_meta().pk.column)
+ with self.connection.cursor() as cursor:
+ for sql, params in self.as_sql():
+ cursor.execute(sql, params)
+ if not (return_id and cursor):
+ return
+ if self.connection.features.can_return_id_from_insert:
+ return self.connection.ops.fetch_returned_insert_id(cursor)
+ return self.connection.ops.last_insert_id(cursor,
+ self.query.get_meta().db_table, self.query.get_meta().pk.column)
class SQLDeleteCompiler(SQLCompiler):
@@ -970,12 +990,15 @@ class SQLUpdateCompiler(SQLCompiler):
related queries are not available.
"""
cursor = super(SQLUpdateCompiler, self).execute_sql(result_type)
- rows = cursor.rowcount if cursor else 0
- is_empty = cursor is None
- del cursor
+ try:
+ rows = cursor.rowcount if cursor else 0
+ is_empty = cursor is None
+ finally:
+ if cursor:
+ cursor.close()
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
- if is_empty:
+ if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows
@@ -1111,6 +1134,19 @@ class SQLDateTimeCompiler(SQLCompiler):
yield datetime
+def cursor_iter(cursor, sentinel):
+ """
+ Yields blocks of rows from a cursor and ensures the cursor is closed when
+ done.
+ """
+ try:
+ for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+ sentinel):
+ yield rows
+ finally:
+ cursor.close()
+
+
def order_modified_iter(cursor, trim, sentinel):
"""
Yields blocks of rows from a cursor. We use this iterator in the special
@@ -1118,6 +1154,9 @@ def order_modified_iter(cursor, trim, sentinel):
requirements. We must trim those extra columns before anything else can use
the results, since they're only needed to make the SQL valid.
"""
- for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
- sentinel):
- yield [r[:-trim] for r in rows]
+ try:
+ for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+ sentinel):
+ yield [r[:-trim] for r in rows]
+ finally:
+ cursor.close()
diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py
index 904f7b2c8b..36aab23bae 100644
--- a/django/db/models/sql/constants.py
+++ b/django/db/models/sql/constants.py
@@ -33,6 +33,8 @@ SelectInfo = namedtuple('SelectInfo', 'col field')
# How many results to expect from a cursor.execute call
MULTI = 'multi'
SINGLE = 'single'
+CURSOR = 'cursor'
+NO_RESULTS = 'no results'
ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$')
ORDER_DIR = {
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index 86b1efd3f8..cfda1f552c 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -8,7 +8,7 @@ from django.db import connections
from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
-from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
+from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
from django.utils import six
@@ -30,7 +30,7 @@ class DeleteQuery(Query):
def do_query(self, table, where, using):
self.tables = [table]
self.where = where
- self.get_compiler(using).execute_sql(None)
+ self.get_compiler(using).execute_sql(NO_RESULTS)
def delete_batch(self, pk_list, using, field=None):
"""
@@ -82,7 +82,7 @@ class DeleteQuery(Query):
values = innerq
self.where = self.where_class()
self.add_q(Q(pk__in=values))
- self.get_compiler(using).execute_sql(None)
+ self.get_compiler(using).execute_sql(NO_RESULTS)
class UpdateQuery(Query):
@@ -116,7 +116,7 @@ class UpdateQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class()
self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
- self.get_compiler(using).execute_sql(None)
+ self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values):
"""
diff --git a/tests/backends/tests.py b/tests/backends/tests.py
index 0ff3ad0bba..f3c38893f4 100644
--- a/tests/backends/tests.py
+++ b/tests/backends/tests.py
@@ -20,6 +20,7 @@ from django.db.backends.utils import format_number, CursorWrapper
from django.db.models import Sum, Avg, Variance, StdDev
from django.db.models.fields import (AutoField, DateField, DateTimeField,
DecimalField, IntegerField, TimeField)
+from django.db.models.sql.constants import CURSOR
from django.db.utils import ConnectionHandler
from django.test import (TestCase, TransactionTestCase, override_settings,
skipUnlessDBFeature, skipIfDBFeature)
@@ -58,9 +59,9 @@ class OracleChecks(unittest.TestCase):
# stored procedure through our cursor wrapper.
from django.db.backends.oracle.base import convert_unicode
- cursor = connection.cursor()
- cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
- [convert_unicode('_django_testing!')])
+ with connection.cursor() as cursor:
+ cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
+ [convert_unicode('_django_testing!')])
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics")
@@ -69,31 +70,31 @@ class OracleChecks(unittest.TestCase):
# as query parameters.
from django.db.backends.oracle.base import Database
- cursor = connection.cursor()
- var = cursor.var(Database.STRING)
- cursor.execute("BEGIN %s := 'X'; END; ", [var])
- self.assertEqual(var.getvalue(), 'X')
+ with connection.cursor() as cursor:
+ var = cursor.var(Database.STRING)
+ cursor.execute("BEGIN %s := 'X'; END; ", [var])
+ self.assertEqual(var.getvalue(), 'X')
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics")
def test_long_string(self):
# If the backend is Oracle, test that we can save a text longer
# than 4000 chars and read it properly
- c = connection.cursor()
- c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
- long_str = ''.join(six.text_type(x) for x in xrange(4000))
- c.execute('INSERT INTO ltext VALUES (%s)', [long_str])
- c.execute('SELECT text FROM ltext')
- row = c.fetchone()
- self.assertEqual(long_str, row[0].read())
- c.execute('DROP TABLE ltext')
+ with connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
+ long_str = ''.join(six.text_type(x) for x in xrange(4000))
+ cursor.execute('INSERT INTO ltext VALUES (%s)', [long_str])
+ cursor.execute('SELECT text FROM ltext')
+ row = cursor.fetchone()
+ self.assertEqual(long_str, row[0].read())
+ cursor.execute('DROP TABLE ltext')
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle connection semantics")
def test_client_encoding(self):
# If the backend is Oracle, test that the client encoding is set
# correctly. This was broken under Cygwin prior to r14781.
- connection.cursor() # Ensure the connection is initialized.
+ self.connection.ensure_connection()
self.assertEqual(connection.connection.encoding, "UTF-8")
self.assertEqual(connection.connection.nencoding, "UTF-8")
@@ -102,12 +103,12 @@ class OracleChecks(unittest.TestCase):
def test_order_of_nls_parameters(self):
# an 'almost right' datetime should work with configured
# NLS parameters as per #18465.
- c = connection.cursor()
- query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
- # Test that the query succeeds without errors - pre #18465 this
- # wasn't the case.
- c.execute(query)
- self.assertEqual(c.fetchone()[0], 1)
+ with connection.cursor() as cursor:
+ query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
+ # Test that the query succeeds without errors - pre #18465 this
+ # wasn't the case.
+ cursor.execute(query)
+ self.assertEqual(cursor.fetchone()[0], 1)
class SQLiteTests(TestCase):
@@ -209,7 +210,7 @@ class LastExecutedQueryTest(TestCase):
"""
persons = models.Reporter.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1})
sql, params = persons.query.sql_with_params()
- cursor = persons.query.get_compiler('default').execute_sql(None)
+ cursor = persons.query.get_compiler('default').execute_sql(CURSOR)
last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
self.assertIsInstance(last_sql, six.text_type)
@@ -327,6 +328,12 @@ class PostgresVersionTest(TestCase):
def fetchone(self):
return ["PostgreSQL 8.3"]
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ pass
+
class OlderConnectionMock(object):
"Mock of psycopg2 (< 2.0.12) connection"
def cursor(self):
diff --git a/tests/cache/tests.py b/tests/cache/tests.py
index 94790ed740..bc0f705375 100644
--- a/tests/cache/tests.py
+++ b/tests/cache/tests.py
@@ -896,10 +896,9 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
management.call_command('createcachetable', verbosity=0, interactive=False)
def drop_table(self):
- cursor = connection.cursor()
- table_name = connection.ops.quote_name('test cache table')
- cursor.execute('DROP TABLE %s' % table_name)
- cursor.close()
+ with connection.cursor() as cursor:
+ table_name = connection.ops.quote_name('test cache table')
+ cursor.execute('DROP TABLE %s' % table_name)
def test_zero_cull(self):
self._perform_cull_test(caches['zero_cull'], 50, 18)
diff --git a/tests/custom_methods/models.py b/tests/custom_methods/models.py
index cef3fd722b..78e00a99b8 100644
--- a/tests/custom_methods/models.py
+++ b/tests/custom_methods/models.py
@@ -30,11 +30,11 @@ class Article(models.Model):
database query for the sake of demonstration.
"""
from django.db import connection
- cursor = connection.cursor()
- cursor.execute("""
- SELECT id, headline, pub_date
- FROM custom_methods_article
- WHERE pub_date = %s
- AND id != %s""", [connection.ops.value_to_db_date(self.pub_date),
- self.id])
- return [self.__class__(*row) for row in cursor.fetchall()]
+ with connection.cursor() as cursor:
+ cursor.execute("""
+ SELECT id, headline, pub_date
+ FROM custom_methods_article
+ WHERE pub_date = %s
+ AND id != %s""", [connection.ops.value_to_db_date(self.pub_date),
+ self.id])
+ return [self.__class__(*row) for row in cursor.fetchall()]
diff --git a/tests/initial_sql_regress/tests.py b/tests/initial_sql_regress/tests.py
index e725f4b102..428d993667 100644
--- a/tests/initial_sql_regress/tests.py
+++ b/tests/initial_sql_regress/tests.py
@@ -28,9 +28,9 @@ class InitialSQLTests(TestCase):
connection = connections[DEFAULT_DB_ALIAS]
custom_sql = custom_sql_for_model(Simple, no_style(), connection)
self.assertEqual(len(custom_sql), 9)
- cursor = connection.cursor()
- for sql in custom_sql:
- cursor.execute(sql)
+ with connection.cursor() as cursor:
+ for sql in custom_sql:
+ cursor.execute(sql)
self.assertEqual(Simple.objects.count(), 9)
self.assertEqual(
Simple.objects.get(name__contains='placeholders').name,
diff --git a/tests/introspection/tests.py b/tests/introspection/tests.py
index 8ec3d39903..0c339bc8ea 100644
--- a/tests/introspection/tests.py
+++ b/tests/introspection/tests.py
@@ -23,17 +23,17 @@ class IntrospectionTests(TestCase):
"'%s' isn't in table_list()." % Article._meta.db_table)
def test_django_table_names(self):
- cursor = connection.cursor()
- cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
- tl = connection.introspection.django_table_names()
- cursor.execute("DROP TABLE django_ixn_test_table;")
- self.assertTrue('django_ixn_testcase_table' not in tl,
- "django_table_names() returned a non-Django table")
+ with connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
+ tl = connection.introspection.django_table_names()
+ cursor.execute("DROP TABLE django_ixn_test_table;")
+ self.assertTrue('django_ixn_testcase_table' not in tl,
+ "django_table_names() returned a non-Django table")
def test_django_table_names_retval_type(self):
# Ticket #15216
- cursor = connection.cursor()
- cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
+ with connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names(only_existing=True)
self.assertIs(type(tl), list)
@@ -53,14 +53,14 @@ class IntrospectionTests(TestCase):
'Reporter sequence not found in sequence_list()')
def test_get_table_description_names(self):
- cursor = connection.cursor()
- desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
+ with connection.cursor() as cursor:
+ desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual([r[0] for r in desc],
[f.column for f in Reporter._meta.fields])
def test_get_table_description_types(self):
- cursor = connection.cursor()
- desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
+ with connection.cursor() as cursor:
+ desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
# The MySQL exception is due to the cursor.description returning the same constant for
# text and blob columns. TODO: use information_schema database to retrieve the proper
# field type on MySQL
@@ -75,8 +75,8 @@ class IntrospectionTests(TestCase):
# inspect the length of character columns).
@expectedFailureOnOracle
def test_get_table_description_col_lengths(self):
- cursor = connection.cursor()
- desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
+ with connection.cursor() as cursor:
+ desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual(
[r[3] for r in desc if datatype(r[1], r) == 'CharField'],
[30, 30, 75]
@@ -87,8 +87,8 @@ class IntrospectionTests(TestCase):
# so its idea about null_ok in cursor.description is different from ours.
@skipIfDBFeature('interprets_empty_strings_as_nulls')
def test_get_table_description_nullable(self):
- cursor = connection.cursor()
- desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
+ with connection.cursor() as cursor:
+ desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual(
[r[6] for r in desc],
[False, False, False, False, True, True]
@@ -97,15 +97,15 @@ class IntrospectionTests(TestCase):
# Regression test for #9991 - 'real' types in postgres
@skipUnlessDBFeature('has_real_datatype')
def test_postgresql_real_type(self):
- cursor = connection.cursor()
- cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
- desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
- cursor.execute('DROP TABLE django_ixn_real_test_table;')
+ with connection.cursor() as cursor:
+ cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
+ desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
+ cursor.execute('DROP TABLE django_ixn_real_test_table;')
self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField')
def test_get_relations(self):
- cursor = connection.cursor()
- relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
+ with connection.cursor() as cursor:
+ relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
# Older versions of MySQL don't have the chops to report on this stuff,
# so just skip it if no relations come back. If they do, though, we
@@ -117,21 +117,21 @@ class IntrospectionTests(TestCase):
@skipUnlessDBFeature('can_introspect_foreign_keys')
def test_get_key_columns(self):
- cursor = connection.cursor()
- key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
+ with connection.cursor() as cursor:
+ key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
self.assertEqual(
set(key_columns),
set([('reporter_id', Reporter._meta.db_table, 'id'),
('response_to_id', Article._meta.db_table, 'id')]))
def test_get_primary_key_column(self):
- cursor = connection.cursor()
- primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
+ with connection.cursor() as cursor:
+ primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
self.assertEqual(primary_key_column, 'id')
def test_get_indexes(self):
- cursor = connection.cursor()
- indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
+ with connection.cursor() as cursor:
+ indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False})
def test_get_indexes_multicol(self):
@@ -139,8 +139,8 @@ class IntrospectionTests(TestCase):
Test that multicolumn indexes are not included in the introspection
results.
"""
- cursor = connection.cursor()
- indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
+ with connection.cursor() as cursor:
+ indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
self.assertNotIn('first_name', indexes)
self.assertIn('id', indexes)
diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py
index 7ab09b04a5..2dba30b2aa 100644
--- a/tests/migrations/test_base.py
+++ b/tests/migrations/test_base.py
@@ -9,33 +9,40 @@ class MigrationTestBase(TransactionTestCase):
available_apps = ["migrations"]
+ def get_table_description(self, table):
+ with connection.cursor() as cursor:
+ return connection.introspection.get_table_description(cursor, table)
+
def assertTableExists(self, table):
- self.assertIn(table, connection.introspection.get_table_list(connection.cursor()))
+ with connection.cursor() as cursor:
+ self.assertIn(table, connection.introspection.get_table_list(cursor))
def assertTableNotExists(self, table):
- self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
+ with connection.cursor() as cursor:
+ self.assertNotIn(table, connection.introspection.get_table_list(cursor))
def assertColumnExists(self, table, column):
- self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
+ self.assertIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNotExists(self, table, column):
- self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
+ self.assertNotIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNull(self, table, column):
- self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True)
+ self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], True)
def assertColumnNotNull(self, table, column):
- self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False)
+ self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], False)
def assertIndexExists(self, table, columns, value=True):
- self.assertEqual(
- value,
- any(
- c["index"]
- for c in connection.introspection.get_constraints(connection.cursor(), table).values()
- if c['columns'] == list(columns)
- ),
- )
+ with connection.cursor() as cursor:
+ self.assertEqual(
+ value,
+ any(
+ c["index"]
+ for c in connection.introspection.get_constraints(cursor, table).values()
+ if c['columns'] == list(columns)
+ ),
+ )
def assertIndexNotExists(self, table, columns):
return self.assertIndexExists(table, columns, False)
diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py
index eda356fd5d..375a9ccc54 100644
--- a/tests/migrations/test_operations.py
+++ b/tests/migrations/test_operations.py
@@ -19,15 +19,15 @@ class OperationTests(MigrationTestBase):
Creates a test model state and database table.
"""
# Delete the tables if they already exist
- cursor = connection.cursor()
- try:
- cursor.execute("DROP TABLE %s_pony" % app_label)
- except:
- pass
- try:
- cursor.execute("DROP TABLE %s_stable" % app_label)
- except:
- pass
+ with connection.cursor() as cursor:
+ try:
+ cursor.execute("DROP TABLE %s_pony" % app_label)
+ except:
+ pass
+ try:
+ cursor.execute("DROP TABLE %s_stable" % app_label)
+ except:
+ pass
# Make the "current" state
operations = [migrations.CreateModel(
"Pony",
@@ -348,21 +348,21 @@ class OperationTests(MigrationTestBase):
operation.state_forwards("test_alflpkfk", new_state)
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
+
+ def assertIdTypeEqualsFkType(self):
+ with connection.cursor() as cursor:
+ id_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_pony") if c.name == "id"][0]
+ fk_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_rider") if c.name == "pony_id"][0]
+ self.assertEqual(id_type, fk_type)
+ assertIdTypeEqualsFkType()
# Test the database alteration
- id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
- fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
- self.assertEqual(id_type, fk_type)
with connection.schema_editor() as editor:
operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
- id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
- fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
- self.assertEqual(id_type, fk_type)
+ assertIdTypeEqualsFkType()
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
- id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
- fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
- self.assertEqual(id_type, fk_type)
+ assertIdTypeEqualsFkType()
def test_rename_field(self):
"""
@@ -400,24 +400,24 @@ class OperationTests(MigrationTestBase):
self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0)
self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1)
# Make sure we can insert duplicate rows
- cursor = connection.cursor()
- cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
- cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
- cursor.execute("DELETE FROM test_alunto_pony")
- # Test the database alteration
- with connection.schema_editor() as editor:
- operation.database_forwards("test_alunto", editor, project_state, new_state)
- cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
- with self.assertRaises(IntegrityError):
- with atomic():
- cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
- cursor.execute("DELETE FROM test_alunto_pony")
- # And test reversal
- with connection.schema_editor() as editor:
- operation.database_backwards("test_alunto", editor, new_state, project_state)
- cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
- cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
- cursor.execute("DELETE FROM test_alunto_pony")
+ with connection.cursor() as cursor:
+ cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
+ cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
+ cursor.execute("DELETE FROM test_alunto_pony")
+ # Test the database alteration
+ with connection.schema_editor() as editor:
+ operation.database_forwards("test_alunto", editor, project_state, new_state)
+ cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
+ with self.assertRaises(IntegrityError):
+ with atomic():
+ cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
+ cursor.execute("DELETE FROM test_alunto_pony")
+ # And test reversal
+ with connection.schema_editor() as editor:
+ operation.database_backwards("test_alunto", editor, new_state, project_state)
+ cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
+ cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
+ cursor.execute("DELETE FROM test_alunto_pony")
# Test flat unique_together
operation = migrations.AlterUniqueTogether("Pony", ("pink", "weight"))
operation.state_forwards("test_alunto", new_state)
diff --git a/tests/requests/tests.py b/tests/requests/tests.py
index bb56369371..3e594376aa 100644
--- a/tests/requests/tests.py
+++ b/tests/requests/tests.py
@@ -725,7 +725,7 @@ class DatabaseConnectionHandlingTests(TransactionTestCase):
# request_finished signal.
response = self.client.get('/')
# Make sure there is an open connection
- connection.cursor()
+ self.connection.ensure_connection()
connection.enter_transaction_management()
signals.request_finished.send(sender=response._handler_class)
self.assertEqual(len(connection.transaction_state), 0)
diff --git a/tests/schema/tests.py b/tests/schema/tests.py
index 450362ecda..a502ef6002 100644
--- a/tests/schema/tests.py
+++ b/tests/schema/tests.py
@@ -37,38 +37,38 @@ class SchemaTests(TransactionTestCase):
def delete_tables(self):
"Deletes all model tables for our models for a clean test environment"
- cursor = connection.cursor()
- connection.disable_constraint_checking()
- table_names = connection.introspection.table_names(cursor)
- for model in self.models:
- # Remove any M2M tables first
- for field in model._meta.local_many_to_many:
+ with connection.cursor() as cursor:
+ connection.disable_constraint_checking()
+ table_names = connection.introspection.table_names(cursor)
+ for model in self.models:
+ # Remove any M2M tables first
+ for field in model._meta.local_many_to_many:
+ with atomic():
+ tbl = field.rel.through._meta.db_table
+ if tbl in table_names:
+ cursor.execute(connection.schema_editor().sql_delete_table % {
+ "table": connection.ops.quote_name(tbl),
+ })
+ table_names.remove(tbl)
+ # Then remove the main tables
with atomic():
- tbl = field.rel.through._meta.db_table
+ tbl = model._meta.db_table
if tbl in table_names:
cursor.execute(connection.schema_editor().sql_delete_table % {
"table": connection.ops.quote_name(tbl),
})
table_names.remove(tbl)
- # Then remove the main tables
- with atomic():
- tbl = model._meta.db_table
- if tbl in table_names:
- cursor.execute(connection.schema_editor().sql_delete_table % {
- "table": connection.ops.quote_name(tbl),
- })
- table_names.remove(tbl)
connection.enable_constraint_checking()
def column_classes(self, model):
- cursor = connection.cursor()
- columns = dict(
- (d[0], (connection.introspection.get_field_type(d[1], d), d))
- for d in connection.introspection.get_table_description(
- cursor,
- model._meta.db_table,
+ with connection.cursor() as cursor:
+ columns = dict(
+ (d[0], (connection.introspection.get_field_type(d[1], d), d))
+ for d in connection.introspection.get_table_description(
+ cursor,
+ model._meta.db_table,
+ )
)
- )
# SQLite has a different format for field_type
for name, (type, desc) in columns.items():
if isinstance(type, tuple):
@@ -78,6 +78,20 @@ class SchemaTests(TransactionTestCase):
raise DatabaseError("Table does not exist (empty pragma)")
return columns
+ def get_indexes(self, table):
+ """
+ Get the indexes on the table using a new cursor.
+ """
+ with connection.cursor() as cursor:
+ return connection.introspection.get_indexes(cursor, table)
+
+ def get_constraints(self, table):
+ """
+ Get the constraints on a table using a new cursor.
+ """
+ with connection.cursor() as cursor:
+ return connection.introspection.get_constraints(cursor, table)
+
# Tests
def test_creation_deletion(self):
@@ -127,7 +141,7 @@ class SchemaTests(TransactionTestCase):
strict=True,
)
# Make sure the new FK constraint is present
- constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
+ constraints = self.get_constraints(Book._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["author_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
@@ -342,7 +356,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(TagM2MTest)
editor.create_model(UniqueTest)
# Ensure the M2M exists and points to TagM2MTest
- constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
+ constraints = self.get_constraints(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']:
@@ -363,7 +377,7 @@ class SchemaTests(TransactionTestCase):
# Ensure old M2M is gone
self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
# Ensure the new M2M exists and points to UniqueTest
- constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
+ constraints = self.get_constraints(new_field.rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
@@ -388,7 +402,7 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor:
editor.create_model(Author)
# Ensure the constraint exists
- constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+ constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
break
@@ -404,7 +418,7 @@ class SchemaTests(TransactionTestCase):
new_field,
strict=True,
)
- constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+ constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
self.fail("Check constraint for height found")
@@ -416,7 +430,7 @@ class SchemaTests(TransactionTestCase):
Author._meta.get_field_by_name("height")[0],
strict=True,
)
- constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+ constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
break
@@ -527,7 +541,7 @@ class SchemaTests(TransactionTestCase):
False,
any(
c["index"]
- for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
+ for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@@ -543,7 +557,7 @@ class SchemaTests(TransactionTestCase):
True,
any(
c["index"]
- for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
+ for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@@ -561,7 +575,7 @@ class SchemaTests(TransactionTestCase):
False,
any(
c["index"]
- for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
+ for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@@ -578,7 +592,7 @@ class SchemaTests(TransactionTestCase):
True,
any(
c["index"]
- for c in connection.introspection.get_constraints(connection.cursor(), "schema_tagindexed").values()
+ for c in self.get_constraints("schema_tagindexed").values()
if c['columns'] == ["slug", "title"]
),
)
@@ -627,7 +641,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the right index
self.assertIn(
"title",
- connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+ self.get_indexes(Book._meta.db_table),
)
# Alter to remove the index
new_field = CharField(max_length=100, db_index=False)
@@ -642,7 +656,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has no index
self.assertNotIn(
"title",
- connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+ self.get_indexes(Book._meta.db_table),
)
# Alter to re-add the index
with connection.schema_editor() as editor:
@@ -655,7 +669,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the index again
self.assertIn(
"title",
- connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+ self.get_indexes(Book._meta.db_table),
)
# Add a unique column, verify that creates an implicit index
with connection.schema_editor() as editor:
@@ -665,7 +679,7 @@ class SchemaTests(TransactionTestCase):
)
self.assertIn(
"slug",
- connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+ self.get_indexes(Book._meta.db_table),
)
# Remove the unique, check the index goes with it
new_field2 = CharField(max_length=20, unique=False)
@@ -679,7 +693,7 @@ class SchemaTests(TransactionTestCase):
)
self.assertNotIn(
"slug",
- connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+ self.get_indexes(Book._meta.db_table),
)
def test_primary_key(self):
@@ -691,7 +705,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(Tag)
# Ensure the table is there and has the right PK
self.assertTrue(
- connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['id']['primary_key'],
+ self.get_indexes(Tag._meta.db_table)['id']['primary_key'],
)
# Alter to change the PK
new_field = SlugField(primary_key=True)
@@ -707,10 +721,10 @@ class SchemaTests(TransactionTestCase):
# Ensure the PK changed
self.assertNotIn(
'id',
- connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table),
+ self.get_indexes(Tag._meta.db_table),
)
self.assertTrue(
- connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['slug']['primary_key'],
+ self.get_indexes(Tag._meta.db_table)['slug']['primary_key'],
)
def test_context_manager_exit(self):
@@ -741,7 +755,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has an index on the column
self.assertIn(
column_name,
- connection.introspection.get_indexes(connection.cursor(), BookWithLongName._meta.db_table),
+ self.get_indexes(BookWithLongName._meta.db_table),
)
def test_creation_deletion_reserved_names(self):
diff --git a/tests/transactions/tests.py b/tests/transactions/tests.py
index 5c38bc8ef2..e7ce43cd93 100644
--- a/tests/transactions/tests.py
+++ b/tests/transactions/tests.py
@@ -202,8 +202,9 @@ class AtomicTests(TransactionTestCase):
# trigger a database error inside an inner atomic without savepoint
with self.assertRaises(DatabaseError):
with transaction.atomic(savepoint=False):
- connection.cursor().execute(
- "SELECT no_such_col FROM transactions_reporter")
+ with connection.cursor() as cursor:
+ cursor.execute(
+ "SELECT no_such_col FROM transactions_reporter")
# prevent atomic from rolling back since we're recovering manually
self.assertTrue(transaction.get_rollback())
transaction.set_rollback(False)
@@ -534,8 +535,8 @@ class TransactionRollbackTests(IgnoreDeprecationWarningsMixin, TransactionTestCa
available_apps = ['transactions']
def execute_bad_sql(self):
- cursor = connection.cursor()
- cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
+ with connection.cursor() as cursor:
+ cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
@skipUnlessDBFeature('requires_rollback_on_dirty_transaction')
def test_bad_sql(self):
@@ -678,6 +679,6 @@ class TransactionContextManagerTests(IgnoreDeprecationWarningsMixin, Transaction
"""
with self.assertRaises(IntegrityError):
with transaction.commit_on_success():
- cursor = connection.cursor()
- cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
+ with connection.cursor() as cursor:
+ cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
transaction.rollback()
diff --git a/tests/transactions_regress/tests.py b/tests/transactions_regress/tests.py
index cada46edb2..1f9f291307 100644
--- a/tests/transactions_regress/tests.py
+++ b/tests/transactions_regress/tests.py
@@ -54,8 +54,8 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
@commit_on_success
def raw_sql():
"Write a record using raw sql under a commit_on_success decorator"
- cursor = connection.cursor()
- cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
+ with connection.cursor() as cursor:
+ cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
raw_sql()
# Rollback so that if the decorator didn't commit, the record is unwritten
@@ -143,10 +143,10 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
(reference). All this under commit_on_success, so the second insert should
be committed.
"""
- cursor = connection.cursor()
- cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
- transaction.rollback()
- cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
+ with connection.cursor() as cursor:
+ cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
+ transaction.rollback()
+ cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
reuse_cursor_ref()
# Rollback so that if the decorator didn't commit, the record is unwritten