diff options
| author | Adrian Holovaty <adrian@holovaty.com> | 2006-05-02 01:31:56 +0000 |
|---|---|---|
| committer | Adrian Holovaty <adrian@holovaty.com> | 2006-05-02 01:31:56 +0000 |
| commit | f69cf70ed813a8cd7e1f963a14ae39103e8d5265 (patch) | |
| tree | d3b32e84cd66573b3833ddf662af020f8ef2f7a8 /django/db/models | |
| parent | d5dbeaa9be359a4c794885c2e9f1b5a7e5e51fb8 (diff) | |
MERGED MAGIC-REMOVAL BRANCH TO TRUNK. This change is highly backwards-incompatible. Please read http://code.djangoproject.com/wiki/RemovingTheMagic for upgrade instructions.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@2809 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models')
| -rw-r--r-- | django/db/models/__init__.py | 40 | ||||
| -rw-r--r-- | django/db/models/base.py | 401 | ||||
| -rw-r--r-- | django/db/models/fields/__init__.py | 788 | ||||
| -rw-r--r-- | django/db/models/fields/related.py | 718 | ||||
| -rw-r--r-- | django/db/models/loading.py | 71 | ||||
| -rw-r--r-- | django/db/models/manager.py | 101 | ||||
| -rw-r--r-- | django/db/models/manipulators.py | 330 | ||||
| -rw-r--r-- | django/db/models/options.py | 269 | ||||
| -rw-r--r-- | django/db/models/query.py | 888 | ||||
| -rw-r--r-- | django/db/models/related.py | 132 | ||||
| -rw-r--r-- | django/db/models/signals.py | 12 |
11 files changed, 3750 insertions, 0 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py new file mode 100644 index 0000000000..d708fa60bc --- /dev/null +++ b/django/db/models/__init__.py @@ -0,0 +1,40 @@ +from django.conf import settings +from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured +from django.core import validators +from django.db import backend, connection +from django.db.models.loading import get_apps, get_app, get_models, get_model, register_models +from django.db.models.query import Q +from django.db.models.manager import Manager +from django.db.models.base import Model, AdminOptions +from django.db.models.fields import * +from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel, TABULAR, STACKED +from django.db.models import signals +from django.utils.functional import curry +from django.utils.text import capfirst + +# Admin stages. +ADD, CHANGE, BOTH = 1, 2, 3 + +class LazyDate: + """ + Use in limit_choices_to to compare the field to dates calculated at run time + instead of when the model is loaded. For example:: + + ... limit_choices_to = {'date__gt' : models.LazyDate(days=-3)} ... + + which will limit the choices to dates greater than three days ago. + """ + def __init__(self, **kwargs): + self.delta = datetime.timedelta(**kwargs) + + def __str__(self): + return str(self.__get_value__()) + + def __repr__(self): + return "<LazyDate: %s>" % self.delta + + def __get_value__(self): + return datetime.datetime.now() + self.delta + + def __getattr__(self, attr): + return getattr(self.__get_value__(), attr) diff --git a/django/db/models/base.py b/django/db/models/base.py new file mode 100644 index 0000000000..2185471e2b --- /dev/null +++ b/django/db/models/base.py @@ -0,0 +1,401 @@ +import django.db.models.manipulators +import django.db.models.manager +from django.core import validators +from django.core.exceptions import ObjectDoesNotExist +from django.db.models.fields import AutoField, ImageField, FieldDoesNotExist +from django.db.models.fields.related import OneToOneRel, ManyToOneRel +from django.db.models.related import RelatedObject +from django.db.models.query import orderlist2sql, delete_objects +from django.db.models.options import Options, AdminOptions +from django.db import connection, backend, transaction +from django.db.models import signals +from django.db.models.loading import register_models +from django.dispatch import dispatcher +from django.utils.datastructures import SortedDict +from django.utils.functional import curry +from django.conf import settings +import types +import sys +import os + +class ModelBase(type): + "Metaclass for all models" + def __new__(cls, name, bases, attrs): + # If this isn't a subclass of Model, don't do anything special. + if not bases or bases == (object,): + return type.__new__(cls, name, bases, attrs) + + # Create the class. + new_class = type.__new__(cls, name, bases, {'__module__': attrs.pop('__module__')}) + new_class.add_to_class('_meta', Options(attrs.pop('Meta', None))) + new_class.add_to_class('DoesNotExist', types.ClassType('DoesNotExist', (ObjectDoesNotExist,), {})) + + # Build complete list of parents + for base in bases: + # TODO: Checking for the presence of '_meta' is hackish. + if '_meta' in dir(base): + new_class._meta.parents.append(base) + new_class._meta.parents.extend(base._meta.parents) + + model_module = sys.modules[new_class.__module__] + + if getattr(new_class._meta, 'app_label', None) is None: + # Figure out the app_label by looking one level up. + # For 'django.contrib.sites.models', this would be 'sites'. + new_class._meta.app_label = model_module.__name__.split('.')[-2] + + # Add all attributes to the class. + for obj_name, obj in attrs.items(): + new_class.add_to_class(obj_name, obj) + + # Add Fields inherited from parents + for parent in new_class._meta.parents: + for field in parent._meta.fields: + # Only add parent fields if they aren't defined for this class. + try: + new_class._meta.get_field(field.name) + except FieldDoesNotExist: + field.contribute_to_class(new_class, field.name) + + new_class._prepare() + + register_models(new_class._meta.app_label, new_class) + return new_class + +class Model(object): + __metaclass__ = ModelBase + + def _get_pk_val(self): + return getattr(self, self._meta.pk.attname) + + def __repr__(self): + return '<%s object>' % self.__class__.__name__ + + def __eq__(self, other): + return isinstance(other, self.__class__) and self._get_pk_val() == other._get_pk_val() + + def __ne__(self, other): + return not self.__eq__(other) + + def __init__(self, *args, **kwargs): + dispatcher.send(signal=signals.pre_init, sender=self.__class__, args=args, kwargs=kwargs) + for f in self._meta.fields: + if isinstance(f.rel, ManyToOneRel): + try: + # Assume object instance was passed in. + rel_obj = kwargs.pop(f.name) + except KeyError: + try: + # Object instance wasn't passed in -- must be an ID. + val = kwargs.pop(f.attname) + except KeyError: + val = f.get_default() + else: + # Object instance was passed in. + # Special case: You can pass in "None" for related objects if it's allowed. + if rel_obj is None and f.null: + val = None + else: + try: + val = getattr(rel_obj, f.rel.get_related_field().attname) + except AttributeError: + raise TypeError, "Invalid value: %r should be a %s instance, not a %s" % (f.name, f.rel.to, type(rel_obj)) + setattr(self, f.attname, val) + else: + val = kwargs.pop(f.attname, f.get_default()) + setattr(self, f.attname, val) + if kwargs: + raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0] + for i, arg in enumerate(args): + setattr(self, self._meta.fields[i].attname, arg) + dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self) + + def add_to_class(cls, name, value): + if name == 'Admin': + assert type(value) == types.ClassType, "%r attribute of %s model must be a class, not a %s object" % (name, cls.__name__, type(value)) + value = AdminOptions(**dict([(k, v) for k, v in value.__dict__.items() if not k.startswith('_')])) + if hasattr(value, 'contribute_to_class'): + value.contribute_to_class(cls, name) + else: + setattr(cls, name, value) + add_to_class = classmethod(add_to_class) + + def _prepare(cls): + # Creates some methods once self._meta has been populated. + opts = cls._meta + opts._prepare(cls) + + if opts.order_with_respect_to: + cls.get_next_in_order = curry(cls._get_next_or_previous_in_order, is_next=True) + cls.get_previous_in_order = curry(cls._get_next_or_previous_in_order, is_next=False) + setattr(opts.order_with_respect_to.rel.to, 'get_%s_order' % cls.__name__.lower(), curry(method_get_order, cls)) + setattr(opts.order_with_respect_to.rel.to, 'set_%s_order' % cls.__name__.lower(), curry(method_set_order, cls)) + + # Give the class a docstring -- its definition. + if cls.__doc__ is None: + cls.__doc__ = "%s(%s)" % (cls.__name__, ", ".join([f.attname for f in opts.fields])) + + if hasattr(cls, 'get_absolute_url'): + cls.get_absolute_url = curry(get_absolute_url, opts, cls.get_absolute_url) + + dispatcher.send(signal=signals.class_prepared, sender=cls) + + _prepare = classmethod(_prepare) + + def save(self): + dispatcher.send(signal=signals.pre_save, sender=self.__class__, instance=self) + + non_pks = [f for f in self._meta.fields if not f.primary_key] + cursor = connection.cursor() + + # First, try an UPDATE. If that doesn't update anything, do an INSERT. + pk_val = self._get_pk_val() + pk_set = bool(pk_val) + record_exists = True + if pk_set: + # Determine whether a record with the primary key already exists. + cursor.execute("SELECT 1 FROM %s WHERE %s=%%s LIMIT 1" % \ + (backend.quote_name(self._meta.db_table), backend.quote_name(self._meta.pk.column)), [pk_val]) + # If it does already exist, do an UPDATE. + if cursor.fetchone(): + db_values = [f.get_db_prep_save(f.pre_save(getattr(self, f.attname), False)) for f in non_pks] + cursor.execute("UPDATE %s SET %s WHERE %s=%%s" % \ + (backend.quote_name(self._meta.db_table), + ','.join(['%s=%%s' % backend.quote_name(f.column) for f in non_pks]), + backend.quote_name(self._meta.pk.attname)), + db_values + [pk_val]) + else: + record_exists = False + if not pk_set or not record_exists: + field_names = [backend.quote_name(f.column) for f in self._meta.fields if not isinstance(f, AutoField)] + db_values = [f.get_db_prep_save(f.pre_save(getattr(self, f.attname), True)) for f in self._meta.fields if not isinstance(f, AutoField)] + # If the PK has been manually set, respect that. + if pk_set: + field_names += [f.column for f in self._meta.fields if isinstance(f, AutoField)] + db_values += [f.get_db_prep_save(f.pre_save(getattr(self, f.column), True)) for f in self._meta.fields if isinstance(f, AutoField)] + placeholders = ['%s'] * len(field_names) + if self._meta.order_with_respect_to: + field_names.append(backend.quote_name('_order')) + # TODO: This assumes the database supports subqueries. + placeholders.append('(SELECT COUNT(*) FROM %s WHERE %s = %%s)' % \ + (backend.quote_name(self._meta.db_table), backend.quote_name(self._meta.order_with_respect_to.column))) + db_values.append(getattr(self, self._meta.order_with_respect_to.attname)) + cursor.execute("INSERT INTO %s (%s) VALUES (%s)" % \ + (backend.quote_name(self._meta.db_table), ','.join(field_names), + ','.join(placeholders)), db_values) + if self._meta.has_auto_field and not pk_set: + setattr(self, self._meta.pk.attname, backend.get_last_insert_id(cursor, self._meta.db_table, self._meta.pk.column)) + transaction.commit_unless_managed() + + # Run any post-save hooks. + dispatcher.send(signal=signals.post_save, sender=self.__class__, instance=self) + + save.alters_data = True + + def validate(self): + """ + First coerces all fields on this instance to their proper Python types. + Then runs validation on every field. Returns a dictionary of + field_name -> error_list. + """ + error_dict = {} + invalid_python = {} + for f in self._meta.fields: + try: + setattr(self, f.attname, f.to_python(getattr(self, f.attname, f.get_default()))) + except validators.ValidationError, e: + error_dict[f.name] = e.messages + invalid_python[f.name] = 1 + for f in self._meta.fields: + if f.name in invalid_python: + continue + errors = f.validate_full(getattr(self, f.attname, f.get_default()), self.__dict__) + if errors: + error_dict[f.name] = errors + return error_dict + + def _collect_sub_objects(self, seen_objs): + """ + Recursively populates seen_objs with all objects related to this object. + When done, seen_objs will be in the format: + {model_class: {pk_val: obj, pk_val: obj, ...}, + model_class: {pk_val: obj, pk_val: obj, ...}, ...} + """ + pk_val = self._get_pk_val() + if pk_val in seen_objs.setdefault(self.__class__, {}): + return + seen_objs.setdefault(self.__class__, {})[pk_val] = self + + for related in self._meta.get_all_related_objects(): + rel_opts_name = related.get_accessor_name() + if isinstance(related.field.rel, OneToOneRel): + try: + sub_obj = getattr(self, rel_opts_name) + except ObjectDoesNotExist: + pass + else: + sub_obj._collect_sub_objects(seen_objs) + else: + for sub_obj in getattr(self, rel_opts_name).all(): + sub_obj._collect_sub_objects(seen_objs) + + def delete(self): + assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname) + + # Find all the objects than need to be deleted + seen_objs = SortedDict() + self._collect_sub_objects(seen_objs) + + # Actually delete the objects + delete_objects(seen_objs) + + delete.alters_data = True + + def _get_FIELD_display(self, field): + value = getattr(self, field.attname) + return dict(field.choices).get(value, value) + + def _get_next_or_previous_by_FIELD(self, field, is_next): + op = is_next and '>' or '<' + where = '(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \ + (backend.quote_name(field.column), op, backend.quote_name(field.column), + backend.quote_name(self._meta.db_table), backend.quote_name(self._meta.pk.column), op) + param = str(getattr(self, field.attname)) + q = self.__class__._default_manager.order_by((not is_next and '-' or '') + field.name, (not is_next and '-' or '') + self._meta.pk.name) + q._where.append(where) + q._params.extend([param, param, getattr(self, self._meta.pk.attname)]) + return q[0] + + def _get_next_or_previous_in_order(self, is_next): + cachename = "__%s_order_cache" % is_next + if not hasattr(self, cachename): + op = is_next and '>' or '<' + order_field = self._meta.order_with_respect_to + where = ['%s %s (SELECT %s FROM %s WHERE %s=%%s)' % \ + (backend.quote_name('_order'), op, backend.quote_name('_order'), + backend.quote_name(opts.db_table), backend.quote_name(opts.pk.column)), + '%s=%%s' % backend.quote_name(order_field.column)] + params = [self._get_pk_val(), getattr(self, order_field.attname)] + obj = self._default_manager.order_by('_order').extra(where=where, params=params)[:1].get() + setattr(self, cachename, obj) + return getattr(self, cachename) + + def _get_FIELD_filename(self, field): + return os.path.join(settings.MEDIA_ROOT, getattr(self, field.attname)) + + def _get_FIELD_url(self, field): + if getattr(self, field.attname): # value is not blank + import urlparse + return urlparse.urljoin(settings.MEDIA_URL, getattr(self, field.attname)).replace('\\', '/') + return '' + + def _get_FIELD_size(self, field): + return os.path.getsize(self.__get_FIELD_filename(field)) + + def _save_FIELD_file(self, field, filename, raw_contents): + directory = field.get_directory_name() + try: # Create the date-based directory if it doesn't exist. + os.makedirs(os.path.join(settings.MEDIA_ROOT, directory)) + except OSError: # Directory probably already exists. + pass + filename = field.get_filename(filename) + + # If the filename already exists, keep adding an underscore to the name of + # the file until the filename doesn't exist. + while os.path.exists(os.path.join(settings.MEDIA_ROOT, filename)): + try: + dot_index = filename.rindex('.') + except ValueError: # filename has no dot + filename += '_' + else: + filename = filename[:dot_index] + '_' + filename[dot_index:] + + # Write the file to disk. + setattr(self, field.attname, filename) + + full_filename = self._get_FIELD_filename(field) + fp = open(full_filename, 'wb') + fp.write(raw_contents) + fp.close() + + # Save the width and/or height, if applicable. + if isinstance(field, ImageField) and (field.width_field or field.height_field): + from django.utils.images import get_image_dimensions + width, height = get_image_dimensions(full_filename) + if field.width_field: + setattr(self, field.width_field, width) + if field.height_field: + setattr(self, field.height_field, height) + + # Save the object, because it has changed. + self.save() + + _save_FIELD_file.alters_data = True + + def _get_FIELD_width(self, field): + return self.__get_image_dimensions(field)[0] + + def _get_FIELD_height(self, field): + return self.__get_image_dimensions(field)[1] + + def _get_image_dimensions(self, field): + cachename = "__%s_dimensions_cache" % field.name + if not hasattr(self, cachename): + from django.utils.images import get_image_dimensions + filename = self.__get_FIELD_filename(field)() + setattr(self, cachename, get_image_dimensions(filename)) + return getattr(self, cachename) + + # Handles setting many-to-many related objects. + # Example: Album.set_songs() + def _set_related_many_to_many(self, rel_class, rel_field, id_list): + id_list = map(int, id_list) # normalize to integers + rel = rel_field.rel.to + m2m_table = rel_field.m2m_db_table() + this_id = self._get_pk_val() + cursor = connection.cursor() + cursor.execute("DELETE FROM %s WHERE %s = %%s" % \ + (backend.quote_name(m2m_table), + backend.quote_name(rel_field.m2m_column_name())), [this_id]) + sql = "INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \ + (backend.quote_name(m2m_table), + backend.quote_name(rel_field.m2m_column_name()), + backend.quote_name(rel_field.m2m_reverse_name())) + cursor.executemany(sql, [(this_id, i) for i in id_list]) + transaction.commit_unless_managed() + +############################################ +# HELPER FUNCTIONS (CURRIED MODEL METHODS) # +############################################ + +# ORDERING METHODS ######################### + +def method_set_order(ordered_obj, self, id_list): + cursor = connection.cursor() + # Example: "UPDATE poll_choices SET _order = %s WHERE poll_id = %s AND id = %s" + sql = "UPDATE %s SET %s = %%s WHERE %s = %%s AND %s = %%s" % \ + (backend.quote_name(ordered_obj.db_table), backend.quote_name('_order'), + backend.quote_name(ordered_obj.order_with_respect_to.column), + backend.quote_name(ordered_obj.pk.column)) + rel_val = getattr(self, ordered_obj.order_with_respect_to.rel.field_name) + cursor.executemany(sql, [(i, rel_val, j) for i, j in enumerate(id_list)]) + transaction.commit_unless_managed() + +def method_get_order(ordered_obj, self): + cursor = connection.cursor() + # Example: "SELECT id FROM poll_choices WHERE poll_id = %s ORDER BY _order" + sql = "SELECT %s FROM %s WHERE %s = %%s ORDER BY %s" % \ + (backend.quote_name(ordered_obj._meta.pk.column), + backend.quote_name(ordered_obj._meta.db_table), + backend.quote_name(ordered_obj._meta.order_with_respect_to.column), + backend.quote_name('_order')) + rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.rel.field_name) + cursor.execute(sql, [rel_val]) + return [r[0] for r in cursor.fetchall()] + +############################################## +# HELPER FUNCTIONS (CURRIED MODEL FUNCTIONS) # +############################################## + +def get_absolute_url(opts, func, self): + return settings.ABSOLUTE_URL_OVERRIDES.get('%s.%s' % (opts.app_label, opts.module_name), func)(self) diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py new file mode 100644 index 0000000000..8cc17079a9 --- /dev/null +++ b/django/db/models/fields/__init__.py @@ -0,0 +1,788 @@ +from django.db.models import signals +from django.dispatch import dispatcher +from django.conf import settings +from django.core import validators +from django import forms +from django.core.exceptions import ObjectDoesNotExist +from django.utils.functional import curry, lazy +from django.utils.text import capfirst +from django.utils.translation import gettext, gettext_lazy, ngettext +import datetime, os, time + +class NOT_PROVIDED: + pass + +# Values for filter_interface. +HORIZONTAL, VERTICAL = 1, 2 + +# The values to use for "blank" in SelectFields. Will be appended to the start of most "choices" lists. +BLANK_CHOICE_DASH = [("", "---------")] +BLANK_CHOICE_NONE = [("", "None")] + +# prepares a value for use in a LIKE query +prep_for_like_query = lambda x: str(x).replace("%", "\%").replace("_", "\_") + +# returns the <ul> class for a given radio_admin value +get_ul_class = lambda x: 'radiolist%s' % ((x == HORIZONTAL) and ' inline' or '') + +class FieldDoesNotExist(Exception): + pass + +def manipulator_validator_unique(f, opts, self, field_data, all_data): + "Validates that the value is unique for this field." + lookup_type = f.get_validator_unique_lookup_type() + try: + old_obj = self.manager.get(**{lookup_type: field_data}) + except ObjectDoesNotExist: + return + if getattr(self, 'original_object', None) and self.original_object._get_pk_val() == old_obj._get_pk_val(): + return + raise validators.ValidationError, gettext("%(optname)s with this %(fieldname)s already exists.") % {'optname': capfirst(opts.verbose_name), 'fieldname': f.verbose_name} + +# A guide to Field parameters: +# +# * name: The name of the field specifed in the model. +# * attname: The attribute to use on the model object. This is the same as +# "name", except in the case of ForeignKeys, where "_id" is +# appended. +# * db_column: The db_column specified in the model (or None). +# * column: The database column for this field. This is the same as +# "attname", except if db_column is specified. +# +# Code that introspects values, or does other dynamic things, should use +# attname. For example, this gets the primary key value of object "obj": +# +# getattr(obj, opts.pk.attname) + +class Field(object): + + # Designates whether empty strings fundamentally are allowed at the + # database level. + empty_strings_allowed = True + + # Tracks each time a Field instance is created. Used to retain order. + creation_counter = 0 + + def __init__(self, verbose_name=None, name=None, primary_key=False, + maxlength=None, unique=False, blank=False, null=False, db_index=False, + core=False, rel=None, default=NOT_PROVIDED, editable=True, + prepopulate_from=None, unique_for_date=None, unique_for_month=None, + unique_for_year=None, validator_list=None, choices=None, radio_admin=None, + help_text='', db_column=None): + self.name = name + self.verbose_name = verbose_name + self.primary_key = primary_key + self.maxlength, self.unique = maxlength, unique + self.blank, self.null = blank, null + self.core, self.rel, self.default = core, rel, default + self.editable = editable + self.validator_list = validator_list or [] + self.prepopulate_from = prepopulate_from + self.unique_for_date, self.unique_for_month = unique_for_date, unique_for_month + self.unique_for_year = unique_for_year + self.choices = choices or [] + self.radio_admin = radio_admin + self.help_text = help_text + self.db_column = db_column + + # Set db_index to True if the field has a relationship and doesn't explicitly set db_index. + self.db_index = db_index + + # Increase the creation counter, and save our local copy. + self.creation_counter = Field.creation_counter + Field.creation_counter += 1 + + def __cmp__(self, other): + # This is needed because bisect does not take a comparison function. + return cmp(self.creation_counter, other.creation_counter) + + def to_python(self, value): + """ + Converts the input value into the expected Python data type, raising + validators.ValidationError if the data can't be converted. Returns the + converted value. Subclasses should override this. + """ + return value + + def validate_full(self, field_data, all_data): + """ + Returns a list of errors for this field. This is the main interface, + as it encapsulates some basic validation logic used by all fields. + Subclasses should implement validate(), not validate_full(). + """ + if not self.blank and not field_data: + return [gettext_lazy('This field is required.')] + try: + self.validate(field_data, all_data) + except validators.ValidationError, e: + return e.messages + return [] + + def validate(self, field_data, all_data): + """ + Raises validators.ValidationError if field_data has any errors. + Subclasses should override this to specify field-specific validation + logic. This method should assume field_data has already been converted + into the appropriate data type by Field.to_python(). + """ + pass + + def set_attributes_from_name(self, name): + self.name = name + self.attname, self.column = self.get_attname_column() + self.verbose_name = self.verbose_name or (name and name.replace('_', ' ')) + + def contribute_to_class(self, cls, name): + self.set_attributes_from_name(name) + cls._meta.add_field(self) + if self.choices: + setattr(cls, 'get_%s_display' % self.name, curry(cls._get_FIELD_display, field=self)) + + def set_name(self, name): + self.name = name + self.verbose_name = self.verbose_name or name.replace('_', ' ') + self.attname, self.column = self.get_attname_column() + + def get_attname(self): + return self.name + + def get_attname_column(self): + attname = self.get_attname() + column = self.db_column or attname + return attname, column + + def get_cache_name(self): + return '_%s_cache' % self.name + + def get_internal_type(self): + return self.__class__.__name__ + + def pre_save(self, value, add): + "Returns field's value just before saving." + return value + + def get_db_prep_save(self, value): + "Returns field's value prepared for saving into a database." + return value + + def get_db_prep_lookup(self, lookup_type, value): + "Returns field's value prepared for database lookup." + if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'ne', 'year', 'month', 'day'): + return [value] + elif lookup_type in ('range', 'in'): + return value + elif lookup_type in ('contains', 'icontains'): + return ["%%%s%%" % prep_for_like_query(value)] + elif lookup_type == 'iexact': + return [prep_for_like_query(value)] + elif lookup_type in ('startswith', 'istartswith'): + return ["%s%%" % prep_for_like_query(value)] + elif lookup_type in ('endswith', 'iendswith'): + return ["%%%s" % prep_for_like_query(value)] + elif lookup_type == 'isnull': + return [] + raise TypeError, "Field has invalid lookup: %s" % lookup_type + + def has_default(self): + "Returns a boolean of whether this field has a default value." + return self.default is not NOT_PROVIDED + + def get_default(self): + "Returns the default value for this field." + if self.default is not NOT_PROVIDED: + if callable(self.default): + return self.default() + return self.default + if not self.empty_strings_allowed or self.null: + return None + return "" + + def get_manipulator_field_names(self, name_prefix): + """ + Returns a list of field names that this object adds to the manipulator. + """ + return [name_prefix + self.name] + + def prepare_field_objs_and_params(self, manipulator, name_prefix): + params = {'validator_list': self.validator_list[:]} + if self.maxlength and not self.choices: # Don't give SelectFields a maxlength parameter. + params['maxlength'] = self.maxlength + + if self.choices: + if self.radio_admin: + field_objs = [forms.RadioSelectField] + params['ul_class'] = get_ul_class(self.radio_admin) + else: + field_objs = [forms.SelectField] + + params['choices'] = self.get_choices_default() + else: + field_objs = self.get_manipulator_field_objs() + return (field_objs, params) + + def get_manipulator_fields(self, opts, manipulator, change, name_prefix='', rel=False, follow=True): + """ + Returns a list of forms.FormField instances for this field. It + calculates the choices at runtime, not at compile time. + + name_prefix is a prefix to prepend to the "field_name" argument. + rel is a boolean specifying whether this field is in a related context. + """ + field_objs, params = self.prepare_field_objs_and_params(manipulator, name_prefix) + + # Add the "unique" validator(s). + for field_name_list in opts.unique_together: + if field_name_list[0] == self.name: + params['validator_list'].append(getattr(manipulator, 'isUnique%s' % '_'.join(field_name_list))) + + # Add the "unique for..." validator(s). + if self.unique_for_date: + params['validator_list'].append(getattr(manipulator, 'isUnique%sFor%s' % (self.name, self.unique_for_date))) + if self.unique_for_month: + params['validator_list'].append(getattr(manipulator, 'isUnique%sFor%s' % (self.name, self.unique_for_month))) + if self.unique_for_year: + params['validator_list'].append(getattr(manipulator, 'isUnique%sFor%s' % (self.name, self.unique_for_year))) + if self.unique or (self.primary_key and not rel): + params['validator_list'].append(curry(manipulator_validator_unique, self, opts, manipulator)) + + # Only add is_required=True if the field cannot be blank. Primary keys + # are a special case, and fields in a related context should set this + # as False, because they'll be caught by a separate validator -- + # RequiredIfOtherFieldGiven. + params['is_required'] = not self.blank and not self.primary_key and not rel + + # BooleanFields (CheckboxFields) are a special case. They don't take + # is_required or validator_list. + if isinstance(self, BooleanField): + del params['validator_list'], params['is_required'] + + # If this field is in a related context, check whether any other fields + # in the related object have core=True. If so, add a validator -- + # RequiredIfOtherFieldsGiven -- to this FormField. + if rel and not self.blank and not isinstance(self, AutoField) and not isinstance(self, FileField): + # First, get the core fields, if any. + core_field_names = [] + for f in opts.fields: + if f.core and f != self: + core_field_names.extend(f.get_manipulator_field_names(name_prefix)) + # Now, if there are any, add the validator to this FormField. + if core_field_names: + params['validator_list'].append(validators.RequiredIfOtherFieldsGiven(core_field_names, gettext_lazy("This field is required."))) + + # Finally, add the field_names. + field_names = self.get_manipulator_field_names(name_prefix) + return [man(field_name=field_names[i], **params) for i, man in enumerate(field_objs)] + + def get_validator_unique_lookup_type(self): + return '%s__exact' % self.name + + def get_manipulator_new_data(self, new_data, rel=False): + """ + Given the full new_data dictionary (from the manipulator), returns this + field's data. + """ + if rel: + return new_data.get(self.name, [self.get_default()])[0] + val = new_data.get(self.name, self.get_default()) + if not self.empty_strings_allowed and val == '' and self.null: + val = None + return val + + def get_choices(self, include_blank=True, blank_choice=BLANK_CHOICE_DASH): + "Returns a list of tuples used as SelectField choices for this field." + first_choice = include_blank and blank_choice or [] + if self.choices: + return first_choice + list(self.choices) + rel_model = self.rel.to + return first_choice + [(x._get_pk_val(), str(x)) + for x in rel_model._default_manager.filter(**self.rel.limit_choices_to)] + + def get_choices_default(self): + if self.radio_admin: + return self.get_choices(include_blank=self.blank, blank_choice=BLANK_CHOICE_NONE) + else: + return self.get_choices() + + def _get_val_from_obj(self, obj): + if obj: + return getattr(obj, self.attname) + else: + return self.get_default() + + def flatten_data(self, follow, obj=None): + """ + Returns a dictionary mapping the field's manipulator field names to its + "flattened" string values for the admin view. obj is the instance to + extract the values from. + """ + return {self.attname: self._get_val_from_obj(obj)} + + def get_follow(self, override=None): + if override != None: + return override + else: + return self.editable + + def bind(self, fieldmapping, original, bound_field_class): + return bound_field_class(self, fieldmapping, original) + +class AutoField(Field): + empty_strings_allowed = False + def __init__(self, *args, **kwargs): + assert kwargs.get('primary_key', False) is True, "%ss must have primary_key=True." % self.__class__.__name__ + kwargs['blank'] = True + Field.__init__(self, *args, **kwargs) + + def to_python(self, value): + if value is None: + return value + try: + return int(value) + except (TypeError, ValueError): + raise validators.ValidationError, gettext("This value must be an integer.") + + def get_manipulator_fields(self, opts, manipulator, change, name_prefix='', rel=False, follow=True): + if not rel: + return [] # Don't add a FormField unless it's in a related context. + return Field.get_manipulator_fields(self, opts, manipulator, change, name_prefix, rel, follow) + + def get_manipulator_field_objs(self): + return [forms.HiddenField] + + def get_manipulator_new_data(self, new_data, rel=False): + # Never going to be called + # Not in main change pages + # ignored in related context + if not rel: + return None + return Field.get_manipulator_new_data(self, new_data, rel) + + def contribute_to_class(self, cls, name): + assert not cls._meta.has_auto_field, "A model can't have more than one AutoField." + super(AutoField, self).contribute_to_class(cls, name) + cls._meta.has_auto_field = True + +class BooleanField(Field): + def __init__(self, *args, **kwargs): + kwargs['blank'] = True + Field.__init__(self, *args, **kwargs) + + def to_python(self, value): + if value in (True, False): return value + if value is 't': return True + if value is 'f': return False + raise validators.ValidationError, gettext("This value must be either True or False.") + + def get_manipulator_field_objs(self): + return [forms.CheckboxField] + +class CharField(Field): + def get_manipulator_field_objs(self): + return [forms.TextField] + + def to_python(self, value): + if isinstance(value, basestring): + return value + if value is None: + if self.null: + return value + else: + raise validators.ValidationError, gettext_lazy("This field cannot be null.") + return str(value) + +# TODO: Maybe move this into contrib, because it's specialized. +class CommaSeparatedIntegerField(CharField): + def get_manipulator_field_objs(self): + return [forms.CommaSeparatedIntegerField] + +class DateField(Field): + empty_strings_allowed = False + def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs): + self.auto_now, self.auto_now_add = auto_now, auto_now_add + #HACKs : auto_now_add/auto_now should be done as a default or a pre_save. + if auto_now or auto_now_add: + kwargs['editable'] = False + kwargs['blank'] = True + Field.__init__(self, verbose_name, name, **kwargs) + + def to_python(self, value): + if isinstance(value, datetime.datetime): + return value.date() + if isinstance(value, datetime.date): + return value + validators.isValidANSIDate(value, None) + return datetime.date(*time.strptime(value, '%Y-%m-%d')[:3]) + + def get_db_prep_lookup(self, lookup_type, value): + if lookup_type == 'range': + value = [str(v) for v in value] + elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'ne'): + value = value.strftime('%Y-%m-%d') + else: + value = str(value) + return Field.get_db_prep_lookup(self, lookup_type, value) + + def pre_save(self, value, add): + if self.auto_now or (self.auto_now_add and add): + return datetime.datetime.now() + return value + + def contribute_to_class(self, cls, name): + super(DateField,self).contribute_to_class(cls, name) + if not self.null: + setattr(cls, 'get_next_by_%s' % self.name, + curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=True)) + setattr(cls, 'get_previous_by_%s' % self.name, + curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=False)) + + # Needed because of horrible auto_now[_add] behaviour wrt. editable + def get_follow(self, override=None): + if override != None: + return override + else: + return self.editable or self.auto_now or self.auto_now_add + + def get_db_prep_save(self, value): + # Casts dates into string format for entry into database. + if value is not None: + value = value.strftime('%Y-%m-%d') + return Field.get_db_prep_save(self, value) + + def get_manipulator_field_objs(self): + return [forms.DateField] + + def flatten_data(self, follow, obj = None): + val = self._get_val_from_obj(obj) + return {self.attname: (val is not None and val.strftime("%Y-%m-%d") or '')} + +class DateTimeField(DateField): + def to_python(self, value): + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + return datetime.datetime(value.year, value.month, value.day) + try: # Seconds are optional, so try converting seconds first. + return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6]) + except ValueError: + try: # Try without seconds. + return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5]) + except ValueError: # Try without hour/minutes/seconds. + try: + return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3]) + except ValueError: + raise validators.ValidationError, gettext('Enter a valid date/time in YYYY-MM-DD HH:MM format.') + + def get_db_prep_save(self, value): + # Casts dates into string format for entry into database. + if value is not None: + # MySQL will throw a warning if microseconds are given, because it + # doesn't support microseconds. + if settings.DATABASE_ENGINE == 'mysql': + value = value.replace(microsecond=0) + value = str(value) + return Field.get_db_prep_save(self, value) + + def get_db_prep_lookup(self, lookup_type, value): + if lookup_type == 'range': + value = [str(v) for v in value] + else: + value = str(value) + return Field.get_db_prep_lookup(self, lookup_type, value) + + def get_manipulator_field_objs(self): + return [forms.DateField, forms.TimeField] + + def get_manipulator_field_names(self, name_prefix): + return [name_prefix + self.name + '_date', name_prefix + self.name + '_time'] + + def get_manipulator_new_data(self, new_data, rel=False): + date_field, time_field = self.get_manipulator_field_names('') + if rel: + d = new_data.get(date_field, [None])[0] + t = new_data.get(time_field, [None])[0] + else: + d = new_data.get(date_field, None) + t = new_data.get(time_field, None) + if d is not None and t is not None: + return datetime.datetime.combine(d, t) + return self.get_default() + + def flatten_data(self,follow, obj = None): + val = self._get_val_from_obj(obj) + date_field, time_field = self.get_manipulator_field_names('') + return {date_field: (val is not None and val.strftime("%Y-%m-%d") or ''), + time_field: (val is not None and val.strftime("%H:%M:%S") or '')} + +class EmailField(CharField): + def __init__(self, *args, **kwargs): + kwargs['maxlength'] = 75 + CharField.__init__(self, *args, **kwargs) + + def get_internal_type(self): + return "CharField" + + def get_manipulator_field_objs(self): + return [forms.EmailField] + + def validate(self, field_data, all_data): + validators.isValidEmail(field_data, all_data) + +class FileField(Field): + def __init__(self, verbose_name=None, name=None, upload_to='', **kwargs): + self.upload_to = upload_to + Field.__init__(self, verbose_name, name, **kwargs) + + def get_manipulator_fields(self, opts, manipulator, change, name_prefix='', rel=False, follow=True): + field_list = Field.get_manipulator_fields(self, opts, manipulator, change, name_prefix, rel, follow) + if not self.blank: + if rel: + # This validator makes sure FileFields work in a related context. + class RequiredFileField: + def __init__(self, other_field_names, other_file_field_name): + self.other_field_names = other_field_names + self.other_file_field_name = other_file_field_name + self.always_test = True + def __call__(self, field_data, all_data): + if not all_data.get(self.other_file_field_name, False): + c = validators.RequiredIfOtherFieldsGiven(self.other_field_names, gettext_lazy("This field is required.")) + c(field_data, all_data) + # First, get the core fields, if any. + core_field_names = [] + for f in opts.fields: + if f.core and f != self: + core_field_names.extend(f.get_manipulator_field_names(name_prefix)) + # Now, if there are any, add the validator to this FormField. + if core_field_names: + field_list[0].validator_list.append(RequiredFileField(core_field_names, field_list[1].field_name)) + else: + v = validators.RequiredIfOtherFieldNotGiven(field_list[1].field_name, gettext_lazy("This field is required.")) + v.always_test = True + field_list[0].validator_list.append(v) + field_list[0].is_required = field_list[1].is_required = False + + # If the raw path is passed in, validate it's under the MEDIA_ROOT. + def isWithinMediaRoot(field_data, all_data): + f = os.path.abspath(os.path.join(settings.MEDIA_ROOT, field_data)) + if not f.startswith(os.path.normpath(settings.MEDIA_ROOT)): + raise validators.ValidationError, _("Enter a valid filename.") + field_list[1].validator_list.append(isWithinMediaRoot) + return field_list + + def contribute_to_class(self, cls, name): + super(FileField, self).contribute_to_class(cls, name) + setattr(cls, 'get_%s_filename' % self.name, curry(cls._get_FIELD_filename, field=self)) + setattr(cls, 'get_%s_url' % self.name, curry(cls._get_FIELD_url, field=self)) + setattr(cls, 'get_%s_size' % self.name, curry(cls._get_FIELD_size, field=self)) + setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents: instance._save_FIELD_file(self, filename, raw_contents)) + dispatcher.connect(self.delete_file, signal=signals.post_delete, sender=cls) + + def delete_file(self, instance): + if getattr(instance, self.attname): + file_name = getattr(instance, 'get_%s_filename' % self.name)() + # If the file exists and no other object of this type references it, + # delete it from the filesystem. + if os.path.exists(file_name) and \ + not instance.__class__._default_manager.filter(**{'%s__exact' % self.name: getattr(instance, self.attname)}): + os.remove(file_name) + + def get_manipulator_field_objs(self): + return [forms.FileUploadField, forms.HiddenField] + + def get_manipulator_field_names(self, name_prefix): + return [name_prefix + self.name + '_file', name_prefix + self.name] + + def save_file(self, new_data, new_object, original_object, change, rel): + upload_field_name = self.get_manipulator_field_names('')[0] + if new_data.get(upload_field_name, False): + func = getattr(new_object, 'save_%s_file' % self.name) + if rel: + func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"]) + else: + func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"]) + + def get_directory_name(self): + return os.path.normpath(datetime.datetime.now().strftime(self.upload_to)) + + def get_filename(self, filename): + from django.utils.text import get_valid_filename + f = os.path.join(self.get_directory_name(), get_valid_filename(os.path.basename(filename))) + return os.path.normpath(f) + +class FilePathField(Field): + def __init__(self, verbose_name=None, name=None, path='', match=None, recursive=False, **kwargs): + self.path, self.match, self.recursive = path, match, recursive + Field.__init__(self, verbose_name, name, **kwargs) + + def get_manipulator_field_objs(self): + return [curry(forms.FilePathField, path=self.path, match=self.match, recursive=self.recursive)] + +class FloatField(Field): + empty_strings_allowed = False + def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs): + self.max_digits, self.decimal_places = max_digits, decimal_places + Field.__init__(self, verbose_name, name, **kwargs) + + def get_manipulator_field_objs(self): + return [curry(forms.FloatField, max_digits=self.max_digits, decimal_places=self.decimal_places)] + +class ImageField(FileField): + def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs): + self.width_field, self.height_field = width_field, height_field + FileField.__init__(self, verbose_name, name, **kwargs) + + def get_manipulator_field_objs(self): + return [forms.ImageUploadField, forms.HiddenField] + + def contribute_to_class(self, cls, name): + super(ImageField, self).contribute_to_class(cls, name) + # Add get_BLAH_width and get_BLAH_height methods, but only if the + # image field doesn't have width and height cache fields. + if not self.width_field: + setattr(cls, 'get_%s_width' % self.name, curry(cls._get_FIELD_width, field=self)) + if not self.height_field: + setattr(cls, 'get_%s_height' % self.name, curry(cls._get_FIELD_height, field=self)) + + def save_file(self, new_data, new_object, original_object, change, rel): + FileField.save_file(self, new_data, new_object, original_object, change, rel) + # If the image has height and/or width field(s) and they haven't + # changed, set the width and/or height field(s) back to their original + # values. + if change and (self.width_field or self.height_field): + if self.width_field: + setattr(new_object, self.width_field, getattr(original_object, self.width_field)) + if self.height_field: + setattr(new_object, self.height_field, getattr(original_object, self.height_field)) + new_object.save() + +class IntegerField(Field): + empty_strings_allowed = False + def get_manipulator_field_objs(self): + return [forms.IntegerField] + +class IPAddressField(Field): + def __init__(self, *args, **kwargs): + kwargs['maxlength'] = 15 + Field.__init__(self, *args, **kwargs) + + def get_manipulator_field_objs(self): + return [forms.IPAddressField] + + def validate(self, field_data, all_data): + validators.isValidIPAddress4(field_data, None) + +class NullBooleanField(Field): + def __init__(self, *args, **kwargs): + kwargs['null'] = True + Field.__init__(self, *args, **kwargs) + + def get_manipulator_field_objs(self): + return [forms.NullBooleanField] + +class PhoneNumberField(IntegerField): + def get_manipulator_field_objs(self): + return [forms.PhoneNumberField] + + def validate(self, field_data, all_data): + validators.isValidPhone(field_data, all_data) + +class PositiveIntegerField(IntegerField): + def get_manipulator_field_objs(self): + return [forms.PositiveIntegerField] + +class PositiveSmallIntegerField(IntegerField): + def get_manipulator_field_objs(self): + return [forms.PositiveSmallIntegerField] + +class SlugField(Field): + def __init__(self, *args, **kwargs): + kwargs['maxlength'] = kwargs.get('maxlength', 50) + kwargs.setdefault('validator_list', []).append(validators.isSlug) + # Set db_index=True unless it's been set manually. + if not kwargs.has_key('db_index'): + kwargs['db_index'] = True + Field.__init__(self, *args, **kwargs) + + def get_manipulator_field_objs(self): + return [forms.TextField] + +class SmallIntegerField(IntegerField): + def get_manipulator_field_objs(self): + return [forms.SmallIntegerField] + +class TextField(Field): + def get_manipulator_field_objs(self): + return [forms.LargeTextField] + +class TimeField(Field): + empty_strings_allowed = False + def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs): + self.auto_now, self.auto_now_add = auto_now, auto_now_add + if auto_now or auto_now_add: + kwargs['editable'] = False + Field.__init__(self, verbose_name, name, **kwargs) + + def get_db_prep_lookup(self, lookup_type, value): + if lookup_type == 'range': + value = [str(v) for v in value] + else: + value = str(value) + return Field.get_db_prep_lookup(self, lookup_type, value) + + def pre_save(self, value, add): + if self.auto_now or (self.auto_now_add and add): + return datetime.datetime.now().time() + return value + + def get_db_prep_save(self, value): + # Casts dates into string format for entry into database. + if value is not None: + # MySQL will throw a warning if microseconds are given, because it + # doesn't support microseconds. + if settings.DATABASE_ENGINE == 'mysql': + value = value.replace(microsecond=0) + value = str(value) + return Field.get_db_prep_save(self, value) + + def get_manipulator_field_objs(self): + return [forms.TimeField] + + def flatten_data(self,follow, obj = None): + val = self._get_val_from_obj(obj) + return {self.attname: (val is not None and val.strftime("%H:%M:%S") or '')} + +class URLField(Field): + def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs): + if verify_exists: + kwargs.setdefault('validator_list', []).append(validators.isExistingURL) + Field.__init__(self, verbose_name, name, **kwargs) + + def get_manipulator_field_objs(self): + return [forms.URLField] + +class USStateField(Field): + def get_manipulator_field_objs(self): + return [forms.USStateField] + +class XMLField(TextField): + def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs): + self.schema_path = schema_path + Field.__init__(self, verbose_name, name, **kwargs) + + def get_internal_type(self): + return "TextField" + + def get_manipulator_field_objs(self): + return [curry(forms.XMLLargeTextField, schema_path=self.schema_path)] + +class OrderingField(IntegerField): + empty_strings_allowed=False + def __init__(self, with_respect_to, **kwargs): + self.wrt = with_respect_to + kwargs['null'] = True + IntegerField.__init__(self, **kwargs ) + + def get_internal_type(self): + return "IntegerField" + + def get_manipulator_fields(self, opts, manipulator, change, name_prefix='', rel=False, follow=True): + return [forms.HiddenField(name_prefix + self.name)] diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py new file mode 100644 index 0000000000..908aa75207 --- /dev/null +++ b/django/db/models/fields/related.py @@ -0,0 +1,718 @@ +from django.db import backend, connection, transaction +from django.db.models import signals +from django.db.models.fields import AutoField, Field, IntegerField, get_ul_class +from django.db.models.related import RelatedObject +from django.utils.translation import gettext_lazy, string_concat +from django.utils.functional import curry +from django.core import validators +from django import forms +from django.dispatch import dispatcher + +# For Python 2.3 +if not hasattr(__builtins__, 'set'): + from sets import Set as set + +# Values for Relation.edit_inline. +TABULAR, STACKED = 1, 2 + +RECURSIVE_RELATIONSHIP_CONSTANT = 'self' + +pending_lookups = {} + +def add_lookup(rel_cls, field): + name = field.rel.to + module = rel_cls.__module__ + key = (module, name) + pending_lookups.setdefault(key, []).append((rel_cls, field)) + +def do_pending_lookups(sender): + other_cls = sender + key = (other_cls.__module__, other_cls.__name__) + for rel_cls, field in pending_lookups.setdefault(key, []): + field.rel.to = other_cls + field.do_related_class(other_cls, rel_cls) + +dispatcher.connect(do_pending_lookups, signal=signals.class_prepared) + +def manipulator_valid_rel_key(f, self, field_data, all_data): + "Validates that the value is a valid foreign key" + klass = f.rel.to + try: + klass._default_manager.get(pk=field_data) + except klass.DoesNotExist: + raise validators.ValidationError, _("Please enter a valid %s.") % f.verbose_name + +#HACK +class RelatedField(object): + def contribute_to_class(self, cls, name): + sup = super(RelatedField, self) + + # Add an accessor to allow easy determination of the related query path for this field + self.related_query_name = curry(self._get_related_query_name, cls._meta) + + if hasattr(sup, 'contribute_to_class'): + sup.contribute_to_class(cls, name) + other = self.rel.to + if isinstance(other, basestring): + if other == RECURSIVE_RELATIONSHIP_CONSTANT: + self.rel.to = cls.__name__ + add_lookup(cls, self) + else: + self.do_related_class(other, cls) + + def set_attributes_from_rel(self): + self.name = self.name or (self.rel.to._meta.object_name.lower() + '_' + self.rel.to._meta.pk.name) + self.verbose_name = self.verbose_name or self.rel.to._meta.verbose_name + self.rel.field_name = self.rel.field_name or self.rel.to._meta.pk.name + + def do_related_class(self, other, cls): + self.set_attributes_from_rel() + related = RelatedObject(other, cls, self) + self.contribute_to_related_class(other, related) + + def _get_related_query_name(self, opts): + # This method defines the name that can be used to identify this related object + # in a table-spanning query. It uses the lower-cased object_name by default, + # but this can be overridden with the "related_name" option. + return self.rel.related_name or opts.object_name.lower() + +class SingleRelatedObjectDescriptor(object): + # This class provides the functionality that makes the related-object + # managers available as attributes on a model class, for fields that have + # a single "remote" value, on the class pointed to by a related field. + # In the example "place.restaurant", the restaurant attribute is a + # SingleRelatedObjectDescriptor instance. + def __init__(self, related): + self.related = related + + def __get__(self, instance, instance_type=None): + if instance is None: + raise AttributeError, "%s must be accessed via instance" % self.related.opts.object_name + + params = {'%s__pk' % self.related.field.name: instance._get_pk_val()} + rel_obj = self.related.model._default_manager.get(**params) + return rel_obj + + def __set__(self, instance, value): + if instance is None: + raise AttributeError, "%s must be accessed via instance" % self.related.opts.object_name + # Set the value of the related field + setattr(value, self.related.field.rel.get_related_field().attname, instance) + + # Clear the cache, if it exists + try: + delattr(value, self.related.field.get_cache_name()) + except AttributeError: + pass + +class ReverseSingleRelatedObjectDescriptor(object): + # This class provides the functionality that makes the related-object + # managers available as attributes on a model class, for fields that have + # a single "remote" value, on the class that defines the related field. + # In the example "choice.poll", the poll attribute is a + # ReverseSingleRelatedObjectDescriptor instance. + def __init__(self, field_with_rel): + self.field = field_with_rel + + def __get__(self, instance, instance_type=None): + if instance is None: + raise AttributeError, "%s must be accessed via instance" % self.field.name + cache_name = self.field.get_cache_name() + try: + return getattr(instance, cache_name) + except AttributeError: + val = getattr(instance, self.field.attname) + if val is None: + # If NULL is an allowed value, return it. + if self.field.null: + return None + raise self.field.rel.to.DoesNotExist + other_field = self.field.rel.get_related_field() + if other_field.rel: + params = {'%s__pk' % self.field.rel.field_name: val} + else: + params = {'%s__exact' % self.field.rel.field_name: val} + rel_obj = self.field.rel.to._default_manager.get(**params) + setattr(instance, cache_name, rel_obj) + return rel_obj + + def __set__(self, instance, value): + if instance is None: + raise AttributeError, "%s must be accessed via instance" % self._field.name + # Set the value of the related field + try: + val = getattr(value, self.field.rel.get_related_field().attname) + except AttributeError: + val = None + setattr(instance, self.field.attname, val) + + # Clear the cache, if it exists + try: + delattr(instance, self.field.get_cache_name()) + except AttributeError: + pass + +class ForeignRelatedObjectsDescriptor(object): + # This class provides the functionality that makes the related-object + # managers available as attributes on a model class, for fields that have + # multiple "remote" values and have a ForeignKey pointed at them by + # some other model. In the example "poll.choice_set", the choice_set + # attribute is a ForeignRelatedObjectsDescriptor instance. + def __init__(self, related): + self.related = related # RelatedObject instance + + def __get__(self, instance, instance_type=None): + if instance is None: + raise AttributeError, "Manager must be accessed via instance" + + rel_field = self.related.field + rel_model = self.related.model + + # Dynamically create a class that subclasses the related + # model's default manager. + superclass = self.related.model._default_manager.__class__ + + class RelatedManager(superclass): + def get_query_set(self): + return superclass.get_query_set(self).filter(**(self.core_filters)) + + def add(self, *objs): + for obj in objs: + setattr(obj, rel_field.name, instance) + obj.save() + add.alters_data = True + + def create(self, **kwargs): + new_obj = self.model(**kwargs) + self.add(new_obj) + return new_obj + create.alters_data = True + + # remove() and clear() are only provided if the ForeignKey can have a value of null. + if rel_field.null: + def remove(self, *objs): + val = getattr(instance, rel_field.rel.get_related_field().attname) + for obj in objs: + # Is obj actually part of this descriptor set? + if getattr(obj, rel_field.attname) == val: + setattr(obj, rel_field.name, None) + obj.save() + else: + raise rel_field.rel.to.DoesNotExist, "'%s' is not related to '%s'." % (obj, instance) + remove.alters_data = True + + def clear(self): + for obj in self.all(): + setattr(obj, rel_field.name, None) + obj.save() + clear.alters_data = True + + manager = RelatedManager() + manager.core_filters = {'%s__pk' % rel_field.name: getattr(instance, rel_field.rel.get_related_field().attname)} + manager.model = self.related.model + + return manager + + def __set__(self, instance, value): + if instance is None: + raise AttributeError, "Manager must be accessed via instance" + + manager = self.__get__(instance) + # If the foreign key can support nulls, then completely clear the related set. + # Otherwise, just move the named objects into the set. + if self.related.field.null: + manager.clear() + for obj in value: + manager.add(obj) + +def create_many_related_manager(superclass): + """Creates a manager that subclasses 'superclass' (which is a Manager) + and adds behavior for many-to-many related objects.""" + class ManyRelatedManager(superclass): + def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None, + join_table=None, source_col_name=None, target_col_name=None): + super(ManyRelatedManager, self).__init__() + self.core_filters = core_filters + self.model = model + self.symmetrical = symmetrical + self.instance = instance + self.join_table = join_table + self.source_col_name = source_col_name + self.target_col_name = target_col_name + if instance: + self._pk_val = self.instance._get_pk_val() + + def get_query_set(self): + return superclass.get_query_set(self).filter(**(self.core_filters)) + + def add(self, *objs): + self._add_items(self.source_col_name, self.target_col_name, *objs) + + # If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table + if self.symmetrical: + self._add_items(self.target_col_name, self.source_col_name, *objs) + add.alters_data = True + + def remove(self, *objs): + self._remove_items(self.source_col_name, self.target_col_name, *objs) + + # If this is a symmetrical m2m relation to self, remove the mirror entry in the m2m table + if self.symmetrical: + self._remove_items(self.target_col_name, self.source_col_name, *objs) + remove.alters_data = True + + def clear(self): + self._clear_items(self.source_col_name) + + # If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table + if self.symmetrical: + self._clear_items(self.target_col_name) + clear.alters_data = True + + def create(self, **kwargs): + new_obj = self.model(**kwargs) + new_obj.save() + self.add(new_obj) + return new_obj + create.alters_data = True + + def _add_items(self, source_col_name, target_col_name, *objs): + # join_table: name of the m2m link table + # source_col_name: the PK colname in join_table for the source object + # target_col_name: the PK colname in join_table for the target object + # *objs - objects to add + from django.db import connection + + # Add the newly created or already existing objects to the join table. + # First find out which items are already added, to avoid adding them twice + new_ids = set([obj._get_pk_val() for obj in objs]) + cursor = connection.cursor() + cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \ + (target_col_name, self.join_table, source_col_name, + target_col_name, ",".join(['%s'] * len(new_ids))), + [self._pk_val] + list(new_ids)) + if cursor.rowcount is not None and cursor.rowcount != 0: + existing_ids = set([row[0] for row in cursor.fetchmany(cursor.rowcount)]) + else: + existing_ids = set() + + # Add the ones that aren't there already + for obj_id in (new_ids - existing_ids): + cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \ + (self.join_table, source_col_name, target_col_name), + [self._pk_val, obj_id]) + transaction.commit_unless_managed() + + def _remove_items(self, source_col_name, target_col_name, *objs): + # source_col_name: the PK colname in join_table for the source object + # target_col_name: the PK colname in join_table for the target object + # *objs - objects to remove + from django.db import connection + + for obj in objs: + if not isinstance(obj, self.model): + raise ValueError, "objects to remove() must be %s instances" % self.model._meta.object_name + # Remove the specified objects from the join table + cursor = connection.cursor() + for obj in objs: + cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s = %%s" % \ + (self.join_table, source_col_name, target_col_name), + [self._pk_val, obj._get_pk_val()]) + transaction.commit_unless_managed() + + def _clear_items(self, source_col_name): + # source_col_name: the PK colname in join_table for the source object + from django.db import connection + cursor = connection.cursor() + cursor.execute("DELETE FROM %s WHERE %s = %%s" % \ + (self.join_table, source_col_name), + [self._pk_val]) + transaction.commit_unless_managed() + + return ManyRelatedManager + +class ManyRelatedObjectsDescriptor(object): + # This class provides the functionality that makes the related-object + # managers available as attributes on a model class, for fields that have + # multiple "remote" values and have a ManyToManyField pointed at them by + # some other model (rather than having a ManyToManyField themselves). + # In the example "publication.article_set", the article_set attribute is a + # ManyRelatedObjectsDescriptor instance. + def __init__(self, related): + self.related = related # RelatedObject instance + + def __get__(self, instance, instance_type=None): + if instance is None: + raise AttributeError, "Manager must be accessed via instance" + + # Dynamically create a class that subclasses the related + # model's default manager. + rel_model = self.related.model + superclass = rel_model._default_manager.__class__ + RelatedManager = create_many_related_manager(superclass) + + qn = backend.quote_name + manager = RelatedManager( + model=rel_model, + core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()}, + instance=instance, + symmetrical=False, + join_table=qn(self.related.field.m2m_db_table()), + source_col_name=qn(self.related.field.m2m_reverse_name()), + target_col_name=qn(self.related.field.m2m_column_name()) + ) + + return manager + + def __set__(self, instance, value): + if instance is None: + raise AttributeError, "Manager must be accessed via instance" + + manager = self.__get__(instance) + manager.clear() + for obj in value: + manager.add(obj) + +class ReverseManyRelatedObjectsDescriptor(object): + # This class provides the functionality that makes the related-object + # managers available as attributes on a model class, for fields that have + # multiple "remote" values and have a ManyToManyField defined in their + # model (rather than having another model pointed *at* them). + # In the example "article.publications", the publications attribute is a + # ReverseManyRelatedObjectsDescriptor instance. + def __init__(self, m2m_field): + self.field = m2m_field + + def __get__(self, instance, instance_type=None): + if instance is None: + raise AttributeError, "Manager must be accessed via instance" + + # Dynamically create a class that subclasses the related + # model's default manager. + rel_model=self.field.rel.to + superclass = rel_model._default_manager.__class__ + RelatedManager = create_many_related_manager(superclass) + + qn = backend.quote_name + manager = RelatedManager( + model=rel_model, + core_filters={'%s__pk' % self.field.related_query_name(): instance._get_pk_val()}, + instance=instance, + symmetrical=(self.field.rel.symmetrical and instance.__class__ == rel_model), + join_table=qn(self.field.m2m_db_table()), + source_col_name=qn(self.field.m2m_column_name()), + target_col_name=qn(self.field.m2m_reverse_name()) + ) + + return manager + + def __set__(self, instance, value): + if instance is None: + raise AttributeError, "Manager must be accessed via instance" + + manager = self.__get__(instance) + manager.clear() + for obj in value: + manager.add(obj) + +class ForeignKey(RelatedField, Field): + empty_strings_allowed = False + def __init__(self, to, to_field=None, **kwargs): + try: + to_name = to._meta.object_name.lower() + except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT + assert isinstance(to, basestring), "ForeignKey(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (to, RECURSIVE_RELATIONSHIP_CONSTANT) + else: + to_field = to_field or to._meta.pk.name + kwargs['verbose_name'] = kwargs.get('verbose_name', '') + + if kwargs.has_key('edit_inline_type'): + import warnings + warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.") + kwargs['edit_inline'] = kwargs.pop('edit_inline_type') + + kwargs['rel'] = ManyToOneRel(to, to_field, + num_in_admin=kwargs.pop('num_in_admin', 3), + min_num_in_admin=kwargs.pop('min_num_in_admin', None), + max_num_in_admin=kwargs.pop('max_num_in_admin', None), + num_extra_on_change=kwargs.pop('num_extra_on_change', 1), + edit_inline=kwargs.pop('edit_inline', False), + related_name=kwargs.pop('related_name', None), + limit_choices_to=kwargs.pop('limit_choices_to', None), + lookup_overrides=kwargs.pop('lookup_overrides', None), + raw_id_admin=kwargs.pop('raw_id_admin', False)) + Field.__init__(self, **kwargs) + + self.db_index = True + + def get_attname(self): + return '%s_id' % self.name + + def get_validator_unique_lookup_type(self): + return '%s__%s__exact' % (self.name, self.rel.get_related_field().name) + + def prepare_field_objs_and_params(self, manipulator, name_prefix): + params = {'validator_list': self.validator_list[:], 'member_name': name_prefix + self.attname} + if self.rel.raw_id_admin: + field_objs = self.get_manipulator_field_objs() + params['validator_list'].append(curry(manipulator_valid_rel_key, self, manipulator)) + else: + if self.radio_admin: + field_objs = [forms.RadioSelectField] + params['ul_class'] = get_ul_class(self.radio_admin) + else: + if self.null: + field_objs = [forms.NullSelectField] + else: + field_objs = [forms.SelectField] + params['choices'] = self.get_choices_default() + return field_objs, params + + def get_manipulator_field_objs(self): + rel_field = self.rel.get_related_field() + if self.rel.raw_id_admin and not isinstance(rel_field, AutoField): + return rel_field.get_manipulator_field_objs() + else: + return [forms.IntegerField] + + def get_db_prep_save(self, value): + if value == '' or value == None: + return None + else: + return self.rel.get_related_field().get_db_prep_save(value) + + def flatten_data(self, follow, obj=None): + if not obj: + # In required many-to-one fields with only one available choice, + # select that one available choice. Note: For SelectFields + # (radio_admin=False), we have to check that the length of choices + # is *2*, not 1, because SelectFields always have an initial + # "blank" value. Otherwise (radio_admin=True), we check that the + # length is 1. + if not self.blank and (not self.rel.raw_id_admin or self.choices): + choice_list = self.get_choices_default() + if self.radio_admin and len(choice_list) == 1: + return {self.attname: choice_list[0][0]} + if not self.radio_admin and len(choice_list) == 2: + return {self.attname: choice_list[1][0]} + return Field.flatten_data(self, follow, obj) + + def contribute_to_class(self, cls, name): + super(ForeignKey, self).contribute_to_class(cls, name) + setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self)) + + def contribute_to_related_class(self, cls, related): + setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) + +class OneToOneField(RelatedField, IntegerField): + def __init__(self, to, to_field=None, **kwargs): + kwargs['verbose_name'] = kwargs.get('verbose_name', '') + to_field = to_field or to._meta.pk.name + + if kwargs.has_key('edit_inline_type'): + import warnings + warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.") + kwargs['edit_inline'] = kwargs.pop('edit_inline_type') + + kwargs['rel'] = OneToOneRel(to, to_field, + num_in_admin=kwargs.pop('num_in_admin', 0), + edit_inline=kwargs.pop('edit_inline', False), + related_name=kwargs.pop('related_name', None), + limit_choices_to=kwargs.pop('limit_choices_to', None), + lookup_overrides=kwargs.pop('lookup_overrides', None), + raw_id_admin=kwargs.pop('raw_id_admin', False)) + kwargs['primary_key'] = True + IntegerField.__init__(self, **kwargs) + + self.db_index = True + + def get_attname(self): + return '%s_id' % self.name + + def get_validator_unique_lookup_type(self): + return '%s__%s__exact' % (self.name, self.rel.get_related_field().name) + + # TODO: Copied from ForeignKey... putting this in RelatedField adversely affects + # ManyToManyField. This works for now. + def prepare_field_objs_and_params(self, manipulator, name_prefix): + params = {'validator_list': self.validator_list[:], 'member_name': name_prefix + self.attname} + if self.rel.raw_id_admin: + field_objs = self.get_manipulator_field_objs() + params['validator_list'].append(curry(manipulator_valid_rel_key, self, manipulator)) + else: + if self.radio_admin: + field_objs = [forms.RadioSelectField] + params['ul_class'] = get_ul_class(self.radio_admin) + else: + if self.null: + field_objs = [forms.NullSelectField] + else: + field_objs = [forms.SelectField] + params['choices'] = self.get_choices_default() + return field_objs, params + + def contribute_to_class(self, cls, name): + super(OneToOneField, self).contribute_to_class(cls, name) + setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self)) + + def contribute_to_related_class(self, cls, related): + setattr(cls, related.get_accessor_name(), SingleRelatedObjectDescriptor(related)) + if not cls._meta.one_to_one_field: + cls._meta.one_to_one_field = self + +class ManyToManyField(RelatedField, Field): + def __init__(self, to, **kwargs): + kwargs['verbose_name'] = kwargs.get('verbose_name', None) + kwargs['rel'] = ManyToManyRel(to, kwargs.pop('singular', None), + num_in_admin=kwargs.pop('num_in_admin', 0), + related_name=kwargs.pop('related_name', None), + filter_interface=kwargs.pop('filter_interface', None), + limit_choices_to=kwargs.pop('limit_choices_to', None), + raw_id_admin=kwargs.pop('raw_id_admin', False), + symmetrical=kwargs.pop('symmetrical', True)) + if kwargs["rel"].raw_id_admin: + kwargs.setdefault("validator_list", []).append(self.isValidIDList) + Field.__init__(self, **kwargs) + + if self.rel.raw_id_admin: + msg = gettext_lazy('Separate multiple IDs with commas.') + else: + msg = gettext_lazy('Hold down "Control", or "Command" on a Mac, to select more than one.') + self.help_text = string_concat(self.help_text, msg) + + def get_manipulator_field_objs(self): + if self.rel.raw_id_admin: + return [forms.RawIdAdminField] + else: + choices = self.get_choices_default() + return [curry(forms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)] + + def get_choices_default(self): + return Field.get_choices(self, include_blank=False) + + def _get_m2m_db_table(self, opts): + "Function that can be curried to provide the m2m table name for this relation" + return '%s_%s' % (opts.db_table, self.name) + + def _get_m2m_column_name(self, related): + "Function that can be curried to provide the source column name for the m2m table" + # If this is an m2m relation to self, avoid the inevitable name clash + if related.model == related.parent_model: + return 'from_' + related.model._meta.object_name.lower() + '_id' + else: + return related.model._meta.object_name.lower() + '_id' + + def _get_m2m_reverse_name(self, related): + "Function that can be curried to provide the related column name for the m2m table" + # If this is an m2m relation to self, avoid the inevitable name clash + if related.model == related.parent_model: + return 'to_' + related.parent_model._meta.object_name.lower() + '_id' + else: + return related.parent_model._meta.object_name.lower() + '_id' + + def isValidIDList(self, field_data, all_data): + "Validates that the value is a valid list of foreign keys" + mod = self.rel.to + try: + pks = map(int, field_data.split(',')) + except ValueError: + # the CommaSeparatedIntegerField validator will catch this error + return + objects = mod._default_manager.in_bulk(pks) + if len(objects) != len(pks): + badkeys = [k for k in pks if k not in objects] + raise validators.ValidationError, ngettext("Please enter valid %(self)s IDs. The value %(value)r is invalid.", + "Please enter valid %(self)s IDs. The values %(value)r are invalid.", len(badkeys)) % { + 'self': self.verbose_name, + 'value': len(badkeys) == 1 and badkeys[0] or tuple(badkeys), + } + + def flatten_data(self, follow, obj = None): + new_data = {} + if obj: + instance_ids = [instance._get_pk_val() for instance in getattr(obj, self.name).all()] + if self.rel.raw_id_admin: + new_data[self.name] = ",".join([str(id) for id in instance_ids]) + else: + new_data[self.name] = instance_ids + else: + # In required many-to-many fields with only one available choice, + # select that one available choice. + if not self.blank and not self.rel.edit_inline and not self.rel.raw_id_admin: + choices_list = self.get_choices_default() + if len(choices_list) == 1: + new_data[self.name] = [choices_list[0][0]] + return new_data + + def contribute_to_class(self, cls, name): + super(ManyToManyField, self).contribute_to_class(cls, name) + # Add the descriptor for the m2m relation + setattr(cls, self.name, ReverseManyRelatedObjectsDescriptor(self)) + + # Set up the accessor for the m2m table name for the relation + self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) + + def contribute_to_related_class(self, cls, related): + # m2m relations to self do not have a ManyRelatedObjectsDescriptor, + # as it would be redundant - unless the field is non-symmetrical. + if related.model != related.parent_model or not self.rel.symmetrical: + # Add the descriptor for the m2m relation + setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(related)) + + self.rel.singular = self.rel.singular or self.rel.to._meta.object_name.lower() + + # Set up the accessors for the column names on the m2m table + self.m2m_column_name = curry(self._get_m2m_column_name, related) + self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related) + + def set_attributes_from_rel(self): + pass + +class ManyToOneRel: + def __init__(self, to, field_name, num_in_admin=3, min_num_in_admin=None, + max_num_in_admin=None, num_extra_on_change=1, edit_inline=False, + related_name=None, limit_choices_to=None, lookup_overrides=None, raw_id_admin=False): + try: + to._meta + except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT + assert isinstance(to, basestring), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT + self.to, self.field_name = to, field_name + self.num_in_admin, self.edit_inline = num_in_admin, edit_inline + self.min_num_in_admin, self.max_num_in_admin = min_num_in_admin, max_num_in_admin + self.num_extra_on_change, self.related_name = num_extra_on_change, related_name + self.limit_choices_to = limit_choices_to or {} + self.lookup_overrides = lookup_overrides or {} + self.raw_id_admin = raw_id_admin + self.multiple = True + + def get_related_field(self): + "Returns the Field in the 'to' object to which this relationship is tied." + return self.to._meta.get_field(self.field_name) + +class OneToOneRel(ManyToOneRel): + def __init__(self, to, field_name, num_in_admin=0, edit_inline=False, + related_name=None, limit_choices_to=None, lookup_overrides=None, + raw_id_admin=False): + self.to, self.field_name = to, field_name + self.num_in_admin, self.edit_inline = num_in_admin, edit_inline + self.related_name = related_name + self.limit_choices_to = limit_choices_to or {} + self.lookup_overrides = lookup_overrides or {} + self.raw_id_admin = raw_id_admin + self.multiple = False + +class ManyToManyRel: + def __init__(self, to, singular=None, num_in_admin=0, related_name=None, + filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True): + self.to = to + self.singular = singular or None + self.num_in_admin = num_in_admin + self.related_name = related_name + self.filter_interface = filter_interface + self.limit_choices_to = limit_choices_to or {} + self.edit_inline = False + self.raw_id_admin = raw_id_admin + self.symmetrical = symmetrical + self.multiple = True + + assert not (self.raw_id_admin and self.filter_interface), "ManyToManyRels may not use both raw_id_admin and filter_interface" diff --git a/django/db/models/loading.py b/django/db/models/loading.py new file mode 100644 index 0000000000..a9e0348f8e --- /dev/null +++ b/django/db/models/loading.py @@ -0,0 +1,71 @@ +"Utilities for loading models and the modules that contain them." + +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured + +__all__ = ('get_apps', 'get_app', 'get_models', 'get_model', 'register_models') + +_app_list = None # Cache of installed apps. +_app_models = {} # Dictionary of models against app label + # Each value is a dictionary of model name: model class + +def get_apps(): + "Returns a list of all installed modules that contain models." + global _app_list + if _app_list is not None: + return _app_list + _app_list = [] + for app_name in settings.INSTALLED_APPS: + try: + _app_list.append(__import__(app_name, '', '', ['models']).models) + except (ImportError, AttributeError), e: + pass + return _app_list + +def get_app(app_label): + "Returns the module containing the models for the given app_label." + for app_name in settings.INSTALLED_APPS: + if app_label == app_name.split('.')[-1]: + return __import__(app_name, '', '', ['models']).models + raise ImproperlyConfigured, "App with label %s could not be found" % app_label + +def get_models(app_mod=None): + """ + Given a module containing models, returns a list of the models. Otherwise + returns a list of all installed models. + """ + app_list = get_apps() # Run get_apps() to populate the _app_list cache. Slightly hackish. + if app_mod: + return _app_models.get(app_mod.__name__.split('.')[-2], {}).values() + else: + model_list = [] + for app_mod in app_list: + model_list.extend(get_models(app_mod)) + return model_list + +def get_model(app_label, model_name): + """ + Returns the model matching the given app_label and case-insensitive model_name. + Returns None if no model is found. + """ + get_apps() # Run get_apps() to populate the _app_list cache. Slightly hackish. + try: + model_dict = _app_models[app_label] + except KeyError: + return None + + try: + return model_dict[model_name.lower()] + except KeyError: + return None + +def register_models(app_label, *models): + """ + Register a set of models as belonging to an app. + """ + for model in models: + # Store as 'name: model' pair in a dictionary + # in the _app_models dictionary + model_name = model._meta.object_name.lower() + model_dict = _app_models.setdefault(app_label, {}) + model_dict[model_name] = model diff --git a/django/db/models/manager.py b/django/db/models/manager.py new file mode 100644 index 0000000000..d847631c82 --- /dev/null +++ b/django/db/models/manager.py @@ -0,0 +1,101 @@ +from django.utils.functional import curry +from django.db import backend, connection +from django.db.models.query import QuerySet +from django.dispatch import dispatcher +from django.db.models import signals +from django.utils.datastructures import SortedDict + +# Size of each "chunk" for get_iterator calls. +# Larger values are slightly faster at the expense of more storage space. +GET_ITERATOR_CHUNK_SIZE = 100 + +def ensure_default_manager(sender): + cls = sender + if not hasattr(cls, '_default_manager'): + # Create the default manager, if needed. + if hasattr(cls, 'objects'): + raise ValueError, "Model %s must specify a custom Manager, because it has a field named 'objects'" % name + cls.add_to_class('objects', Manager()) + +dispatcher.connect(ensure_default_manager, signal=signals.class_prepared) + +class Manager(object): + # Tracks each time a Manager instance is created. Used to retain order. + creation_counter = 0 + + def __init__(self): + super(Manager, self).__init__() + # Increase the creation counter, and save our local copy. + self.creation_counter = Manager.creation_counter + Manager.creation_counter += 1 + self.model = None + + def contribute_to_class(self, model, name): + # TODO: Use weakref because of possible memory leak / circular reference. + self.model = model + setattr(model, name, ManagerDescriptor(self)) + if not hasattr(model, '_default_manager') or self.creation_counter < model._default_manager.creation_counter: + model._default_manager = self + + ####################### + # PROXIES TO QUERYSET # + ####################### + + def get_query_set(self): + """Returns a new QuerySet object. Subclasses can override this method + to easily customise the behaviour of the Manager. + """ + return QuerySet(self.model) + + def all(self): + return self.get_query_set() + + def count(self): + return self.get_query_set().count() + + def dates(self, *args, **kwargs): + return self.get_query_set().dates(*args, **kwargs) + + def distinct(self, *args, **kwargs): + return self.get_query_set().distinct(*args, **kwargs) + + def extra(self, *args, **kwargs): + return self.get_query_set().extra(*args, **kwargs) + + def get(self, *args, **kwargs): + return self.get_query_set().get(*args, **kwargs) + + def filter(self, *args, **kwargs): + return self.get_query_set().filter(*args, **kwargs) + + def exclude(self, *args, **kwargs): + return self.get_query_set().exclude(*args, **kwargs) + + def in_bulk(self, *args, **kwargs): + return self.get_query_set().in_bulk(*args, **kwargs) + + def iterator(self, *args, **kwargs): + return self.get_query_set().iterator(*args, **kwargs) + + def latest(self, *args, **kwargs): + return self.get_query_set().latest(*args, **kwargs) + + def order_by(self, *args, **kwargs): + return self.get_query_set().order_by(*args, **kwargs) + + def select_related(self, *args, **kwargs): + return self.get_query_set().select_related(*args, **kwargs) + + def values(self, *args, **kwargs): + return self.get_query_set().values(*args, **kwargs) + +class ManagerDescriptor(object): + # This class ensures managers aren't accessible via model instances. + # For example, Poll.objects works, but poll_obj.objects raises AttributeError. + def __init__(self, manager): + self.manager = manager + + def __get__(self, instance, type=None): + if instance != None: + raise AttributeError, "Manager isn't accessible via %s instances" % type.__name__ + return self.manager diff --git a/django/db/models/manipulators.py b/django/db/models/manipulators.py new file mode 100644 index 0000000000..fc553bc90c --- /dev/null +++ b/django/db/models/manipulators.py @@ -0,0 +1,330 @@ +from django.core.exceptions import ObjectDoesNotExist +from django import forms +from django.core import validators +from django.db.models.fields import FileField, AutoField +from django.dispatch import dispatcher +from django.db.models import signals +from django.utils.functional import curry +from django.utils.datastructures import DotExpandedDict, MultiValueDict +from django.utils.text import capfirst +import types + +def add_manipulators(sender): + cls = sender + cls.add_to_class('AddManipulator', AutomaticAddManipulator) + cls.add_to_class('ChangeManipulator', AutomaticChangeManipulator) + +dispatcher.connect(add_manipulators, signal=signals.class_prepared) + +class ManipulatorDescriptor(object): + # This class provides the functionality that makes the default model + # manipulators (AddManipulator and ChangeManipulator) available via the + # model class. + def __init__(self, name, base): + self.man = None # Cache of the manipulator class. + self.name = name + self.base = base + + def __get__(self, instance, model=None): + if instance != None: + raise AttributeError, "Manipulator cannot be accessed via instance" + else: + if not self.man: + # Create a class that inherits from the "Manipulator" class + # given in the model class (if specified) and the automatic + # manipulator. + bases = [self.base] + if hasattr(model, 'Manipulator'): + bases = [model.Manipulator] + bases + self.man = types.ClassType(self.name, tuple(bases), {}) + self.man._prepare(model) + return self.man + +class AutomaticManipulator(forms.Manipulator): + def _prepare(cls, model): + cls.model = model + cls.manager = model._default_manager + cls.opts = model._meta + for field_name_list in cls.opts.unique_together: + setattr(cls, 'isUnique%s' % '_'.join(field_name_list), curry(manipulator_validator_unique_together, field_name_list, cls.opts)) + for f in cls.opts.fields: + if f.unique_for_date: + setattr(cls, 'isUnique%sFor%s' % (f.name, f.unique_for_date), curry(manipulator_validator_unique_for_date, f, cls.opts.get_field(f.unique_for_date), cls.opts, 'date')) + if f.unique_for_month: + setattr(cls, 'isUnique%sFor%s' % (f.name, f.unique_for_month), curry(manipulator_validator_unique_for_date, f, cls.opts.get_field(f.unique_for_month), cls.opts, 'month')) + if f.unique_for_year: + setattr(cls, 'isUnique%sFor%s' % (f.name, f.unique_for_year), curry(manipulator_validator_unique_for_date, f, cls.opts.get_field(f.unique_for_year), cls.opts, 'year')) + _prepare = classmethod(_prepare) + + def contribute_to_class(cls, other_cls, name): + setattr(other_cls, name, ManipulatorDescriptor(name, cls)) + contribute_to_class = classmethod(contribute_to_class) + + def __init__(self, follow=None): + self.follow = self.opts.get_follow(follow) + self.fields = [] + + for f in self.opts.fields + self.opts.many_to_many: + if self.follow.get(f.name, False): + self.fields.extend(f.get_manipulator_fields(self.opts, self, self.change)) + + # Add fields for related objects. + for f in self.opts.get_all_related_objects(): + if self.follow.get(f.name, False): + fol = self.follow[f.name] + self.fields.extend(f.get_manipulator_fields(self.opts, self, self.change, fol)) + + # Add field for ordering. + if self.change and self.opts.get_ordered_objects(): + self.fields.append(formfields.CommaSeparatedIntegerField(field_name="order_")) + + def save(self, new_data): + # TODO: big cleanup when core fields go -> use recursive manipulators. + params = {} + for f in self.opts.fields: + # Fields with auto_now_add should keep their original value in the change stage. + auto_now_add = self.change and getattr(f, 'auto_now_add', False) + if self.follow.get(f.name, None) and not auto_now_add: + param = f.get_manipulator_new_data(new_data) + else: + if self.change: + param = getattr(self.original_object, f.attname) + else: + param = f.get_default() + params[f.attname] = param + + if self.change: + params[self.opts.pk.attname] = self.obj_key + + # First, save the basic object itself. + new_object = self.model(**params) + new_object.save() + + # Now that the object's been saved, save any uploaded files. + for f in self.opts.fields: + if isinstance(f, FileField): + f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False) + + # Calculate which primary fields have changed. + if self.change: + self.fields_added, self.fields_changed, self.fields_deleted = [], [], [] + for f in self.opts.fields: + if not f.primary_key and str(getattr(self.original_object, f.attname)) != str(getattr(new_object, f.attname)): + self.fields_changed.append(f.verbose_name) + + # Save many-to-many objects. Example: Set sites for a poll. + for f in self.opts.many_to_many: + if self.follow.get(f.name, None): + if not f.rel.edit_inline: + if f.rel.raw_id_admin: + new_vals = new_data.get(f.name, ()) + else: + new_vals = new_data.getlist(f.name) + # First, clear the existing values. + rel_manager = getattr(new_object, f.name) + rel_manager.clear() + # Then, set the new values. + for n in new_vals: + rel_manager.add(f.rel.to._default_manager.get(pk=n)) + # TODO: Add to 'fields_changed' + + expanded_data = DotExpandedDict(dict(new_data)) + # Save many-to-one objects. Example: Add the Choice objects for a Poll. + for related in self.opts.get_all_related_objects(): + # Create obj_list, which is a DotExpandedDict such as this: + # [('0', {'id': ['940'], 'choice': ['This is the first choice']}), + # ('1', {'id': ['941'], 'choice': ['This is the second choice']}), + # ('2', {'id': [''], 'choice': ['']})] + child_follow = self.follow.get(related.name, None) + + if child_follow: + obj_list = expanded_data[related.var_name].items() + if not obj_list: + continue + + obj_list.sort(lambda x, y: cmp(int(x[0]), int(y[0]))) + + # For each related item... + for _, rel_new_data in obj_list: + + params = {} + + # Keep track of which core=True fields were provided. + # If all core fields were given, the related object will be saved. + # If none of the core fields were given, the object will be deleted. + # If some, but not all, of the fields were given, the validator would + # have caught that. + all_cores_given, all_cores_blank = True, True + + # Get a reference to the old object. We'll use it to compare the + # old to the new, to see which fields have changed. + old_rel_obj = None + if self.change: + if rel_new_data[related.opts.pk.name][0]: + try: + old_rel_obj = getattr(self.original_object, related.get_accessor_name()).get(**{'%s__exact' % related.opts.pk.name: rel_new_data[related.opts.pk.attname][0]}) + except ObjectDoesNotExist: + pass + + for f in related.opts.fields: + if f.core and not isinstance(f, FileField) and f.get_manipulator_new_data(rel_new_data, rel=True) in (None, ''): + all_cores_given = False + elif f.core and not isinstance(f, FileField) and f.get_manipulator_new_data(rel_new_data, rel=True) not in (None, ''): + all_cores_blank = False + # If this field isn't editable, give it the same value it had + # previously, according to the given ID. If the ID wasn't + # given, use a default value. FileFields are also a special + # case, because they'll be dealt with later. + + if f == related.field: + param = getattr(new_object, related.field.rel.field_name) + elif (not self.change) and isinstance(f, AutoField): + param = None + elif self.change and (isinstance(f, FileField) or not child_follow.get(f.name, None)): + if old_rel_obj: + param = getattr(old_rel_obj, f.column) + else: + param = f.get_default() + else: + param = f.get_manipulator_new_data(rel_new_data, rel=True) + if param != None: + params[f.attname] = param + + # Create the related item. + new_rel_obj = related.model(**params) + + # If all the core fields were provided (non-empty), save the item. + if all_cores_given: + new_rel_obj.save() + + # Save any uploaded files. + for f in related.opts.fields: + if child_follow.get(f.name, None): + if isinstance(f, FileField) and rel_new_data.get(f.name, False): + f.save_file(rel_new_data, new_rel_obj, self.change and old_rel_obj or None, old_rel_obj is not None, rel=True) + + # Calculate whether any fields have changed. + if self.change: + if not old_rel_obj: # This object didn't exist before. + self.fields_added.append('%s "%s"' % (related.opts.verbose_name, new_rel_obj)) + else: + for f in related.opts.fields: + if not f.primary_key and f != related.field and str(getattr(old_rel_obj, f.attname)) != str(getattr(new_rel_obj, f.attname)): + self.fields_changed.append('%s for %s "%s"' % (f.verbose_name, related.opts.verbose_name, new_rel_obj)) + + # Save many-to-many objects. + for f in related.opts.many_to_many: + if child_follow.get(f.name, None) and not f.rel.edit_inline: + was_changed = getattr(new_rel_obj, 'set_%s' % f.name)(rel_new_data[f.attname]) + if self.change and was_changed: + self.fields_changed.append('%s for %s "%s"' % (f.verbose_name, related.opts.verbose_name, new_rel_obj)) + + # If, in the change stage, all of the core fields were blank and + # the primary key (ID) was provided, delete the item. + if self.change and all_cores_blank and old_rel_obj: + new_rel_obj.delete() + self.fields_deleted.append('%s "%s"' % (related.opts.verbose_name, old_rel_obj)) + + # Save the order, if applicable. + if self.change and self.opts.get_ordered_objects(): + order = new_data['order_'] and map(int, new_data['order_'].split(',')) or [] + for rel_opts in self.opts.get_ordered_objects(): + getattr(new_object, 'set_%s_order' % rel_opts.object_name.lower())(order) + return new_object + + def get_related_objects(self): + return self.opts.get_followed_related_objects(self.follow) + + def flatten_data(self): + new_data = {} + obj = self.change and self.original_object or None + for f in self.opts.get_data_holders(self.follow): + fol = self.follow.get(f.name) + new_data.update(f.flatten_data(fol, obj)) + return new_data + +class AutomaticAddManipulator(AutomaticManipulator): + change = False + +class AutomaticChangeManipulator(AutomaticManipulator): + change = True + def __init__(self, obj_key, follow=None): + self.obj_key = obj_key + try: + self.original_object = self.manager.get(pk=obj_key) + except ObjectDoesNotExist: + # If the object doesn't exist, this might be a manipulator for a + # one-to-one related object that hasn't created its subobject yet. + # For example, this might be a Restaurant for a Place that doesn't + # yet have restaurant information. + if self.opts.one_to_one_field: + # Sanity check -- Make sure the "parent" object exists. + # For example, make sure the Place exists for the Restaurant. + # Let the ObjectDoesNotExist exception propagate up. + lookup_kwargs = self.opts.one_to_one_field.rel.limit_choices_to + lookup_kwargs['%s__exact' % self.opts.one_to_one_field.rel.field_name] = obj_key + self.opts.one_to_one_field.rel.to.get_model_module().get(**lookup_kwargs) + params = dict([(f.attname, f.get_default()) for f in self.opts.fields]) + params[self.opts.pk.attname] = obj_key + self.original_object = self.opts.get_model_module().Klass(**params) + else: + raise + super(AutomaticChangeManipulator, self).__init__(follow=follow) + +def manipulator_validator_unique_together(field_name_list, opts, self, field_data, all_data): + from django.db.models.fields.related import ManyToOneRel + from django.utils.text import get_text_list + field_list = [opts.get_field(field_name) for field_name in field_name_list] + if isinstance(field_list[0].rel, ManyToOneRel): + kwargs = {'%s__%s__iexact' % (field_name_list[0], field_list[0].rel.field_name): field_data} + else: + kwargs = {'%s__iexact' % field_name_list[0]: field_data} + for f in field_list[1:]: + # This is really not going to work for fields that have different + # form fields, e.g. DateTime. + # This validation needs to occur after html2python to be effective. + field_val = all_data.get(f.attname, None) + if field_val is None: + # This will be caught by another validator, assuming the field + # doesn't have blank=True. + return + if isinstance(f.rel, ManyToOneRel): + kwargs['%s__pk' % f.name] = field_val + else: + kwargs['%s__iexact' % f.name] = field_val + try: + old_obj = self.manager.get(**kwargs) + except ObjectDoesNotExist: + return + if hasattr(self, 'original_object') and self.original_object._get_pk_val() == old_obj._get_pk_val(): + pass + else: + raise validators.ValidationError, _("%(object)s with this %(type)s already exists for the given %(field)s.") % \ + {'object': capfirst(opts.verbose_name), 'type': field_list[0].verbose_name, 'field': get_text_list(field_name_list[1:], 'and')} + +def manipulator_validator_unique_for_date(from_field, date_field, opts, lookup_type, self, field_data, all_data): + from django.db.models.fields.related import ManyToOneRel + date_str = all_data.get(date_field.get_manipulator_field_names('')[0], None) + date_val = forms.DateField.html2python(date_str) + if date_val is None: + return # Date was invalid. This will be caught by another validator. + lookup_kwargs = {'%s__year' % date_field.name: date_val.year} + if isinstance(from_field.rel, ManyToOneRel): + lookup_kwargs['%s__pk' % from_field.name] = field_data + else: + lookup_kwargs['%s__iexact' % from_field.name] = field_data + if lookup_type in ('month', 'date'): + lookup_kwargs['%s__month' % date_field.name] = date_val.month + if lookup_type == 'date': + lookup_kwargs['%s__day' % date_field.name] = date_val.day + try: + old_obj = self.manager.get(**lookup_kwargs) + except ObjectDoesNotExist: + return + else: + if hasattr(self, 'original_object') and self.original_object._get_pk_val() == old_obj._get_pk_val(): + pass + else: + format_string = (lookup_type == 'date') and '%B %d, %Y' or '%B %Y' + raise validators.ValidationError, "Please enter a different %s. The one you entered is already being used for %s." % \ + (from_field.verbose_name, date_val.strftime(format_string)) diff --git a/django/db/models/options.py b/django/db/models/options.py new file mode 100644 index 0000000000..d1f5eeb756 --- /dev/null +++ b/django/db/models/options.py @@ -0,0 +1,269 @@ +from django.conf import settings +from django.db.models.related import RelatedObject +from django.db.models.fields.related import ManyToManyRel +from django.db.models.fields import AutoField, FieldDoesNotExist +from django.db.models.loading import get_models +from django.db.models.query import orderlist2sql +from django.db.models import Manager +from bisect import bisect +import re + +# Calculate the verbose_name by converting from InitialCaps to "lowercase with spaces". +get_verbose_name = lambda class_name: re.sub('([A-Z])', ' \\1', class_name).lower().strip() + +DEFAULT_NAMES = ('verbose_name', 'db_table', 'ordering', + 'unique_together', 'permissions', 'get_latest_by', + 'order_with_respect_to', 'app_label') + +class Options: + def __init__(self, meta): + self.fields, self.many_to_many = [], [] + self.module_name, self.verbose_name = None, None + self.verbose_name_plural = None + self.db_table = '' + self.ordering = [] + self.unique_together = [] + self.permissions = [] + self.object_name, self.app_label = None, None + self.get_latest_by = None + self.order_with_respect_to = None + self.admin = None + self.meta = meta + self.pk = None + self.has_auto_field = False + self.one_to_one_field = None + self.parents = [] + + def contribute_to_class(self, cls, name): + cls._meta = self + self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS + # First, construct the default values for these options. + self.object_name = cls.__name__ + self.module_name = self.object_name.lower() + self.verbose_name = get_verbose_name(self.object_name) + # Next, apply any overridden values from 'class Meta'. + if self.meta: + meta_attrs = self.meta.__dict__ + del meta_attrs['__module__'] + del meta_attrs['__doc__'] + for attr_name in DEFAULT_NAMES: + setattr(self, attr_name, meta_attrs.pop(attr_name, getattr(self, attr_name))) + # verbose_name_plural is a special case because it uses a 's' + # by default. + setattr(self, 'verbose_name_plural', meta_attrs.pop('verbose_name_plural', self.verbose_name + 's')) + # Any leftover attributes must be invalid. + if meta_attrs != {}: + raise TypeError, "'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs.keys()) + else: + self.verbose_name_plural = self.verbose_name + 's' + del self.meta + + def _prepare(self, model): + if self.order_with_respect_to: + self.order_with_respect_to = self.get_field(self.order_with_respect_to) + self.ordering = ('_order',) + else: + self.order_with_respect_to = None + + if self.pk is None: + auto = AutoField(verbose_name='ID', primary_key=True) + auto.creation_counter = -1 + model.add_to_class('id', auto) + + # If the db_table wasn't provided, use the app_label + module_name. + if not self.db_table: + self.db_table = "%s_%s" % (self.app_label, self.module_name) + + def add_field(self, field): + # Insert the given field in the order in which it was created, using + # the "creation_counter" attribute of the field. + # Move many-to-many related fields from self.fields into self.many_to_many. + if field.rel and isinstance(field.rel, ManyToManyRel): + self.many_to_many.insert(bisect(self.many_to_many, field), field) + else: + self.fields.insert(bisect(self.fields, field), field) + if not self.pk and field.primary_key: + self.pk = field + + def __repr__(self): + return '<Options for %s>' % self.object_name + + def get_field(self, name, many_to_many=True): + "Returns the requested field by name. Raises FieldDoesNotExist on error." + to_search = many_to_many and (self.fields + self.many_to_many) or self.fields + for f in to_search: + if f.name == name: + return f + raise FieldDoesNotExist, "name=%s" % name + + def get_order_sql(self, table_prefix=''): + "Returns the full 'ORDER BY' clause for this object, according to self.ordering." + if not self.ordering: return '' + pre = table_prefix and (table_prefix + '.') or '' + return 'ORDER BY ' + orderlist2sql(self.ordering, self, pre) + + def get_add_permission(self): + return 'add_%s' % self.object_name.lower() + + def get_change_permission(self): + return 'change_%s' % self.object_name.lower() + + def get_delete_permission(self): + return 'delete_%s' % self.object_name.lower() + + def get_all_related_objects(self): + try: # Try the cache first. + return self._all_related_objects + except AttributeError: + rel_objs = [] + for klass in get_models(): + for f in klass._meta.fields: + if f.rel and self == f.rel.to._meta: + rel_objs.append(RelatedObject(f.rel.to, klass, f)) + self._all_related_objects = rel_objs + return rel_objs + + def get_followed_related_objects(self, follow=None): + if follow == None: + follow = self.get_follow() + return [f for f in self.get_all_related_objects() if follow.get(f.name, None)] + + def get_data_holders(self, follow=None): + if follow == None: + follow = self.get_follow() + return [f for f in self.fields + self.many_to_many + self.get_all_related_objects() if follow.get(f.name, None)] + + def get_follow(self, override=None): + follow = {} + for f in self.fields + self.many_to_many + self.get_all_related_objects(): + if override and override.has_key(f.name): + child_override = override[f.name] + else: + child_override = None + fol = f.get_follow(child_override) + if fol != None: + follow[f.name] = fol + return follow + + def get_all_related_many_to_many_objects(self): + try: # Try the cache first. + return self._all_related_many_to_many_objects + except AttributeError: + rel_objs = [] + for klass in get_models(): + for f in klass._meta.many_to_many: + if f.rel and self == f.rel.to._meta: + rel_objs.append(RelatedObject(f.rel.to, klass, f)) + self._all_related_many_to_many_objects = rel_objs + return rel_objs + + def get_ordered_objects(self): + "Returns a list of Options objects that are ordered with respect to this object." + if not hasattr(self, '_ordered_objects'): + objects = [] + # TODO + #for klass in get_models(get_app(self.app_label)): + # opts = klass._meta + # if opts.order_with_respect_to and opts.order_with_respect_to.rel \ + # and self == opts.order_with_respect_to.rel.to._meta: + # objects.append(opts) + self._ordered_objects = objects + return self._ordered_objects + + def has_field_type(self, field_type, follow=None): + """ + Returns True if this object's admin form has at least one of the given + field_type (e.g. FileField). + """ + # TODO: follow + if not hasattr(self, '_field_types'): + self._field_types = {} + if not self._field_types.has_key(field_type): + try: + # First check self.fields. + for f in self.fields: + if isinstance(f, field_type): + raise StopIteration + # Failing that, check related fields. + for related in self.get_followed_related_objects(follow): + for f in related.opts.fields: + if isinstance(f, field_type): + raise StopIteration + except StopIteration: + self._field_types[field_type] = True + else: + self._field_types[field_type] = False + return self._field_types[field_type] + +class AdminOptions: + def __init__(self, fields=None, js=None, list_display=None, list_filter=None, + date_hierarchy=None, save_as=False, ordering=None, search_fields=None, + save_on_top=False, list_select_related=False, manager=None, list_per_page=100): + self.fields = fields + self.js = js or [] + self.list_display = list_display or ['__str__'] + self.list_filter = list_filter or [] + self.date_hierarchy = date_hierarchy + self.save_as, self.ordering = save_as, ordering + self.search_fields = search_fields or [] + self.save_on_top = save_on_top + self.list_select_related = list_select_related + self.list_per_page = list_per_page + self.manager = manager or Manager() + + def get_field_sets(self, opts): + "Returns a list of AdminFieldSet objects for this AdminOptions object." + if self.fields is None: + field_struct = ((None, {'fields': [f.name for f in opts.fields + opts.many_to_many if f.editable and not isinstance(f, AutoField)]}),) + else: + field_struct = self.fields + new_fieldset_list = [] + for fieldset in field_struct: + fs_options = fieldset[1] + classes = fs_options.get('classes', ()) + description = fs_options.get('description', '') + new_fieldset_list.append(AdminFieldSet(fieldset[0], classes, + opts.get_field, fs_options['fields'], description)) + return new_fieldset_list + + def contribute_to_class(self, cls, name): + cls._meta.admin = self + # Make sure the admin manager has access to the model + self.manager.model = cls + +class AdminFieldSet(object): + def __init__(self, name, classes, field_locator_func, line_specs, description): + self.name = name + self.field_lines = [AdminFieldLine(field_locator_func, line_spec) for line_spec in line_specs] + self.classes = classes + self.description = description + + def __repr__(self): + return "FieldSet: (%s, %s)" % (self.name, self.field_lines) + + def bind(self, field_mapping, original, bound_field_set_class): + return bound_field_set_class(self, field_mapping, original) + + def __iter__(self): + for field_line in self.field_lines: + yield field_line + + def __len__(self): + return len(self.field_lines) + +class AdminFieldLine(object): + def __init__(self, field_locator_func, linespec): + if isinstance(linespec, basestring): + self.fields = [field_locator_func(linespec)] + else: + self.fields = [field_locator_func(field_name) for field_name in linespec] + + def bind(self, field_mapping, original, bound_field_line_class): + return bound_field_line_class(self, field_mapping, original) + + def __iter__(self): + for field in self.fields: + yield field + + def __len__(self): + return len(self.fields) diff --git a/django/db/models/query.py b/django/db/models/query.py new file mode 100644 index 0000000000..365ead2a3a --- /dev/null +++ b/django/db/models/query.py @@ -0,0 +1,888 @@ +from django.db import backend, connection, transaction +from django.db.models.fields import DateField, FieldDoesNotExist +from django.db.models import signals +from django.dispatch import dispatcher +from django.utils.datastructures import SortedDict + +import operator + +# For Python 2.3 +if not hasattr(__builtins__, 'set'): + from sets import Set as set + +LOOKUP_SEPARATOR = '__' + +# Size of each "chunk" for get_iterator calls. +# Larger values are slightly faster at the expense of more storage space. +GET_ITERATOR_CHUNK_SIZE = 100 + +#################### +# HELPER FUNCTIONS # +#################### + +# Django currently supports two forms of ordering. +# Form 1 (deprecated) example: +# order_by=(('pub_date', 'DESC'), ('headline', 'ASC'), (None, 'RANDOM')) +# Form 2 (new-style) example: +# order_by=('-pub_date', 'headline', '?') +# Form 1 is deprecated and will no longer be supported for Django's first +# official release. The following code converts from Form 1 to Form 2. + +LEGACY_ORDERING_MAPPING = {'ASC': '_', 'DESC': '-_', 'RANDOM': '?'} + +def handle_legacy_orderlist(order_list): + if not order_list or isinstance(order_list[0], basestring): + return order_list + else: + import warnings + new_order_list = [LEGACY_ORDERING_MAPPING[j.upper()].replace('_', str(i)) for i, j in order_list] + warnings.warn("%r ordering syntax is deprecated. Use %r instead." % (order_list, new_order_list), DeprecationWarning) + return new_order_list + +def orderfield2column(f, opts): + try: + return opts.get_field(f, False).column + except FieldDoesNotExist: + return f + +def orderlist2sql(order_list, opts, prefix=''): + if prefix.endswith('.'): + prefix = backend.quote_name(prefix[:-1]) + '.' + output = [] + for f in handle_legacy_orderlist(order_list): + if f.startswith('-'): + output.append('%s%s DESC' % (prefix, backend.quote_name(orderfield2column(f[1:], opts)))) + elif f == '?': + output.append(backend.get_random_function_sql()) + else: + output.append('%s%s ASC' % (prefix, backend.quote_name(orderfield2column(f, opts)))) + return ', '.join(output) + +def quote_only_if_word(word): + if ' ' in word: + return word + else: + return backend.quote_name(word) + +class QuerySet(object): + "Represents a lazy database lookup for a set of objects" + def __init__(self, model=None): + self.model = model + self._filters = Q() + self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering. + self._select_related = False # Whether to fill cache for related objects. + self._distinct = False # Whether the query should use SELECT DISTINCT. + self._select = {} # Dictionary of attname -> SQL. + self._where = [] # List of extra WHERE clauses to use. + self._params = [] # List of params to use for extra WHERE clauses. + self._tables = [] # List of extra tables to use. + self._offset = None # OFFSET clause + self._limit = None # LIMIT clause + self._result_cache = None + + ######################## + # PYTHON MAGIC METHODS # + ######################## + + def __repr__(self): + return repr(self._get_data()) + + def __len__(self): + return len(self._get_data()) + + def __iter__(self): + return iter(self._get_data()) + + def __getitem__(self, k): + "Retrieve an item or slice from the set of results." + if self._result_cache is None: + if isinstance(k, slice): + # Offset: + if self._offset is None: + offset = k.start + elif k.start is None: + offset = self._offset + else: + offset = self._offset + k.start + # Now adjust offset to the bounds of any existing limit: + if self._limit is not None and k.start is not None: + limit = self._limit - k.start + else: + limit = self._limit + + # Limit: + if k.stop is not None and k.start is not None: + if limit is None: + limit = k.stop - k.start + else: + limit = min((k.stop - k.start), limit) + else: + if limit is None: + limit = k.stop + else: + if k.stop is not None: + limit = min(k.stop, limit) + + if k.step is None: + return self._clone(_offset=offset, _limit=limit) + else: + return list(self._clone(_offset=offset, _limit=limit))[::k.step] + else: + return self._clone(_offset=k, _limit=1).get() + else: + return self._result_cache[k] + + def __and__(self, other): + combined = self._combine(other) + combined._filters = self._filters & other._filters + return combined + + def __or__(self, other): + combined = self._combine(other) + combined._filters = self._filters | other._filters + return combined + + #################################### + # METHODS THAT DO DATABASE QUERIES # + #################################### + + def iterator(self): + "Performs the SELECT database lookup of this QuerySet." + # self._select is a dictionary, and dictionaries' key order is + # undefined, so we convert it to a list of tuples. + extra_select = self._select.items() + + cursor = connection.cursor() + select, sql, params = self._get_sql_clause() + cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) + fill_cache = self._select_related + index_end = len(self.model._meta.fields) + while 1: + rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) + if not rows: + raise StopIteration + for row in rows: + if fill_cache: + obj, index_end = get_cached_row(self.model, row, 0) + else: + obj = self.model(*row[:index_end]) + for i, k in enumerate(extra_select): + setattr(obj, k[0], row[index_end+i]) + yield obj + + def count(self): + "Performs a SELECT COUNT() and returns the number of records as an integer." + counter = self._clone() + counter._order_by = () + counter._offset = None + counter._limit = None + counter._select_related = False + select, sql, params = counter._get_sql_clause() + cursor = connection.cursor() + cursor.execute("SELECT COUNT(*)" + sql, params) + return cursor.fetchone()[0] + + def get(self, *args, **kwargs): + "Performs the SELECT and returns a single object matching the given keyword arguments." + clone = self.filter(*args, **kwargs) + if not clone._order_by: + clone._order_by = () + obj_list = list(clone) + if len(obj_list) < 1: + raise self.model.DoesNotExist, "%s does not exist for %s" % (self.model._meta.object_name, kwargs) + assert len(obj_list) == 1, "get() returned more than one %s -- it returned %s! Lookup parameters were %s" % (self.model._meta.object_name, len(obj_list), kwargs) + return obj_list[0] + + def latest(self, field_name=None): + """ + Returns the latest object, according to the model's 'get_latest_by' + option or optional given field_name. + """ + latest_by = field_name or self.model._meta.get_latest_by + assert bool(latest_by), "latest() requires either a field_name parameter or 'get_latest_by' in the model" + assert self._limit is None and self._offset is None, \ + "Cannot change a query once a slice has been taken." + return self._clone(_limit=1, _order_by=('-'+latest_by,)).get() + + def in_bulk(self, id_list): + """ + Returns a dictionary mapping each of the given IDs to the object with + that ID. + """ + assert self._limit is None and self._offset is None, \ + "Cannot use 'limit' or 'offset' with in_bulk" + assert isinstance(id_list, (tuple, list)), "in_bulk() must be provided with a list of IDs." + id_list = list(id_list) + if id_list == []: + return {} + qs = self._clone() + qs._where.append("%s.%s IN (%s)" % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self.model._meta.pk.column), ",".join(['%s'] * len(id_list)))) + qs._params.extend(id_list) + return dict([(obj._get_pk_val(), obj) for obj in qs.iterator()]) + + def delete(self): + """ + Deletes the records in the current QuerySet. + """ + assert self._limit is None and self._offset is None, \ + "Cannot use 'limit' or 'offset' with delete." + + del_query = self._clone() + + # disable non-supported fields + del_query._select_related = False + del_query._order_by = [] + + # Delete objects in chunks to prevent an the list of + # related objects from becoming too long + more_objects = True + while more_objects: + # Collect all the objects to be deleted in this chunk, and all the objects + # that are related to the objects that are to be deleted + seen_objs = SortedDict() + more_objects = False + for object in del_query[0:GET_ITERATOR_CHUNK_SIZE]: + more_objects = True + object._collect_sub_objects(seen_objs) + + # If one or more objects were found, delete them. + # Otherwise, stop looping. + if more_objects: + delete_objects(seen_objs) + + # Clear the result cache, in case this QuerySet gets reused. + self._result_cache = None + delete.alters_data = True + + ################################################## + # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS # + ################################################## + + def values(self, *fields): + return self._clone(klass=ValuesQuerySet, _fields=fields) + + def dates(self, field_name, kind, order='ASC'): + """ + Returns a list of datetime objects representing all available dates + for the given field_name, scoped to 'kind'. + """ + assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'." + assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'." + # Let the FieldDoesNotExist exception propagate. + field = self.model._meta.get_field(field_name, many_to_many=False) + assert isinstance(field, DateField), "%r isn't a DateField." % field_name + return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order) + + ################################################################## + # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET # + ################################################################## + + def filter(self, *args, **kwargs): + "Returns a new QuerySet instance with the args ANDed to the existing set." + return self._filter_or_exclude(Q, *args, **kwargs) + + def exclude(self, *args, **kwargs): + "Returns a new QuerySet instance with NOT (args) ANDed to the existing set." + return self._filter_or_exclude(QNot, *args, **kwargs) + + def _filter_or_exclude(self, qtype, *args, **kwargs): + if len(args) > 0 or len(kwargs) > 0: + assert self._limit is None and self._offset is None, \ + "Cannot filter a query once a slice has been taken." + + clone = self._clone() + if len(kwargs) > 0: + clone._filters = clone._filters & qtype(**kwargs) + if len(args) > 0: + clone._filters = clone._filters & reduce(operator.and_, args) + return clone + + def select_related(self, true_or_false=True): + "Returns a new QuerySet instance with '_select_related' modified." + return self._clone(_select_related=true_or_false) + + def order_by(self, *field_names): + "Returns a new QuerySet instance with the ordering changed." + assert self._limit is None and self._offset is None, \ + "Cannot reorder a query once a slice has been taken." + return self._clone(_order_by=field_names) + + def distinct(self, true_or_false=True): + "Returns a new QuerySet instance with '_distinct' modified." + return self._clone(_distinct=true_or_false) + + def extra(self, select=None, where=None, params=None, tables=None): + assert self._limit is None and self._offset is None, \ + "Cannot change a query once a slice has been taken" + clone = self._clone() + if select: clone._select.update(select) + if where: clone._where.extend(where) + if params: clone._params.extend(params) + if tables: clone._tables.extend(tables) + return clone + + ################### + # PRIVATE METHODS # + ################### + + def _clone(self, klass=None, **kwargs): + if klass is None: + klass = self.__class__ + c = klass() + c.model = self.model + c._filters = self._filters + c._order_by = self._order_by + c._select_related = self._select_related + c._distinct = self._distinct + c._select = self._select.copy() + c._where = self._where[:] + c._params = self._params[:] + c._tables = self._tables[:] + c._offset = self._offset + c._limit = self._limit + c.__dict__.update(kwargs) + return c + + def _combine(self, other): + assert self._limit is None and self._offset is None \ + and other._limit is None and other._offset is None, \ + "Cannot combine queries once a slice has been taken." + assert self._distinct == other._distinct, \ + "Cannot combine a unique query with a non-unique query" + # use 'other's order by + # (so that A.filter(args1) & A.filter(args2) does the same as + # A.filter(args1).filter(args2) + combined = other._clone() + # If 'self' is ordered and 'other' isn't, propagate 'self's ordering + if (self._order_by is not None and len(self._order_by) > 0) and \ + (combined._order_by is None or len(combined._order_by) == 0): + combined._order_by = self._order_by + return combined + + def _get_data(self): + if self._result_cache is None: + self._result_cache = list(self.iterator()) + return self._result_cache + + def _get_sql_clause(self): + opts = self.model._meta + + # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z. + select = ["%s.%s" % (backend.quote_name(opts.db_table), backend.quote_name(f.column)) for f in opts.fields] + tables = [quote_only_if_word(t) for t in self._tables] + joins = SortedDict() + where = self._where[:] + params = self._params[:] + + # Convert self._filters into SQL. + tables2, joins2, where2, params2 = self._filters.get_sql(opts) + tables.extend(tables2) + joins.update(joins2) + where.extend(where2) + params.extend(params2) + + # Add additional tables and WHERE clauses based on select_related. + if self._select_related: + fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) + + # Add any additional SELECTs. + if self._select: + select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), backend.quote_name(s[0])) for s in self._select.items()]) + + # Start composing the body of the SQL statement. + sql = [" FROM", backend.quote_name(opts.db_table)] + + # Compose the join dictionary into SQL describing the joins. + if joins: + sql.append(" ".join(["%s %s AS %s ON %s" % (join_type, table, alias, condition) + for (alias, (table, join_type, condition)) in joins.items()])) + + # Compose the tables clause into SQL. + if tables: + sql.append(", " + ", ".join(tables)) + + # Compose the where clause into SQL. + if where: + sql.append(where and "WHERE " + " AND ".join(where)) + + # ORDER BY clause + order_by = [] + if self._order_by is not None: + ordering_to_use = self._order_by + else: + ordering_to_use = opts.ordering + for f in handle_legacy_orderlist(ordering_to_use): + if f == '?': # Special case. + order_by.append(backend.get_random_function_sql()) + else: + if f.startswith('-'): + col_name = f[1:] + order = "DESC" + else: + col_name = f + order = "ASC" + if "." in col_name: + table_prefix, col_name = col_name.split('.', 1) + table_prefix = backend.quote_name(table_prefix) + '.' + else: + # Use the database table as a column prefix if it wasn't given, + # and if the requested column isn't a custom SELECT. + if "." not in col_name and col_name not in (self._select or ()): + table_prefix = backend.quote_name(opts.db_table) + '.' + else: + table_prefix = '' + order_by.append('%s%s %s' % (table_prefix, backend.quote_name(orderfield2column(col_name, opts)), order)) + if order_by: + sql.append("ORDER BY " + ", ".join(order_by)) + + # LIMIT and OFFSET clauses + if self._limit is not None: + sql.append("%s " % backend.get_limit_offset_sql(self._limit, self._offset)) + else: + assert self._offset is None, "'offset' is not allowed without 'limit'" + + return select, " ".join(sql), params + +class ValuesQuerySet(QuerySet): + def iterator(self): + # select_related and select aren't supported in values(). + self._select_related = False + self._select = {} + + # self._fields is a list of field names to fetch. + if self._fields: + columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + field_names = self._fields + else: # Default to all fields. + columns = [f.column for f in self.model._meta.fields] + field_names = [f.attname for f in self.model._meta.fields] + + cursor = connection.cursor() + select, sql, params = self._get_sql_clause() + select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns] + cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) + while 1: + rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) + if not rows: + raise StopIteration + for row in rows: + yield dict(zip(field_names, row)) + + def _clone(self, klass=None, **kwargs): + c = super(ValuesQuerySet, self)._clone(klass, **kwargs) + c._fields = self._fields[:] + return c + +class DateQuerySet(QuerySet): + def iterator(self): + from django.db.backends.util import typecast_timestamp + self._order_by = () # Clear this because it'll mess things up otherwise. + if self._field.null: + date_query._where.append('%s.%s IS NOT NULL' % \ + (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))) + select, sql, params = self._get_sql_clause() + sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ + (backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table), + backend.quote_name(self._field.column))), sql, self._order) + cursor = connection.cursor() + cursor.execute(sql, params) + # We have to manually run typecast_timestamp(str()) on the results, because + # MySQL doesn't automatically cast the result of date functions as datetime + # objects -- MySQL returns the values as strings, instead. + return [typecast_timestamp(str(row[0])) for row in cursor.fetchall()] + + def _clone(self, klass=None, **kwargs): + c = super(DateQuerySet, self)._clone(klass, **kwargs) + c._field = self._field + c._kind = self._kind + c._order = self._order + return c + +class QOperator: + "Base class for QAnd and QOr" + def __init__(self, *args): + self.args = args + + def get_sql(self, opts): + tables, joins, where, params = [], SortedDict(), [], [] + for val in self.args: + tables2, joins2, where2, params2 = val.get_sql(opts) + tables.extend(tables2) + joins.update(joins2) + where.extend(where2) + params.extend(params2) + if where: + return tables, joins, ['(%s)' % self.operator.join(where)], params + return tables, joins, [], params + +class QAnd(QOperator): + "Encapsulates a combined query that uses 'AND'." + operator = ' AND ' + def __or__(self, other): + return QOr(self, other) + + def __and__(self, other): + if isinstance(other, QAnd): + return QAnd(*(self.args+other.args)) + elif isinstance(other, (Q, QOr)): + return QAnd(*(self.args+(other,))) + else: + raise TypeError, other + +class QOr(QOperator): + "Encapsulates a combined query that uses 'OR'." + operator = ' OR ' + def __and__(self, other): + return QAnd(self, other) + + def __or__(self, other): + if isinstance(other, QOr): + return QOr(*(self.args+other.args)) + elif isinstance(other, (Q, QAnd)): + return QOr(*(self.args+(other,))) + else: + raise TypeError, other + +class Q(object): + "Encapsulates queries as objects that can be combined logically." + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __and__(self, other): + return QAnd(self, other) + + def __or__(self, other): + return QOr(self, other) + + def get_sql(self, opts): + return parse_lookup(self.kwargs.items(), opts) + +class QNot(Q): + "Encapsulates NOT (...) queries as objects" + + def get_sql(self, opts): + tables, joins, where, params = super(QNot, self).get_sql(opts) + where2 = ['(NOT (%s))' % " AND ".join(where)] + return tables, joins, where2, params + +def get_where_clause(lookup_type, table_prefix, field_name, value): + if table_prefix.endswith('.'): + table_prefix = backend.quote_name(table_prefix[:-1])+'.' + field_name = backend.quote_name(field_name) + try: + return '%s%s %s' % (table_prefix, field_name, (backend.OPERATOR_MAPPING[lookup_type] % '%s')) + except KeyError: + pass + if lookup_type == 'in': + return '%s%s IN (%s)' % (table_prefix, field_name, ','.join(['%s' for v in value])) + elif lookup_type == 'range': + return '%s%s BETWEEN %%s AND %%s' % (table_prefix, field_name) + elif lookup_type in ('year', 'month', 'day'): + return "%s = %%s" % backend.get_date_extract_sql(lookup_type, table_prefix + field_name) + elif lookup_type == 'isnull': + return "%s%s IS %sNULL" % (table_prefix, field_name, (not value and 'NOT ' or '')) + raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type) + +def get_cached_row(klass, row, index_start): + "Helper function that recursively returns an object with cache filled" + index_end = index_start + len(klass._meta.fields) + obj = klass(*row[index_start:index_end]) + for f in klass._meta.fields: + if f.rel and not f.null: + rel_obj, index_end = get_cached_row(f.rel.to, row, index_end) + setattr(obj, f.get_cache_name(), rel_obj) + return obj, index_end + +def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen): + """ + Helper function that recursively populates the select, tables and where (in + place) for fill-cache queries. + """ + for f in opts.fields: + if f.rel and not f.null: + db_table = f.rel.to._meta.db_table + if db_table not in cache_tables_seen: + tables.append(backend.quote_name(db_table)) + else: # The table was already seen, so give it a table alias. + new_prefix = '%s%s' % (db_table, len(cache_tables_seen)) + tables.append('%s %s' % (backend.quote_name(db_table), backend.quote_name(new_prefix))) + db_table = new_prefix + cache_tables_seen.append(db_table) + where.append('%s.%s = %s.%s' % \ + (backend.quote_name(old_prefix), backend.quote_name(f.column), + backend.quote_name(db_table), backend.quote_name(f.rel.get_related_field().column))) + select.extend(['%s.%s' % (backend.quote_name(db_table), backend.quote_name(f2.column)) for f2 in f.rel.to._meta.fields]) + fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen) + +def parse_lookup(kwarg_items, opts): + # Helper function that handles converting API kwargs + # (e.g. "name__exact": "tom") to SQL. + + # 'joins' is a sorted dictionary describing the tables that must be joined + # to complete the query. The dictionary is sorted because creation order + # is significant; it is a dictionary to ensure uniqueness of alias names. + # + # Each key-value pair follows the form + # alias: (table, join_type, condition) + # where + # alias is the AS alias for the joined table + # table is the actual table name to be joined + # join_type is the type of join (INNER JOIN, LEFT OUTER JOIN, etc) + # condition is the where-like statement over which narrows the join. + # alias will be derived from the lookup list name. + # + # At present, this method only every returns INNER JOINs; the option is + # there for others to implement custom Q()s, etc that return other join + # types. + tables, joins, where, params = [], SortedDict(), [], [] + + for kwarg, value in kwarg_items: + if value is not None: + path = kwarg.split(LOOKUP_SEPARATOR) + # Extract the last elements of the kwarg. + # The very-last is the clause (equals, like, etc). + # The second-last is the table column on which the clause is + # to be performed. + # The exceptions to this are: + # 1) "pk", which is an implicit id__exact; + # if we find "pk", make the clause "exact', and insert + # a dummy name of None, which we will replace when + # we know which table column to grab as the primary key. + # 2) If there is only one part, assume it to be an __exact + clause = path.pop() + if clause == 'pk': + clause = 'exact' + path.append(None) + elif len(path) == 0: + path.append(clause) + clause = 'exact' + + if len(path) < 1: + raise TypeError, "Cannot parse keyword query %r" % kwarg + + tables2, joins2, where2, params2 = lookup_inner(path, clause, value, opts, opts.db_table, None) + tables.extend(tables2) + joins.update(joins2) + where.extend(where2) + params.extend(params2) + return tables, joins, where, params + +class FieldFound(Exception): + "Exception used to short circuit field-finding operations." + pass + +def find_field(name, field_list, related_query): + """ + Finds a field with a specific name in a list of field instances. + Returns None if there are no matches, or several matches. + """ + if related_query: + matches = [f for f in field_list if f.field.related_query_name() == name] + else: + matches = [f for f in field_list if f.name == name] + if len(matches) != 1: + return None + return matches[0] + +def lookup_inner(path, clause, value, opts, table, column): + tables, joins, where, params = [], SortedDict(), [], [] + current_opts = opts + current_table = table + current_column = column + intermediate_table = None + join_required = False + + name = path.pop(0) + # Has the primary key been requested? If so, expand it out + # to be the name of the current class' primary key + if name is None: + name = current_opts.pk.name + + # Try to find the name in the fields associated with the current class + try: + # Does the name belong to a defined many-to-many field? + field = find_field(name, current_opts.many_to_many, False) + if field: + new_table = current_table + LOOKUP_SEPARATOR + name + new_opts = field.rel.to._meta + new_column = new_opts.pk.column + + # Need to create an intermediate table join over the m2m table + # This process hijacks current_table/column to point to the + # intermediate table. + current_table = "m2m_" + new_table + intermediate_table = field.m2m_db_table() + join_column = field.m2m_reverse_name() + intermediate_column = field.m2m_column_name() + + raise FieldFound + + # Does the name belong to a reverse defined many-to-many field? + field = find_field(name, current_opts.get_all_related_many_to_many_objects(), True) + if field: + new_table = current_table + LOOKUP_SEPARATOR + name + new_opts = field.opts + new_column = new_opts.pk.column + + # Need to create an intermediate table join over the m2m table. + # This process hijacks current_table/column to point to the + # intermediate table. + current_table = "m2m_" + new_table + intermediate_table = field.field.m2m_db_table() + join_column = field.field.m2m_column_name() + intermediate_column = field.field.m2m_reverse_name() + + raise FieldFound + + # Does the name belong to a one-to-many field? + field = find_field(name, current_opts.get_all_related_objects(), True) + if field: + new_table = table + LOOKUP_SEPARATOR + name + new_opts = field.opts + new_column = field.field.column + join_column = opts.pk.column + + # 1-N fields MUST be joined, regardless of any other conditions. + join_required = True + + raise FieldFound + + # Does the name belong to a one-to-one, many-to-one, or regular field? + field = find_field(name, current_opts.fields, False) + if field: + if field.rel: # One-to-One/Many-to-one field + new_table = current_table + LOOKUP_SEPARATOR + name + new_opts = field.rel.to._meta + new_column = new_opts.pk.column + join_column = field.column + + raise FieldFound + + except FieldFound: # Match found, loop has been shortcut. + pass + except: # Any other exception; rethrow + raise + else: # No match found. + raise TypeError, "Cannot resolve keyword '%s' into field" % name + + # Check to see if an intermediate join is required between current_table + # and new_table. + if intermediate_table: + joins[backend.quote_name(current_table)] = ( + backend.quote_name(intermediate_table), + "LEFT OUTER JOIN", + "%s.%s = %s.%s" % \ + (backend.quote_name(table), + backend.quote_name(current_opts.pk.column), + backend.quote_name(current_table), + backend.quote_name(intermediate_column)) + ) + + if path: + if len(path) == 1 and path[0] in (new_opts.pk.name, None) \ + and clause in ('exact', 'isnull') and not join_required: + # If the last name query is for a key, and the search is for + # isnull/exact, then the current (for N-1) or intermediate + # (for N-N) table can be used for the search - no need to join an + # extra table just to check the primary key. + new_table = current_table + else: + # There are 1 or more name queries pending, and we have ruled out + # any shortcuts; therefore, a join is required. + joins[backend.quote_name(new_table)] = ( + backend.quote_name(new_opts.db_table), + "INNER JOIN", + "%s.%s = %s.%s" % + (backend.quote_name(current_table), + backend.quote_name(join_column), + backend.quote_name(new_table), + backend.quote_name(new_column)) + ) + # If we have made the join, we don't need to tell subsequent + # recursive calls about the column name we joined on. + join_column = None + + # There are name queries remaining. Recurse deeper. + tables2, joins2, where2, params2 = lookup_inner(path, clause, value, new_opts, new_table, join_column) + + tables.extend(tables2) + joins.update(joins2) + where.extend(where2) + params.extend(params2) + else: + # Evaluate clause on current table. + if name in (current_opts.pk.name, None) and clause in ('exact', 'isnull') and current_column: + # If this is an exact/isnull key search, and the last pass + # found/introduced a current/intermediate table that we can use to + # optimize the query, then use that column name. + column = current_column + else: + column = field.column + + where.append(get_where_clause(clause, current_table + '.', column, value)) + params.extend(field.get_db_prep_lookup(clause, value)) + + return tables, joins, where, params + +def delete_objects(seen_objs): + "Iterate through a list of seen classes, and remove any instances that are referred to" + ordered_classes = seen_objs.keys() + ordered_classes.reverse() + + cursor = connection.cursor() + + for cls in ordered_classes: + seen_objs[cls] = seen_objs[cls].items() + seen_objs[cls].sort() + + # Pre notify all instances to be deleted + for pk_val, instance in seen_objs[cls]: + dispatcher.send(signal=signals.pre_delete, sender=cls, instance=instance) + + pk_list = [pk for pk,instance in seen_objs[cls]] + for related in cls._meta.get_all_related_many_to_many_objects(): + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + (backend.quote_name(related.field.m2m_db_table()), + backend.quote_name(related.field.m2m_reverse_name()), + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + for f in cls._meta.many_to_many: + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + (backend.quote_name(f.m2m_db_table()), + backend.quote_name(f.m2m_column_name()), + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + for field in cls._meta.fields: + if field.rel and field.null and field.rel.to in seen_objs: + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + cursor.execute("UPDATE %s SET %s=NULL WHERE %s IN (%s)" % \ + (backend.quote_name(cls._meta.db_table), + backend.quote_name(field.column), + backend.quote_name(cls._meta.pk.column), + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + + # Now delete the actual data + for cls in ordered_classes: + seen_objs[cls].reverse() + pk_list = [pk for pk,instance in seen_objs[cls]] + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + (backend.quote_name(cls._meta.db_table), + backend.quote_name(cls._meta.pk.column), + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + + # Last cleanup; set NULLs where there once was a reference to the object, + # NULL the primary key of the found objects, and perform post-notification. + for pk_val, instance in seen_objs[cls]: + for field in cls._meta.fields: + if field.rel and field.null and field.rel.to in seen_objs: + setattr(instance, field.attname, None) + + setattr(instance, cls._meta.pk.attname, None) + dispatcher.send(signal=signals.post_delete, sender=cls, instance=instance) + + transaction.commit_unless_managed() diff --git a/django/db/models/related.py b/django/db/models/related.py new file mode 100644 index 0000000000..4ab8cde5e7 --- /dev/null +++ b/django/db/models/related.py @@ -0,0 +1,132 @@ +class BoundRelatedObject(object): + def __init__(self, related_object, field_mapping, original): + self.relation = related_object + self.field_mappings = field_mapping[related_object.opts.module_name] + + def template_name(self): + raise NotImplementedError + + def __repr__(self): + return repr(self.__dict__) + +class RelatedObject(object): + def __init__(self, parent_model, model, field): + self.parent_model = parent_model + self.model = model + self.opts = model._meta + self.field = field + self.edit_inline = field.rel.edit_inline + self.name = self.opts.module_name + self.var_name = self.opts.object_name.lower() + + def flatten_data(self, follow, obj=None): + new_data = {} + rel_instances = self.get_list(obj) + for i, rel_instance in enumerate(rel_instances): + instance_data = {} + for f in self.opts.fields + self.opts.many_to_many: + # TODO: Fix for recursive manipulators. + fol = follow.get(f.name, None) + if fol: + field_data = f.flatten_data(fol, rel_instance) + for name, value in field_data.items(): + instance_data['%s.%d.%s' % (self.var_name, i, name)] = value + new_data.update(instance_data) + return new_data + + def extract_data(self, data): + """ + Pull out the data meant for inline objects of this class, + i.e. anything starting with our module name. + """ + return data # TODO + + def get_list(self, parent_instance=None): + "Get the list of this type of object from an instance of the parent class." + if parent_instance is not None: + attr = getattr(parent_instance, self.get_accessor_name()) + if self.field.rel.multiple: + # For many-to-many relationships, return a list of objects + # corresponding to the xxx_num_in_admin options of the field + objects = list(attr.all()) + + count = len(objects) + self.field.rel.num_extra_on_change + if self.field.rel.min_num_in_admin: + count = max(count, self.field.rel.min_num_in_admin) + if self.field.rel.max_num_in_admin: + count = min(count, self.field.rel.max_num_in_admin) + + change = count - len(objects) + if change > 0: + return objects + [None] * change + if change < 0: + return objects[:change] + else: # Just right + return objects + else: + # A one-to-one relationship, so just return the single related + # object + return [attr] + else: + return [None] * self.field.rel.num_in_admin + + def editable_fields(self): + "Get the fields in this class that should be edited inline." + return [f for f in self.opts.fields + self.opts.many_to_many if f.editable and f != self.field] + + def get_follow(self, override=None): + if isinstance(override, bool): + if override: + over = {} + else: + return None + else: + if override: + over = override.copy() + elif self.edit_inline: + over = {} + else: + return None + + over[self.field.name] = False + return self.opts.get_follow(over) + + def get_manipulator_fields(self, opts, manipulator, change, follow): + if self.field.rel.multiple: + if change: + attr = getattr(manipulator.original_object, self.get_accessor_name()) + count = attr.count() + count += self.field.rel.num_extra_on_change + if self.field.rel.min_num_in_admin: + count = max(count, self.field.rel.min_num_in_admin) + if self.field.rel.max_num_in_admin: + count = min(count, self.field.rel.max_num_in_admin) + else: + count = self.field.rel.num_in_admin + else: + count = 1 + + fields = [] + for i in range(count): + for f in self.opts.fields + self.opts.many_to_many: + if follow.get(f.name, False): + prefix = '%s.%d.' % (self.var_name, i) + fields.extend(f.get_manipulator_fields(self.opts, manipulator, change, + name_prefix=prefix, rel=True)) + return fields + + def __repr__(self): + return "<RelatedObject: %s related to %s>" % (self.name, self.field.name) + + def bind(self, field_mapping, original, bound_related_object_class=BoundRelatedObject): + return bound_related_object_class(self, field_mapping, original) + + def get_accessor_name(self): + # This method encapsulates the logic that decides what name to give an + # accessor descriptor that retrieves related many-to-one or + # many-to-many objects. It uses the lower-cased object_name + "_set", + # but this can be overridden with the "related_name" option. + if self.field.rel.multiple: + return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') + else: + return self.field.rel.related_name or (self.opts.object_name.lower()) diff --git a/django/db/models/signals.py b/django/db/models/signals.py new file mode 100644 index 0000000000..2171cb1bf3 --- /dev/null +++ b/django/db/models/signals.py @@ -0,0 +1,12 @@ +class_prepared = object() + +pre_init= object() +post_init = object() + +pre_save = object() +post_save = object() + +pre_delete = object() +post_delete = object() + +post_syncdb = object() |
