diff options
| author | Michael Manfre <mmanfre@gmail.com> | 2014-01-09 10:05:15 -0500 |
|---|---|---|
| committer | Michael Manfre <mmanfre@gmail.com> | 2014-02-02 12:47:21 -0500 |
| commit | 3ffeb931869cc68a8e0916219702ee282afc6e9d (patch) | |
| tree | f24020307dd5b529989329bfabfc95effcecffe8 /django/db | |
| parent | 0837eacc4e1fa7916e48135e8ba43f54a7a64997 (diff) | |
Ensure cursors are closed when no longer needed.
This commit touchs various parts of the code base and test framework. Any
found usage of opening a cursor for the sake of initializing a connection
has been replaced with 'ensure_connection()'.
Diffstat (limited to 'django/db')
| -rw-r--r-- | django/db/backends/__init__.py | 30 | ||||
| -rw-r--r-- | django/db/backends/creation.py | 71 | ||||
| -rw-r--r-- | django/db/backends/mysql/base.py | 37 | ||||
| -rw-r--r-- | django/db/backends/oracle/base.py | 4 | ||||
| -rw-r--r-- | django/db/backends/postgresql_psycopg2/base.py | 6 | ||||
| -rw-r--r-- | django/db/backends/postgresql_psycopg2/version.py | 6 | ||||
| -rw-r--r-- | django/db/backends/schema.py | 8 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/base.py | 16 | ||||
| -rw-r--r-- | django/db/models/query.py | 89 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 21 |
10 files changed, 149 insertions, 139 deletions
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index b96407056a..50b8745f47 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object): ##### Backend-specific savepoint management methods ##### def _savepoint(self, sid): - self.cursor().execute(self.ops.savepoint_create_sql(sid)) + with self.cursor() as cursor: + cursor.execute(self.ops.savepoint_create_sql(sid)) def _savepoint_rollback(self, sid): - self.cursor().execute(self.ops.savepoint_rollback_sql(sid)) + with self.cursor() as cursor: + cursor.execute(self.ops.savepoint_rollback_sql(sid)) def _savepoint_commit(self, sid): - self.cursor().execute(self.ops.savepoint_commit_sql(sid)) + with self.cursor() as cursor: + cursor.execute(self.ops.savepoint_commit_sql(sid)) def _savepoint_allowed(self): # Savepoints cannot be created outside a transaction @@ -688,15 +691,15 @@ class BaseDatabaseFeatures(object): # otherwise autocommit will cause the confimation to # fail. self.connection.enter_transaction_management() - cursor = self.connection.cursor() - cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)') - self.connection.commit() - cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)') - self.connection.rollback() - cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST') - count, = cursor.fetchone() - cursor.execute('DROP TABLE ROLLBACK_TEST') - self.connection.commit() + with self.connection.cursor() as cursor: + cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)') + self.connection.commit() + cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)') + self.connection.rollback() + cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST') + count, = cursor.fetchone() + cursor.execute('DROP TABLE ROLLBACK_TEST') + self.connection.commit() finally: self.connection.leave_transaction_management() return count == 0 @@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object): in sorting order between databases. """ if cursor is None: - cursor = self.connection.cursor() + with self.connection.cursor() as cursor: + return sorted(self.get_table_list(cursor)) return sorted(self.get_table_list(cursor)) def get_table_list(self, cursor): diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index ff62f30e71..3ee1e8448e 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -378,9 +378,8 @@ class BaseDatabaseCreation(object): call_command('createcachetable', database=self.connection.alias) - # Get a cursor (even though we don't need one yet). This has - # the side effect of initializing the test database. - self.connection.cursor() + # Ensure a connection for the side effect of initializing the test database. + self.connection.ensure_connection() return test_database_name @@ -406,34 +405,34 @@ class BaseDatabaseCreation(object): qn = self.connection.ops.quote_name # Create the test database and connect to it. - cursor = self._nodb_connection.cursor() - try: - cursor.execute( - "CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) - except Exception as e: - sys.stderr.write( - "Got an error creating the test database: %s\n" % e) - if not autoclobber: - confirm = input( - "Type 'yes' if you would like to try deleting the test " - "database '%s', or 'no' to cancel: " % test_database_name) - if autoclobber or confirm == 'yes': - try: - if verbosity >= 1: - print("Destroying old test database '%s'..." - % self.connection.alias) - cursor.execute( - "DROP DATABASE %s" % qn(test_database_name)) - cursor.execute( - "CREATE DATABASE %s %s" % (qn(test_database_name), - suffix)) - except Exception as e: - sys.stderr.write( - "Got an error recreating the test database: %s\n" % e) - sys.exit(2) - else: - print("Tests cancelled.") - sys.exit(1) + with self._nodb_connection.cursor() as cursor: + try: + cursor.execute( + "CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) + except Exception as e: + sys.stderr.write( + "Got an error creating the test database: %s\n" % e) + if not autoclobber: + confirm = input( + "Type 'yes' if you would like to try deleting the test " + "database '%s', or 'no' to cancel: " % test_database_name) + if autoclobber or confirm == 'yes': + try: + if verbosity >= 1: + print("Destroying old test database '%s'..." + % self.connection.alias) + cursor.execute( + "DROP DATABASE %s" % qn(test_database_name)) + cursor.execute( + "CREATE DATABASE %s %s" % (qn(test_database_name), + suffix)) + except Exception as e: + sys.stderr.write( + "Got an error recreating the test database: %s\n" % e) + sys.exit(2) + else: + print("Tests cancelled.") + sys.exit(1) return test_database_name @@ -461,11 +460,11 @@ class BaseDatabaseCreation(object): # ourselves. Connect to the previous database (not the test database) # to do so, because it's not allowed to delete a database while being # connected to it. - cursor = self._nodb_connection.cursor() - # Wait to avoid "database is being accessed by other users" errors. - time.sleep(1) - cursor.execute("DROP DATABASE %s" - % self.connection.ops.quote_name(test_database_name)) + with self._nodb_connection.cursor() as cursor: + # Wait to avoid "database is being accessed by other users" errors. + time.sleep(1) + cursor.execute("DROP DATABASE %s" + % self.connection.ops.quote_name(test_database_name)) def set_autocommit(self): """ diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index e7932dd800..9d3935dc54 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -180,15 +180,15 @@ class DatabaseFeatures(BaseDatabaseFeatures): @cached_property def _mysql_storage_engine(self): "Internal method used in Django tests. Don't rely on this from your code" - cursor = self.connection.cursor() - cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)') - # This command is MySQL specific; the second column - # will tell you the default table type of the created - # table. Since all Django's test tables will have the same - # table type, that's enough to evaluate the feature. - cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'") - result = cursor.fetchone() - cursor.execute('DROP TABLE INTROSPECT_TEST') + with self.connection.cursor() as cursor: + cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)') + # This command is MySQL specific; the second column + # will tell you the default table type of the created + # table. Since all Django's test tables will have the same + # table type, that's enough to evaluate the feature. + cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'") + result = cursor.fetchone() + cursor.execute('DROP TABLE INTROSPECT_TEST') return result[1] @cached_property @@ -207,9 +207,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): return False # Test if the time zone definitions are installed. - cursor = self.connection.cursor() - cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1") - return cursor.fetchone() is not None + with self.connection.cursor() as cursor: + cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1") + return cursor.fetchone() is not None class DatabaseOperations(BaseDatabaseOperations): @@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): return conn def init_connection_state(self): - cursor = self.connection.cursor() - # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column - # on a recently-inserted row will return when the field is tested for - # NULL. Disabling this value brings this aspect of MySQL in line with - # SQL standards. - cursor.execute('SET SQL_AUTO_IS_NULL = 0') - cursor.close() + with self.connection.cursor() as cursor: + # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column + # on a recently-inserted row will return when the field is tested for + # NULL. Disabling this value brings this aspect of MySQL in line with + # SQL standards. + cursor.execute('SET SQL_AUTO_IS_NULL = 0') def create_cursor(self): cursor = self.connection.cursor() diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index cdb101d20c..2495986a02 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -353,8 +353,8 @@ WHEN (new.%(col_name)s IS NULL) def regex_lookup(self, lookup_type): # If regex_lookup is called before it's been initialized, then create # a cursor to initialize it and recur. - self.connection.cursor() - return self.connection.ops.regex_lookup(lookup_type) + with self.connection.cursor(): + return self.connection.ops.regex_lookup(lookup_type) def return_insert_id(self): return "RETURNING %s INTO %%s", (InsertIdVar(),) diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 33f885d50c..e89a4e604a 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -149,8 +149,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): if conn_tz != tz: cursor = self.connection.cursor() - cursor.execute(self.ops.set_time_zone_sql(), [tz]) - cursor.close() + try: + cursor.execute(self.ops.set_time_zone_sql(), [tz]) + finally: + cursor.close() # Commit after setting the time zone (see #17062) if not self.get_autocommit(): self.connection.commit() diff --git a/django/db/backends/postgresql_psycopg2/version.py b/django/db/backends/postgresql_psycopg2/version.py index dae94f2dac..64fd7c8298 100644 --- a/django/db/backends/postgresql_psycopg2/version.py +++ b/django/db/backends/postgresql_psycopg2/version.py @@ -39,6 +39,6 @@ def get_version(connection): if hasattr(connection, 'server_version'): return connection.server_version else: - cursor = connection.cursor() - cursor.execute("SELECT version()") - return _parse_version(cursor.fetchone()[0]) + with connection.cursor() as cursor: + cursor.execute("SELECT version()") + return _parse_version(cursor.fetchone()[0]) diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index e690567738..88cc894437 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -86,14 +86,13 @@ class BaseDatabaseSchemaEditor(object): """ Executes the given SQL statement, with optional parameters. """ - # Get the cursor - cursor = self.connection.cursor() # Log the command we're running, then run it logger.debug("%s; (params %r)" % (sql, params)) if self.collect_sql: self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";") else: - cursor.execute(sql, params) + with self.connection.cursor() as cursor: + cursor.execute(sql, params) def quote_name(self, name): return self.connection.ops.quote_name(name) @@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object): Returns all constraint names matching the columns and conditions """ column_names = list(column_names) if column_names else None - constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) + with self.connection.cursor() as cursor: + constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table) result = [] for name, infodict in constraints.items(): if column_names is None or column_names == infodict['columns']: diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 3c8f170b7d..2adfbacaa9 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -122,14 +122,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): rule out support for STDDEV. We need to manually check whether the call works. """ - cursor = self.connection.cursor() - cursor.execute('CREATE TABLE STDDEV_TEST (X INT)') - try: - cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST') - has_support = True - except utils.DatabaseError: - has_support = False - cursor.execute('DROP TABLE STDDEV_TEST') + with self.connection.cursor() as cursor: + cursor.execute('CREATE TABLE STDDEV_TEST (X INT)') + try: + cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST') + has_support = True + except utils.DatabaseError: + has_support = False + cursor.execute('DROP TABLE STDDEV_TEST') return has_support @cached_property diff --git a/django/db/models/query.py b/django/db/models/query.py index 353dd95794..6051b9f859 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1522,54 +1522,59 @@ class RawQuerySet(object): query = iter(self.query) - # Find out which columns are model's fields, and which ones should be - # annotated to the model. - for pos, column in enumerate(self.columns): - if column in self.model_fields: - model_init_field_names[self.model_fields[column].attname] = pos - else: - annotation_fields.append((column, pos)) + try: + # Find out which columns are model's fields, and which ones should be + # annotated to the model. + for pos, column in enumerate(self.columns): + if column in self.model_fields: + model_init_field_names[self.model_fields[column].attname] = pos + else: + annotation_fields.append((column, pos)) - # Find out which model's fields are not present in the query. - skip = set() - for field in self.model._meta.fields: - if field.attname not in model_init_field_names: - skip.add(field.attname) - if skip: - if self.model._meta.pk.attname in skip: - raise InvalidQuery('Raw query must include the primary key') - model_cls = deferred_class_factory(self.model, skip) - else: - model_cls = self.model - # All model's fields are present in the query. So, it is possible - # to use *args based model instantation. For each field of the model, - # record the query column position matching that field. - model_init_field_pos = [] + # Find out which model's fields are not present in the query. + skip = set() for field in self.model._meta.fields: - model_init_field_pos.append(model_init_field_names[field.attname]) - if need_resolv_columns: - fields = [self.model_fields.get(c, None) for c in self.columns] - # Begin looping through the query values. - for values in query: - if need_resolv_columns: - values = compiler.resolve_columns(values, fields) - # Associate fields to values + if field.attname not in model_init_field_names: + skip.add(field.attname) if skip: - model_init_kwargs = {} - for attname, pos in six.iteritems(model_init_field_names): - model_init_kwargs[attname] = values[pos] - instance = model_cls(**model_init_kwargs) + if self.model._meta.pk.attname in skip: + raise InvalidQuery('Raw query must include the primary key') + model_cls = deferred_class_factory(self.model, skip) else: - model_init_args = [values[pos] for pos in model_init_field_pos] - instance = model_cls(*model_init_args) - if annotation_fields: - for column, pos in annotation_fields: - setattr(instance, column, values[pos]) + model_cls = self.model + # All model's fields are present in the query. So, it is possible + # to use *args based model instantation. For each field of the model, + # record the query column position matching that field. + model_init_field_pos = [] + for field in self.model._meta.fields: + model_init_field_pos.append(model_init_field_names[field.attname]) + if need_resolv_columns: + fields = [self.model_fields.get(c, None) for c in self.columns] + # Begin looping through the query values. + for values in query: + if need_resolv_columns: + values = compiler.resolve_columns(values, fields) + # Associate fields to values + if skip: + model_init_kwargs = {} + for attname, pos in six.iteritems(model_init_field_names): + model_init_kwargs[attname] = values[pos] + instance = model_cls(**model_init_kwargs) + else: + model_init_args = [values[pos] for pos in model_init_field_pos] + instance = model_cls(*model_init_args) + if annotation_fields: + for column, pos in annotation_fields: + setattr(instance, column, values[pos]) - instance._state.db = db - instance._state.adding = False + instance._state.db = db + instance._state.adding = False - yield instance + yield instance + finally: + # Done iterating the Query. If it has its own cursor, close it. + if hasattr(self.query, 'cursor') and self.query.cursor: + self.query.cursor.close() def __repr__(self): text = self.raw_query diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 536a66d139..d9161d820c 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1,4 +1,5 @@ import datetime +import sys from django.conf import settings from django.core.exceptions import FieldError @@ -777,7 +778,7 @@ class SQLCompiler(object): cursor = self.connection.cursor() try: cursor.execute(sql, params) - except: + except Exception: cursor.close() raise @@ -908,15 +909,15 @@ class SQLInsertCompiler(SQLCompiler): def execute_sql(self, return_id=False): assert not (return_id and len(self.query.objs) != 1) self.return_id = return_id - cursor = self.connection.cursor() - for sql, params in self.as_sql(): - cursor.execute(sql, params) - if not (return_id and cursor): - return - if self.connection.features.can_return_id_from_insert: - return self.connection.ops.fetch_returned_insert_id(cursor) - return self.connection.ops.last_insert_id(cursor, - self.query.get_meta().db_table, self.query.get_meta().pk.column) + with self.connection.cursor() as cursor: + for sql, params in self.as_sql(): + cursor.execute(sql, params) + if not (return_id and cursor): + return + if self.connection.features.can_return_id_from_insert: + return self.connection.ops.fetch_returned_insert_id(cursor) + return self.connection.ops.last_insert_id(cursor, + self.query.get_meta().db_table, self.query.get_meta().pk.column) class SQLDeleteCompiler(SQLCompiler): |
