summaryrefslogtreecommitdiff
path: root/django/db/models/sql
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql')
-rw-r--r--django/db/models/sql/compiler.py78
-rw-r--r--django/db/models/sql/constants.py2
-rw-r--r--django/db/models/sql/subqueries.py8
3 files changed, 64 insertions, 24 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 123427cf8b..536a66d139 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -5,8 +5,8 @@ from django.core.exceptions import FieldError
from django.db.backends.utils import truncate_name
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend, QueryWrapper
-from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
- GET_ITERATOR_CHUNK_SIZE, SelectInfo)
+from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
+ ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
@@ -762,6 +762,8 @@ class SQLCompiler(object):
is needed, as the filters describe an empty set. In that case, None is
returned, to avoid any unnecessary database interaction.
"""
+ if not result_type:
+ result_type = NO_RESULTS
try:
sql, params = self.as_sql()
if not sql:
@@ -773,27 +775,44 @@ class SQLCompiler(object):
return
cursor = self.connection.cursor()
- cursor.execute(sql, params)
+ try:
+ cursor.execute(sql, params)
+ except:
+ cursor.close()
+ raise
- if not result_type:
+ if result_type == CURSOR:
+ # Caller didn't specify a result_type, so just give them back the
+ # cursor to process (and close).
return cursor
if result_type == SINGLE:
- if self.ordering_aliases:
- return cursor.fetchone()[:-len(self.ordering_aliases)]
- return cursor.fetchone()
+ try:
+ if self.ordering_aliases:
+ return cursor.fetchone()[:-len(self.ordering_aliases)]
+ return cursor.fetchone()
+ finally:
+ # done with the cursor
+ cursor.close()
+ if result_type == NO_RESULTS:
+ cursor.close()
+ return
# The MULTI case.
if self.ordering_aliases:
result = order_modified_iter(cursor, len(self.ordering_aliases),
self.connection.features.empty_fetchmany_value)
else:
- result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
- self.connection.features.empty_fetchmany_value)
+ result = cursor_iter(cursor,
+ self.connection.features.empty_fetchmany_value)
if not self.connection.features.can_use_chunked_reads:
- # If we are using non-chunked reads, we return the same data
- # structure as normally, but ensure it is all read into memory
- # before going any further.
- return list(result)
+ try:
+ # If we are using non-chunked reads, we return the same data
+ # structure as normally, but ensure it is all read into memory
+ # before going any further.
+ return list(result)
+ finally:
+ # done with the cursor
+ cursor.close()
return result
def as_subquery_condition(self, alias, columns, qn):
@@ -970,12 +989,15 @@ class SQLUpdateCompiler(SQLCompiler):
related queries are not available.
"""
cursor = super(SQLUpdateCompiler, self).execute_sql(result_type)
- rows = cursor.rowcount if cursor else 0
- is_empty = cursor is None
- del cursor
+ try:
+ rows = cursor.rowcount if cursor else 0
+ is_empty = cursor is None
+ finally:
+ if cursor:
+ cursor.close()
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
- if is_empty:
+ if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows
@@ -1111,6 +1133,19 @@ class SQLDateTimeCompiler(SQLCompiler):
yield datetime
+def cursor_iter(cursor, sentinel):
+ """
+ Yields blocks of rows from a cursor and ensures the cursor is closed when
+ done.
+ """
+ try:
+ for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+ sentinel):
+ yield rows
+ finally:
+ cursor.close()
+
+
def order_modified_iter(cursor, trim, sentinel):
"""
Yields blocks of rows from a cursor. We use this iterator in the special
@@ -1118,6 +1153,9 @@ def order_modified_iter(cursor, trim, sentinel):
requirements. We must trim those extra columns before anything else can use
the results, since they're only needed to make the SQL valid.
"""
- for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
- sentinel):
- yield [r[:-trim] for r in rows]
+ try:
+ for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+ sentinel):
+ yield [r[:-trim] for r in rows]
+ finally:
+ cursor.close()
diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py
index 904f7b2c8b..36aab23bae 100644
--- a/django/db/models/sql/constants.py
+++ b/django/db/models/sql/constants.py
@@ -33,6 +33,8 @@ SelectInfo = namedtuple('SelectInfo', 'col field')
# How many results to expect from a cursor.execute call
MULTI = 'multi'
SINGLE = 'single'
+CURSOR = 'cursor'
+NO_RESULTS = 'no results'
ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$')
ORDER_DIR = {
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index 86b1efd3f8..cfda1f552c 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -8,7 +8,7 @@ from django.db import connections
from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
-from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
+from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
from django.utils import six
@@ -30,7 +30,7 @@ class DeleteQuery(Query):
def do_query(self, table, where, using):
self.tables = [table]
self.where = where
- self.get_compiler(using).execute_sql(None)
+ self.get_compiler(using).execute_sql(NO_RESULTS)
def delete_batch(self, pk_list, using, field=None):
"""
@@ -82,7 +82,7 @@ class DeleteQuery(Query):
values = innerq
self.where = self.where_class()
self.add_q(Q(pk__in=values))
- self.get_compiler(using).execute_sql(None)
+ self.get_compiler(using).execute_sql(NO_RESULTS)
class UpdateQuery(Query):
@@ -116,7 +116,7 @@ class UpdateQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class()
self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
- self.get_compiler(using).execute_sql(None)
+ self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values):
"""