summaryrefslogtreecommitdiff
path: root/django/db
diff options
context:
space:
mode:
authorMichael Manfre <mmanfre@gmail.com>2014-01-09 10:05:15 -0500
committerMichael Manfre <mmanfre@gmail.com>2014-02-02 12:47:21 -0500
commit3ffeb931869cc68a8e0916219702ee282afc6e9d (patch)
treef24020307dd5b529989329bfabfc95effcecffe8 /django/db
parent0837eacc4e1fa7916e48135e8ba43f54a7a64997 (diff)
Ensure cursors are closed when no longer needed.
This commit touchs various parts of the code base and test framework. Any found usage of opening a cursor for the sake of initializing a connection has been replaced with 'ensure_connection()'.
Diffstat (limited to 'django/db')
-rw-r--r--django/db/backends/__init__.py30
-rw-r--r--django/db/backends/creation.py71
-rw-r--r--django/db/backends/mysql/base.py37
-rw-r--r--django/db/backends/oracle/base.py4
-rw-r--r--django/db/backends/postgresql_psycopg2/base.py6
-rw-r--r--django/db/backends/postgresql_psycopg2/version.py6
-rw-r--r--django/db/backends/schema.py8
-rw-r--r--django/db/backends/sqlite3/base.py16
-rw-r--r--django/db/models/query.py89
-rw-r--r--django/db/models/sql/compiler.py21
10 files changed, 149 insertions, 139 deletions
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
index b96407056a..50b8745f47 100644
--- a/django/db/backends/__init__.py
+++ b/django/db/backends/__init__.py
@@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object):
##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
- self.cursor().execute(self.ops.savepoint_create_sql(sid))
+ with self.cursor() as cursor:
+ cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid):
- self.cursor().execute(self.ops.savepoint_rollback_sql(sid))
+ with self.cursor() as cursor:
+ cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid):
- self.cursor().execute(self.ops.savepoint_commit_sql(sid))
+ with self.cursor() as cursor:
+ cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction
@@ -688,15 +691,15 @@ class BaseDatabaseFeatures(object):
# otherwise autocommit will cause the confimation to
# fail.
self.connection.enter_transaction_management()
- cursor = self.connection.cursor()
- cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
- self.connection.commit()
- cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
- self.connection.rollback()
- cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
- count, = cursor.fetchone()
- cursor.execute('DROP TABLE ROLLBACK_TEST')
- self.connection.commit()
+ with self.connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
+ self.connection.commit()
+ cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
+ self.connection.rollback()
+ cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
+ count, = cursor.fetchone()
+ cursor.execute('DROP TABLE ROLLBACK_TEST')
+ self.connection.commit()
finally:
self.connection.leave_transaction_management()
return count == 0
@@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object):
in sorting order between databases.
"""
if cursor is None:
- cursor = self.connection.cursor()
+ with self.connection.cursor() as cursor:
+ return sorted(self.get_table_list(cursor))
return sorted(self.get_table_list(cursor))
def get_table_list(self, cursor):
diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py
index ff62f30e71..3ee1e8448e 100644
--- a/django/db/backends/creation.py
+++ b/django/db/backends/creation.py
@@ -378,9 +378,8 @@ class BaseDatabaseCreation(object):
call_command('createcachetable', database=self.connection.alias)
- # Get a cursor (even though we don't need one yet). This has
- # the side effect of initializing the test database.
- self.connection.cursor()
+ # Ensure a connection for the side effect of initializing the test database.
+ self.connection.ensure_connection()
return test_database_name
@@ -406,34 +405,34 @@ class BaseDatabaseCreation(object):
qn = self.connection.ops.quote_name
# Create the test database and connect to it.
- cursor = self._nodb_connection.cursor()
- try:
- cursor.execute(
- "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
- except Exception as e:
- sys.stderr.write(
- "Got an error creating the test database: %s\n" % e)
- if not autoclobber:
- confirm = input(
- "Type 'yes' if you would like to try deleting the test "
- "database '%s', or 'no' to cancel: " % test_database_name)
- if autoclobber or confirm == 'yes':
- try:
- if verbosity >= 1:
- print("Destroying old test database '%s'..."
- % self.connection.alias)
- cursor.execute(
- "DROP DATABASE %s" % qn(test_database_name))
- cursor.execute(
- "CREATE DATABASE %s %s" % (qn(test_database_name),
- suffix))
- except Exception as e:
- sys.stderr.write(
- "Got an error recreating the test database: %s\n" % e)
- sys.exit(2)
- else:
- print("Tests cancelled.")
- sys.exit(1)
+ with self._nodb_connection.cursor() as cursor:
+ try:
+ cursor.execute(
+ "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
+ except Exception as e:
+ sys.stderr.write(
+ "Got an error creating the test database: %s\n" % e)
+ if not autoclobber:
+ confirm = input(
+ "Type 'yes' if you would like to try deleting the test "
+ "database '%s', or 'no' to cancel: " % test_database_name)
+ if autoclobber or confirm == 'yes':
+ try:
+ if verbosity >= 1:
+ print("Destroying old test database '%s'..."
+ % self.connection.alias)
+ cursor.execute(
+ "DROP DATABASE %s" % qn(test_database_name))
+ cursor.execute(
+ "CREATE DATABASE %s %s" % (qn(test_database_name),
+ suffix))
+ except Exception as e:
+ sys.stderr.write(
+ "Got an error recreating the test database: %s\n" % e)
+ sys.exit(2)
+ else:
+ print("Tests cancelled.")
+ sys.exit(1)
return test_database_name
@@ -461,11 +460,11 @@ class BaseDatabaseCreation(object):
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
- cursor = self._nodb_connection.cursor()
- # Wait to avoid "database is being accessed by other users" errors.
- time.sleep(1)
- cursor.execute("DROP DATABASE %s"
- % self.connection.ops.quote_name(test_database_name))
+ with self._nodb_connection.cursor() as cursor:
+ # Wait to avoid "database is being accessed by other users" errors.
+ time.sleep(1)
+ cursor.execute("DROP DATABASE %s"
+ % self.connection.ops.quote_name(test_database_name))
def set_autocommit(self):
"""
diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py
index e7932dd800..9d3935dc54 100644
--- a/django/db/backends/mysql/base.py
+++ b/django/db/backends/mysql/base.py
@@ -180,15 +180,15 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
- cursor = self.connection.cursor()
- cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
- # This command is MySQL specific; the second column
- # will tell you the default table type of the created
- # table. Since all Django's test tables will have the same
- # table type, that's enough to evaluate the feature.
- cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
- result = cursor.fetchone()
- cursor.execute('DROP TABLE INTROSPECT_TEST')
+ with self.connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
+ # This command is MySQL specific; the second column
+ # will tell you the default table type of the created
+ # table. Since all Django's test tables will have the same
+ # table type, that's enough to evaluate the feature.
+ cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
+ result = cursor.fetchone()
+ cursor.execute('DROP TABLE INTROSPECT_TEST')
return result[1]
@cached_property
@@ -207,9 +207,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
# Test if the time zone definitions are installed.
- cursor = self.connection.cursor()
- cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
- return cursor.fetchone() is not None
+ with self.connection.cursor() as cursor:
+ cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
+ return cursor.fetchone() is not None
class DatabaseOperations(BaseDatabaseOperations):
@@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return conn
def init_connection_state(self):
- cursor = self.connection.cursor()
- # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
- # on a recently-inserted row will return when the field is tested for
- # NULL. Disabling this value brings this aspect of MySQL in line with
- # SQL standards.
- cursor.execute('SET SQL_AUTO_IS_NULL = 0')
- cursor.close()
+ with self.connection.cursor() as cursor:
+ # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
+ # on a recently-inserted row will return when the field is tested for
+ # NULL. Disabling this value brings this aspect of MySQL in line with
+ # SQL standards.
+ cursor.execute('SET SQL_AUTO_IS_NULL = 0')
def create_cursor(self):
cursor = self.connection.cursor()
diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py
index cdb101d20c..2495986a02 100644
--- a/django/db/backends/oracle/base.py
+++ b/django/db/backends/oracle/base.py
@@ -353,8 +353,8 @@ WHEN (new.%(col_name)s IS NULL)
def regex_lookup(self, lookup_type):
# If regex_lookup is called before it's been initialized, then create
# a cursor to initialize it and recur.
- self.connection.cursor()
- return self.connection.ops.regex_lookup(lookup_type)
+ with self.connection.cursor():
+ return self.connection.ops.regex_lookup(lookup_type)
def return_insert_id(self):
return "RETURNING %s INTO %%s", (InsertIdVar(),)
diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py
index 33f885d50c..e89a4e604a 100644
--- a/django/db/backends/postgresql_psycopg2/base.py
+++ b/django/db/backends/postgresql_psycopg2/base.py
@@ -149,8 +149,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if conn_tz != tz:
cursor = self.connection.cursor()
- cursor.execute(self.ops.set_time_zone_sql(), [tz])
- cursor.close()
+ try:
+ cursor.execute(self.ops.set_time_zone_sql(), [tz])
+ finally:
+ cursor.close()
# Commit after setting the time zone (see #17062)
if not self.get_autocommit():
self.connection.commit()
diff --git a/django/db/backends/postgresql_psycopg2/version.py b/django/db/backends/postgresql_psycopg2/version.py
index dae94f2dac..64fd7c8298 100644
--- a/django/db/backends/postgresql_psycopg2/version.py
+++ b/django/db/backends/postgresql_psycopg2/version.py
@@ -39,6 +39,6 @@ def get_version(connection):
if hasattr(connection, 'server_version'):
return connection.server_version
else:
- cursor = connection.cursor()
- cursor.execute("SELECT version()")
- return _parse_version(cursor.fetchone()[0])
+ with connection.cursor() as cursor:
+ cursor.execute("SELECT version()")
+ return _parse_version(cursor.fetchone()[0])
diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py
index e690567738..88cc894437 100644
--- a/django/db/backends/schema.py
+++ b/django/db/backends/schema.py
@@ -86,14 +86,13 @@ class BaseDatabaseSchemaEditor(object):
"""
Executes the given SQL statement, with optional parameters.
"""
- # Get the cursor
- cursor = self.connection.cursor()
# Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params))
if self.collect_sql:
self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";")
else:
- cursor.execute(sql, params)
+ with self.connection.cursor() as cursor:
+ cursor.execute(sql, params)
def quote_name(self, name):
return self.connection.ops.quote_name(name)
@@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object):
Returns all constraint names matching the columns and conditions
"""
column_names = list(column_names) if column_names else None
- constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
+ with self.connection.cursor() as cursor:
+ constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
result = []
for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']:
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
index 3c8f170b7d..2adfbacaa9 100644
--- a/django/db/backends/sqlite3/base.py
+++ b/django/db/backends/sqlite3/base.py
@@ -122,14 +122,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
rule out support for STDDEV. We need to manually check
whether the call works.
"""
- cursor = self.connection.cursor()
- cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
- try:
- cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
- has_support = True
- except utils.DatabaseError:
- has_support = False
- cursor.execute('DROP TABLE STDDEV_TEST')
+ with self.connection.cursor() as cursor:
+ cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
+ try:
+ cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
+ has_support = True
+ except utils.DatabaseError:
+ has_support = False
+ cursor.execute('DROP TABLE STDDEV_TEST')
return has_support
@cached_property
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 353dd95794..6051b9f859 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1522,54 +1522,59 @@ class RawQuerySet(object):
query = iter(self.query)
- # Find out which columns are model's fields, and which ones should be
- # annotated to the model.
- for pos, column in enumerate(self.columns):
- if column in self.model_fields:
- model_init_field_names[self.model_fields[column].attname] = pos
- else:
- annotation_fields.append((column, pos))
+ try:
+ # Find out which columns are model's fields, and which ones should be
+ # annotated to the model.
+ for pos, column in enumerate(self.columns):
+ if column in self.model_fields:
+ model_init_field_names[self.model_fields[column].attname] = pos
+ else:
+ annotation_fields.append((column, pos))
- # Find out which model's fields are not present in the query.
- skip = set()
- for field in self.model._meta.fields:
- if field.attname not in model_init_field_names:
- skip.add(field.attname)
- if skip:
- if self.model._meta.pk.attname in skip:
- raise InvalidQuery('Raw query must include the primary key')
- model_cls = deferred_class_factory(self.model, skip)
- else:
- model_cls = self.model
- # All model's fields are present in the query. So, it is possible
- # to use *args based model instantation. For each field of the model,
- # record the query column position matching that field.
- model_init_field_pos = []
+ # Find out which model's fields are not present in the query.
+ skip = set()
for field in self.model._meta.fields:
- model_init_field_pos.append(model_init_field_names[field.attname])
- if need_resolv_columns:
- fields = [self.model_fields.get(c, None) for c in self.columns]
- # Begin looping through the query values.
- for values in query:
- if need_resolv_columns:
- values = compiler.resolve_columns(values, fields)
- # Associate fields to values
+ if field.attname not in model_init_field_names:
+ skip.add(field.attname)
if skip:
- model_init_kwargs = {}
- for attname, pos in six.iteritems(model_init_field_names):
- model_init_kwargs[attname] = values[pos]
- instance = model_cls(**model_init_kwargs)
+ if self.model._meta.pk.attname in skip:
+ raise InvalidQuery('Raw query must include the primary key')
+ model_cls = deferred_class_factory(self.model, skip)
else:
- model_init_args = [values[pos] for pos in model_init_field_pos]
- instance = model_cls(*model_init_args)
- if annotation_fields:
- for column, pos in annotation_fields:
- setattr(instance, column, values[pos])
+ model_cls = self.model
+ # All model's fields are present in the query. So, it is possible
+ # to use *args based model instantation. For each field of the model,
+ # record the query column position matching that field.
+ model_init_field_pos = []
+ for field in self.model._meta.fields:
+ model_init_field_pos.append(model_init_field_names[field.attname])
+ if need_resolv_columns:
+ fields = [self.model_fields.get(c, None) for c in self.columns]
+ # Begin looping through the query values.
+ for values in query:
+ if need_resolv_columns:
+ values = compiler.resolve_columns(values, fields)
+ # Associate fields to values
+ if skip:
+ model_init_kwargs = {}
+ for attname, pos in six.iteritems(model_init_field_names):
+ model_init_kwargs[attname] = values[pos]
+ instance = model_cls(**model_init_kwargs)
+ else:
+ model_init_args = [values[pos] for pos in model_init_field_pos]
+ instance = model_cls(*model_init_args)
+ if annotation_fields:
+ for column, pos in annotation_fields:
+ setattr(instance, column, values[pos])
- instance._state.db = db
- instance._state.adding = False
+ instance._state.db = db
+ instance._state.adding = False
- yield instance
+ yield instance
+ finally:
+ # Done iterating the Query. If it has its own cursor, close it.
+ if hasattr(self.query, 'cursor') and self.query.cursor:
+ self.query.cursor.close()
def __repr__(self):
text = self.raw_query
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 536a66d139..d9161d820c 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -1,4 +1,5 @@
import datetime
+import sys
from django.conf import settings
from django.core.exceptions import FieldError
@@ -777,7 +778,7 @@ class SQLCompiler(object):
cursor = self.connection.cursor()
try:
cursor.execute(sql, params)
- except:
+ except Exception:
cursor.close()
raise
@@ -908,15 +909,15 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id
- cursor = self.connection.cursor()
- for sql, params in self.as_sql():
- cursor.execute(sql, params)
- if not (return_id and cursor):
- return
- if self.connection.features.can_return_id_from_insert:
- return self.connection.ops.fetch_returned_insert_id(cursor)
- return self.connection.ops.last_insert_id(cursor,
- self.query.get_meta().db_table, self.query.get_meta().pk.column)
+ with self.connection.cursor() as cursor:
+ for sql, params in self.as_sql():
+ cursor.execute(sql, params)
+ if not (return_id and cursor):
+ return
+ if self.connection.features.can_return_id_from_insert:
+ return self.connection.ops.fetch_returned_insert_id(cursor)
+ return self.connection.ops.last_insert_id(cursor,
+ self.query.get_meta().db_table, self.query.get_meta().pk.column)
class SQLDeleteCompiler(SQLCompiler):