summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Pope <nick.pope@flightdataservices.com>2018-12-01 23:49:38 +0000
committerTim Graham <timograham@gmail.com>2019-01-14 14:39:57 -0500
commit846624ed0858aec0e51baebaa5b397e135c6d1dc (patch)
treea98697ebce990b44ab89c4d2f6f8d32a9fc8c7a0
parent6d4efa8e6a4cc7be4ba957dec71f6f63cd58700d (diff)
Refs #28643 -- Extracted DurationField logic for Avg() and Sum() into mixin.
Also addresses Sum() not handling the filter option correctly.
-rw-r--r--django/db/models/aggregates.py38
-rw-r--r--django/db/models/functions/mixins.py19
2 files changed, 24 insertions, 33 deletions
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index 1b0f9d98af..ac0b62d0bf 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -4,7 +4,9 @@ Classes to represent the definitions of aggregate functions.
from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
-from django.db.models.functions.mixins import NumericOutputFieldMixin
+from django.db.models.functions.mixins import (
+ FixDurationInputMixin, NumericOutputFieldMixin,
+)
__all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@@ -94,25 +96,10 @@ class Aggregate(Func):
return options
-class Avg(NumericOutputFieldMixin, Aggregate):
+class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
function = 'AVG'
name = 'Avg'
- def as_mysql(self, compiler, connection, **extra_context):
- sql, params = super().as_sql(compiler, connection, **extra_context)
- if self.output_field.get_internal_type() == 'DurationField':
- sql = 'CAST(%s as SIGNED)' % sql
- return sql, params
-
- def as_oracle(self, compiler, connection, **extra_context):
- if self.output_field.get_internal_type() == 'DurationField':
- expression = self.get_source_expressions()[0]
- from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
- return compiler.compile(
- SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
- )
- return super().as_sql(compiler, connection, **extra_context)
-
class Count(Aggregate):
function = 'COUNT'
@@ -152,25 +139,10 @@ class StdDev(NumericOutputFieldMixin, Aggregate):
return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
-class Sum(Aggregate):
+class Sum(FixDurationInputMixin, Aggregate):
function = 'SUM'
name = 'Sum'
- def as_mysql(self, compiler, connection, **extra_context):
- sql, params = super().as_sql(compiler, connection, **extra_context)
- if self.output_field.get_internal_type() == 'DurationField':
- sql = 'CAST(%s as SIGNED)' % sql
- return sql, params
-
- def as_oracle(self, compiler, connection, **extra_context):
- if self.output_field.get_internal_type() == 'DurationField':
- expression = self.get_source_expressions()[0]
- from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
- return compiler.compile(
- SecondsToInterval(Sum(IntervalToSeconds(expression)))
- )
- return super().as_sql(compiler, connection, **extra_context)
-
class Variance(NumericOutputFieldMixin, Aggregate):
name = 'Variance'
diff --git a/django/db/models/functions/mixins.py b/django/db/models/functions/mixins.py
index 9b46987788..8486ddb005 100644
--- a/django/db/models/functions/mixins.py
+++ b/django/db/models/functions/mixins.py
@@ -20,6 +20,25 @@ class FixDecimalInputMixin:
return clone.as_sql(compiler, connection, **extra_context)
+class FixDurationInputMixin:
+
+ def as_mysql(self, compiler, connection, **extra_context):
+ sql, params = super().as_sql(compiler, connection, **extra_context)
+ if self.output_field.get_internal_type() == 'DurationField':
+ sql = 'CAST(%s AS SIGNED)' % sql
+ return sql, params
+
+ def as_oracle(self, compiler, connection, **extra_context):
+ if self.output_field.get_internal_type() == 'DurationField':
+ expression = self.get_source_expressions()[0]
+ options = self._get_repr_options()
+ from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
+ return compiler.compile(
+ SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
+ )
+ return super().as_sql(compiler, connection, **extra_context)
+
+
class NumericOutputFieldMixin:
def _resolve_output_field(self):