summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Pope <nick.pope@flightdataservices.com>2018-12-01 23:43:27 +0000
committerTim Graham <timograham@gmail.com>2019-01-14 14:26:46 -0500
commit3d5e0f8394688d40036e27cfcfac295e6fe62269 (patch)
treee05e45c3fdb226538ac4230629d0cef648bd2a4b
parent7f1577d1efacea583d06f5786c2f4eee6878643a (diff)
Refs #28643 -- Moved db function mixins to a separate module.
-rw-r--r--django/db/models/functions/math.py63
-rw-r--r--django/db/models/functions/mixins.py27
2 files changed, 48 insertions, 42 deletions
diff --git a/django/db/models/functions/math.py b/django/db/models/functions/math.py
index c8760652b4..43cbc17a1e 100644
--- a/django/db/models/functions/math.py
+++ b/django/db/models/functions/math.py
@@ -1,56 +1,35 @@
import math
-import sys
from django.db.models.expressions import Func
-from django.db.models.fields import DecimalField, FloatField, IntegerField
+from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast
+from django.db.models.functions.mixins import (
+ FixDecimalInputMixin, NumericOutputFieldMixin,
+)
from django.db.models.lookups import Transform
-class DecimalInputMixin:
-
- def as_postgresql(self, compiler, connection, **extra_context):
- # Cast FloatField to DecimalField as PostgreSQL doesn't support the
- # following function signatures:
- # - LOG(double, double)
- # - MOD(double, double)
- output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
- clone = self.copy()
- clone.set_source_expressions([
- Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
- else expression for expression in self.get_source_expressions()
- ])
- return clone.as_sql(compiler, connection, **extra_context)
-
-
-class OutputFieldMixin:
-
- def _resolve_output_field(self):
- has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
- return DecimalField() if has_decimals else FloatField()
-
-
class Abs(Transform):
function = 'ABS'
lookup_name = 'abs'
-class ACos(OutputFieldMixin, Transform):
+class ACos(NumericOutputFieldMixin, Transform):
function = 'ACOS'
lookup_name = 'acos'
-class ASin(OutputFieldMixin, Transform):
+class ASin(NumericOutputFieldMixin, Transform):
function = 'ASIN'
lookup_name = 'asin'
-class ATan(OutputFieldMixin, Transform):
+class ATan(NumericOutputFieldMixin, Transform):
function = 'ATAN'
lookup_name = 'atan'
-class ATan2(OutputFieldMixin, Func):
+class ATan2(NumericOutputFieldMixin, Func):
function = 'ATAN2'
arity = 2
@@ -80,12 +59,12 @@ class Ceil(Transform):
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
-class Cos(OutputFieldMixin, Transform):
+class Cos(NumericOutputFieldMixin, Transform):
function = 'COS'
lookup_name = 'cos'
-class Cot(OutputFieldMixin, Transform):
+class Cot(NumericOutputFieldMixin, Transform):
function = 'COT'
lookup_name = 'cot'
@@ -93,7 +72,7 @@ class Cot(OutputFieldMixin, Transform):
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
-class Degrees(OutputFieldMixin, Transform):
+class Degrees(NumericOutputFieldMixin, Transform):
function = 'DEGREES'
lookup_name = 'degrees'
@@ -105,7 +84,7 @@ class Degrees(OutputFieldMixin, Transform):
)
-class Exp(OutputFieldMixin, Transform):
+class Exp(NumericOutputFieldMixin, Transform):
function = 'EXP'
lookup_name = 'exp'
@@ -115,12 +94,12 @@ class Floor(Transform):
lookup_name = 'floor'
-class Ln(OutputFieldMixin, Transform):
+class Ln(NumericOutputFieldMixin, Transform):
function = 'LN'
lookup_name = 'ln'
-class Log(DecimalInputMixin, OutputFieldMixin, Func):
+class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = 'LOG'
arity = 2
@@ -134,12 +113,12 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func):
return clone.as_sql(compiler, connection, **extra_context)
-class Mod(DecimalInputMixin, OutputFieldMixin, Func):
+class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = 'MOD'
arity = 2
-class Pi(OutputFieldMixin, Func):
+class Pi(NumericOutputFieldMixin, Func):
function = 'PI'
arity = 0
@@ -147,12 +126,12 @@ class Pi(OutputFieldMixin, Func):
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
-class Power(OutputFieldMixin, Func):
+class Power(NumericOutputFieldMixin, Func):
function = 'POWER'
arity = 2
-class Radians(OutputFieldMixin, Transform):
+class Radians(NumericOutputFieldMixin, Transform):
function = 'RADIANS'
lookup_name = 'radians'
@@ -169,16 +148,16 @@ class Round(Transform):
lookup_name = 'round'
-class Sin(OutputFieldMixin, Transform):
+class Sin(NumericOutputFieldMixin, Transform):
function = 'SIN'
lookup_name = 'sin'
-class Sqrt(OutputFieldMixin, Transform):
+class Sqrt(NumericOutputFieldMixin, Transform):
function = 'SQRT'
lookup_name = 'sqrt'
-class Tan(OutputFieldMixin, Transform):
+class Tan(NumericOutputFieldMixin, Transform):
function = 'TAN'
lookup_name = 'tan'
diff --git a/django/db/models/functions/mixins.py b/django/db/models/functions/mixins.py
new file mode 100644
index 0000000000..1bf3d6cbd0
--- /dev/null
+++ b/django/db/models/functions/mixins.py
@@ -0,0 +1,27 @@
+import sys
+
+from django.db.models.fields import DecimalField, FloatField
+from django.db.models.functions import Cast
+
+
+class FixDecimalInputMixin:
+
+ def as_postgresql(self, compiler, connection, **extra_context):
+ # Cast FloatField to DecimalField as PostgreSQL doesn't support the
+ # following function signatures:
+ # - LOG(double, double)
+ # - MOD(double, double)
+ output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
+ clone = self.copy()
+ clone.set_source_expressions([
+ Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
+ else expression for expression in self.get_source_expressions()
+ ])
+ return clone.as_sql(compiler, connection, **extra_context)
+
+
+class NumericOutputFieldMixin:
+
+ def _resolve_output_field(self):
+ has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
+ return DecimalField() if has_decimals else FloatField()