diff options
Diffstat (limited to 'django')
| -rw-r--r-- | django/conf/global_settings.py | 51 | ||||
| -rw-r--r-- | django/conf/locale/sr_Latn/LC_MESSAGES/django.po | 2 | ||||
| -rw-r--r-- | django/contrib/auth/tests/views.py | 23 | ||||
| -rw-r--r-- | django/core/management/commands/inspectdb.py | 5 | ||||
| -rw-r--r-- | django/core/management/commands/test.py | 15 | ||||
| -rw-r--r-- | django/core/management/commands/test_windmill.py | 122 | ||||
| -rw-r--r-- | django/db/models/loading.py | 36 | ||||
| -rw-r--r-- | django/test/__init__.py | 7 | ||||
| -rw-r--r-- | django/test/decorators.py | 44 | ||||
| -rw-r--r-- | django/test/mocks.py | 38 | ||||
| -rw-r--r-- | django/test/simple.py | 175 | ||||
| -rw-r--r-- | django/test/test_coverage.py | 129 | ||||
| -rw-r--r-- | django/test/testcases.py | 77 | ||||
| -rw-r--r-- | django/test/twill_tests.py | 332 | ||||
| -rw-r--r-- | django/test/utils.py | 21 | ||||
| -rw-r--r-- | django/test/windmill_tests.py | 137 | ||||
| -rw-r--r-- | django/utils/module_tools/__init__.py | 3 | ||||
| -rw-r--r-- | django/utils/module_tools/data_storage.py | 42 | ||||
| -rw-r--r-- | django/utils/module_tools/module_loader.py | 79 | ||||
| -rw-r--r-- | django/utils/module_tools/module_walker.py | 135 |
20 files changed, 1419 insertions, 54 deletions
diff --git a/django/conf/global_settings.py b/django/conf/global_settings.py index 99fc72e468..eb9629990d 100644 --- a/django/conf/global_settings.py +++ b/django/conf/global_settings.py @@ -379,7 +379,7 @@ PASSWORD_RESET_TIMEOUT_DAYS = 3 ########### # The name of the method to use to invoke the test suite -TEST_RUNNER = 'django.test.simple.run_tests' +TEST_RUNNER = 'django.test.simple.DefaultTestRunner' # The name of the database to use for testing purposes. # If None, a name of 'test_' + DATABASE_NAME will be assumed @@ -393,6 +393,55 @@ TEST_DATABASE_CHARSET = None TEST_DATABASE_COLLATION = None ############ +# COVERAGE # +############ + + +# Specify the coverage test runner +COVERAGE_TEST_RUNNER = 'django.test.test_coverage.ConsoleReportCoverageRunner' + +# Specify regular expressions of code blocks the coverage analyzer should +# ignore as statements (e.g. ``raise NotImplemented``). +# These statements are not figured in as part of the coverage statistics. +# This setting is optional. +COVERAGE_CODE_EXCLUDES = [ + 'def __unicode__\(self\):', 'def get_absolute_url\(self\):', + 'from .* import .*', 'import .*', + ] + +# Specify a list of regular expressions of paths to exclude from +# coverage analysis. +# Note these paths are ignored by the module introspection tool and take +# precedence over any package/module settings such as: +# TODO: THE SETTING FOR MODULES +# Use this to exclude subdirectories like ``r'.svn'``, for example. +# This setting is optional. +COVERAGE_PATH_EXCLUDES = [r'.svn'] + +# Specify a list of additional module paths to include +# in the coverage analysis. By default, only modules within installed +# apps are reported. If you have utility modules outside of the app +# structure, you can include them here. +# Note this list is *NOT* regular expression, so you have to be explicit, +# such as 'myproject.utils', and not 'utils$'. +# This setting is optional. +COVERAGE_ADDITIONAL_MODULES = [] + +# Specify a list of regular expressions of module paths to exclude +# from the coverage analysis. Examples are ``'tests$'`` and ``'urls$'``. +# This setting is optional. +COVERAGE_MODULE_EXCLUDES = ['tests$', 'settings$','urls$', 'common.views.test', + '__init__', 'django'] + +# Specify the directory where you would like the coverage report to create +# the HTML files. +# You'll need to make sure this directory exists and is writable by the +# user account running the test. +# You should probably set this one explicitly in your own settings file. +COVERAGE_REPORT_HTML_OUTPUT_DIR = 'test_html' + + +############ # FIXTURES # ############ diff --git a/django/conf/locale/sr_Latn/LC_MESSAGES/django.po b/django/conf/locale/sr_Latn/LC_MESSAGES/django.po index fd7ae5f869..06dfd356b9 100644 --- a/django/conf/locale/sr_Latn/LC_MESSAGES/django.po +++ b/django/conf/locale/sr_Latn/LC_MESSAGES/django.po @@ -314,7 +314,7 @@ msgstr "zapisi u logovima" #: contrib/admin/options.py:133 contrib/admin/options.py:147 msgid "None" -msgstr "Ništa" +msgstr "Ништа" #: contrib/admin/options.py:519 #, python-format diff --git a/django/contrib/auth/tests/views.py b/django/contrib/auth/tests/views.py index 532f92b523..d7915fd51d 100644 --- a/django/contrib/auth/tests/views.py +++ b/django/contrib/auth/tests/views.py @@ -9,6 +9,7 @@ from django.contrib.auth.models import User from django.test import TestCase from django.core import mail from django.core.urlresolvers import reverse +from django.test.decorators import views_required class AuthViewsTestCase(TestCase): """ @@ -49,22 +50,26 @@ class PasswordResetTest(AuthViewsTestCase): def test_email_not_found(self): "Error is raised if the provided email address isn't currently registered" - response = self.client.get('/password_reset/') + response = self.client.get(reverse('django.contrib.auth.views.password_reset')) self.assertEquals(response.status_code, 200) - response = self.client.post('/password_reset/', {'email': 'not_a_real_email@email.com'}) + response = self.client.post(reverse('django.contrib.auth.views.password_reset'), {'email': 'not_a_real_email@email.com'}) self.assertContains(response, "That e-mail address doesn't have an associated user account") self.assertEquals(len(mail.outbox), 0) - + + test_email_not_found = views_required(required_views=['django.contrib.auth.views.password_reset'])(test_email_not_found) + def test_email_found(self): "Email is sent if a valid email address is provided for password reset" - response = self.client.post('/password_reset/', {'email': 'staffmember@example.com'}) + response = self.client.post(reverse('django.contrib.auth.views.password_reset'), {'email': 'staffmember@example.com'}) self.assertEquals(response.status_code, 302) self.assertEquals(len(mail.outbox), 1) self.assert_("http://" in mail.outbox[0].body) + + test_email_found = views_required(required_views=['django.contrib.auth.views.password_reset'])(test_email_found) def _test_confirm_start(self): # Start by creating the email - response = self.client.post('/password_reset/', {'email': 'staffmember@example.com'}) + response = self.client.post(reverse('django.contrib.auth.views.password_reset'), {'email': 'staffmember@example.com'}) self.assertEquals(response.status_code, 302) self.assertEquals(len(mail.outbox), 1) return self._read_signup_email(mail.outbox[0]) @@ -80,6 +85,8 @@ class PasswordResetTest(AuthViewsTestCase): # redirect to a 'complete' page: self.assertEquals(response.status_code, 200) self.assert_("Please enter your new password" in response.content) + test_confirm_valid = views_required(required_views=['django.contrib.auth.views.password_reset'])(test_confirm_valid) + def test_confirm_invalid(self): url, path = self._test_confirm_start() @@ -90,6 +97,7 @@ class PasswordResetTest(AuthViewsTestCase): response = self.client.get(path) self.assertEquals(response.status_code, 200) self.assert_("The password reset link was invalid" in response.content) + test_confirm_invalid = views_required(required_views=['django.contrib.auth.views.password_reset'])(test_confirm_invalid) def test_confirm_invalid_post(self): # Same as test_confirm_invalid, but trying @@ -102,6 +110,7 @@ class PasswordResetTest(AuthViewsTestCase): # Check the password has not been changed u = User.objects.get(email='staffmember@example.com') self.assert_(not u.check_password("anewpassword")) + test_confirm_invalid_post = views_required(required_views=['django.contrib.auth.views.password_reset'])(test_confirm_invalid_post) def test_confirm_complete(self): url, path = self._test_confirm_start() @@ -117,6 +126,7 @@ class PasswordResetTest(AuthViewsTestCase): response = self.client.get(path) self.assertEquals(response.status_code, 200) self.assert_("The password reset link was invalid" in response.content) + test_confirm_complete = views_required(required_views=['django.contrib.auth.views.password_reset'])(test_confirm_complete) def test_confirm_different_passwords(self): url, path = self._test_confirm_start() @@ -124,7 +134,8 @@ class PasswordResetTest(AuthViewsTestCase): 'new_password2':' x'}) self.assertEquals(response.status_code, 200) self.assert_("The two password fields didn't match" in response.content) - + + test_confirm_different_passwords = views_required(required_views=['django.contrib.auth.views.password_reset'])(test_confirm_different_passwords) class ChangePasswordTest(AuthViewsTestCase): def login(self, password='password'): diff --git a/django/core/management/commands/inspectdb.py b/django/core/management/commands/inspectdb.py index fbe539274e..54203c532b 100644 --- a/django/core/management/commands/inspectdb.py +++ b/django/core/management/commands/inspectdb.py @@ -15,8 +15,9 @@ class Command(NoArgsCommand): def handle_inspection(self): from django.db import connection import keyword - - table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '') + + table2model = lambda table_name: table_name.title() + .replace('_', '').replace(' ', '').replace('-', '').replace('*','_').replace(',','_') cursor = connection.cursor() yield "# This is an auto-generated Django model module." diff --git a/django/core/management/commands/test.py b/django/core/management/commands/test.py index 8ebf3daea6..92949fc367 100644 --- a/django/core/management/commands/test.py +++ b/django/core/management/commands/test.py @@ -6,6 +6,10 @@ class Command(BaseCommand): option_list = BaseCommand.option_list + ( make_option('--noinput', action='store_false', dest='interactive', default=True, help='Tells Django to NOT prompt the user for input of any kind.'), + make_option('--coverage', action='store_true', dest='coverage', default=False, + help='Tells Django to run the coverage runner'), + make_option('--reports', action='store_true', dest='reports', default=False, + help='Tells Django to output coverage results as HTML reports'), ) help = 'Runs the test suite for the specified applications, or the entire site if no apps are specified.' args = '[appname ...]' @@ -18,8 +22,13 @@ class Command(BaseCommand): verbosity = int(options.get('verbosity', 1)) interactive = options.get('interactive', True) - test_runner = get_runner(settings) - - failures = test_runner(test_labels, verbosity=verbosity, interactive=interactive) + cover = options.get('coverage', False) + report = options.get('reports', False) + test_runner = get_runner(settings, coverage=cover, reports=report) + if(type(test_runner) == 'function'): + failures = test_runner(test_labels, verbosity=verbosity, interactive=interactive) + else: + tr = test_runner() + failures = tr.run_tests(test_labels, verbosity=verbosity, interactive=interactive) if failures: sys.exit(failures) diff --git a/django/core/management/commands/test_windmill.py b/django/core/management/commands/test_windmill.py new file mode 100644 index 0000000000..c3da16fe84 --- /dev/null +++ b/django/core/management/commands/test_windmill.py @@ -0,0 +1,122 @@ +from django.core.management.base import BaseCommand +#from windmill.authoring import djangotest +from django.utils import importlib +from django.test import windmill_tests as djangotest +import sys, os +from time import sleep +import types +import logging +import threading + +class ServerContainer(object): + start_test_server = djangotest.start_test_server + stop_test_server = djangotest.stop_test_server + +def attempt_import(name, suffix): + try: + mod = importlib.import_module(name+'.'+suffix) + except ImportError: + mod = None + if mod is not None: + s = name.split('.') + mod = importlib.import_module(s.pop(0)) + for x in s+[suffix]: + try: + mod = getattr(mod, x) + except Exception, e: + pass + return mod + +class Command(BaseCommand): + + help = "Run windmill tests. Specify a browser, if one is not passed Firefox will be used" + + args = '<label label ...>' + label = 'label' + + def handle(self, *labels, **options): + + from windmill.conf import global_settings + from django.test.windmill_tests import WindmillDjangoUnitTest + if 'ie' in labels: + global_settings.START_IE = True + sys.argv.remove('ie') + elif 'safari' in labels: + global_settings.START_SAFARI = True + sys.argv.remove('safari') + elif 'chrome' in labels: + global_settings.START_CHROME = True + sys.argv.remove('chrome') + else: + global_settings.START_FIREFOX = True + if 'firefox' in labels: + sys.argv.remove('firefox') + + if 'manage.py' in sys.argv: + sys.argv.remove('manage.py') + if 'test_windmill' in sys.argv: + sys.argv.remove('test_windmill') + + from django.conf import settings + tests = [] + for name in settings.INSTALLED_APPS: + for suffix in ['tests', 'wmtests', 'windmilltests']: + x = attempt_import(name, suffix) + if x is not None: tests.append((suffix,x,)); + wmfixs = [] + wmtests = [] + for (ttype, mod,) in tests: + if ttype == 'tests': + for ucls in [getattr(mod, x) for x in dir(mod) + if ( type(getattr(mod, x, None)) in (types.ClassType, + types.TypeType) ) and + issubclass(getattr(mod, x), WindmillDjangoUnitTest) + ]: + wmtests.append(ucls.test_dir) + + else: + if mod.__file__.endswith('__init__.py') or mod.__file__.endswith('__init__.pyc'): + wmtests.append(os.path.join(*os.path.split(os.path.abspath(mod.__file__))[:-1])) + else: + wmtests.append(os.path.abspath(mod.__file__)) + # Look for any attribute named fixtures and try to load it. + if hasattr(mod, 'fixtures'): + for fixture in getattr(mod, 'fixtures'): + wmfixs.append(fixture) + + # Create the threaded server. + server_container = ServerContainer() + # Set the server's 'fixtures' attribute so they can be loaded in-thread if using sqlite's memory backend. + server_container.__setattr__('fixtures', wmfixs) + # Start the server thread. + started = server_container.start_test_server() + + print 'Waiting for threaded server to come online.' + started.wait() + print 'DB Ready, Server online.' + + global_settings.TEST_URL = 'http://localhost:%d' % server_container.server_thread.port + + # import windmill + # windmill.stdout, windmill.stdin = sys.stdout, sys.stdin + from windmill.authoring import setup_module, teardown_module + + + if len(wmtests) is 0: + print 'Sorry, no windmill tests found.' + else: + testtotals = {} + x = logging.getLogger() + x.setLevel(0) + from windmill.server.proxy import logger + from functest import bin + from functest import runner + runner.CLIRunner.final = classmethod(lambda self, totals: testtotals.update(totals) ) + import windmill + setup_module(tests[0][1]) + sys.argv = sys.argv + wmtests + bin.cli() + teardown_module(tests[0][1]) + if testtotals['fail'] is not 0: + sleep(.5) + sys.exit(1) diff --git a/django/db/models/loading.py b/django/db/models/loading.py index e07aab4efe..96ebe43941 100644 --- a/django/db/models/loading.py +++ b/django/db/models/loading.py @@ -10,7 +10,7 @@ import os import threading __all__ = ('get_apps', 'get_app', 'get_models', 'get_model', 'register_models', - 'load_app', 'app_cache_ready') + 'load_app', 'app_cache_ready', 'remove_model') class AppCache(object): """ @@ -178,6 +178,39 @@ class AppCache(object): continue model_dict[model_name] = model + def clear_apps(self): + """Clears cache so on next call, it will be reloaded.""" + self.loaded = False + self.write_lock.acquire() + self.handled.clear() + try: + if self.loaded: + return + for app_name in settings.INSTALLED_APPS: + if app_name in self.handled: + continue + self.load_app(app_name, True) + if not self.nesting_level: + for app_name in self.postponed: + self.load_app(app_name) + self.loaded = True + finally: + self.write_lock.release() + + def remove_model(self, model_name): + """Removes a model from the cache. Used when loading test-only models.""" + try: + try: + self.write_lock.acquire() + if model_name in self.app_models: + del self.app_models[model_name] + except Exception, e: + raise e + finally: + self.write_lock.release() + + + cache = AppCache() # These methods were always module level, so are kept that way for backwards @@ -190,3 +223,4 @@ get_model = cache.get_model register_models = cache.register_models load_app = cache.load_app app_cache_ready = cache.app_cache_ready +remove_model = cache.remove_model
\ No newline at end of file diff --git a/django/test/__init__.py b/django/test/__init__.py index 957b293e12..e5647243f2 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -4,3 +4,10 @@ Django Unit Test and Doctest framework. from django.test.client import Client from django.test.testcases import TestCase, TransactionTestCase + +class SkippedTest(Exception): + def __init__(self, reason): + self.reason = reason + + def __str__(self): + return self.reason
\ No newline at end of file diff --git a/django/test/decorators.py b/django/test/decorators.py new file mode 100644 index 0000000000..ca621894e3 --- /dev/null +++ b/django/test/decorators.py @@ -0,0 +1,44 @@ +from django.core import urlresolvers +from django.test import SkippedTest + +def views_required(required_views=[]): + def urls_found(): + try: + for view in required_views: + urlresolvers.reverse(view) + return True + except urlresolvers.NoReverseMatch: + return False + reason = 'Required view%s for this test not found: %s' % \ + (len(required_views) > 1 and 's' or '', ', '.join(required_views)) + return conditional_skip(urls_found, reason=reason) + +def modules_required(required_modules=[]): + def modules_found(): + try: + for module in required_modules: + __import__(module) + return True + except ImportError: + return False + reason = 'Required module%s for this test not found: %s' % \ + (len(required_modules) > 1 and 's' or '', ', '.join(required_modules)) + return conditional_skip(modules_found, reason=reason) + +def skip_specific_database(database_engine): + def database_check(): + from django.conf import settings + return database_engine == settings.DATABASE_ENGINE + reason = 'Test not run for database engine %s.' % database_engine + return conditional_skip(database_check, reason=reason) + +def conditional_skip(required_condition, reason=''): + if required_condition(): + return lambda x: x + else: + return skip_test(reason) + +def skip_test(reason=''): + def _skip(x): + raise SkippedTest(reason=reason) + return lambda x: _skip diff --git a/django/test/mocks.py b/django/test/mocks.py new file mode 100644 index 0000000000..94048a2a2a --- /dev/null +++ b/django/test/mocks.py @@ -0,0 +1,38 @@ +from django.test import Client +from django.core.handlers.wsgi import WSGIRequest + +class RequestFactory(Client): + """ + Class that lets you create mock Request objects for use in testing. + + Usage: + + rf = RequestFactory() + get_request = rf.get('/hello/') + post_request = rf.post('/submit/', {'foo': 'bar'}) + + This class re-uses the django.test.client.Client interface, docs here: + http://www.djangoproject.com/documentation/testing/#the-test-client + + Once you have a request object you can pass it to any view function, + just as if that view had been hooked up using a URLconf. + + """ + def request(self, **request): + """ + Similar to parent class, but returns the request object as soon as it + has created it. + """ + environ = { + 'HTTP_COOKIE': self.cookies, + 'PATH_INFO': '/', + 'QUERY_STRING': '', + 'REQUEST_METHOD': 'GET', + 'SCRIPT_NAME': '', + 'SERVER_NAME': 'testserver', + 'SERVER_PORT': 80, + 'SERVER_PROTOCOL': 'HTTP/1.1', + } + environ.update(self.defaults) + environ.update(request) + return WSGIRequest(environ) diff --git a/django/test/simple.py b/django/test/simple.py index f3c48bae33..8a60b69b36 100644 --- a/django/test/simple.py +++ b/django/test/simple.py @@ -1,4 +1,4 @@ -import unittest +import sys, time, traceback, unittest from django.conf import settings from django.db.models import get_app, get_apps from django.test import _doctest as doctest @@ -146,52 +146,151 @@ def reorder_suite(suite, classes): bins[0].addTests(bins[i+1]) return bins[0] -def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[]): +class DefaultTestRunner(object): + """ + The original test runner. No coverage reporting. """ - Run the unit tests for all the test labels in the provided list. - Labels must be of the form: - - app.TestClass.test_method - Run a single specific test method - - app.TestClass - Run all the test methods in a given class - - app - Search for doctests and unittests in the named application. - When looking for tests, the test runner will look in the models and - tests modules for the application. + def __init__(self): + """ + Placeholder constructor. Want to make it obvious that it can + be overridden. + """ + self.isloaded = True - A list of 'extra' tests may also be provided; these tests - will be added to the test suite. - Returns the number of tests that failed. - """ - setup_test_environment() + def run_tests(self, test_labels, verbosity=1, interactive=True, extra_tests=[]): + """ + Run the unit tests for all the test labels in the provided list. + Labels must be of the form: + - app.TestClass.test_method + Run a single specific test method + - app.TestClass + Run all the test methods in a given class + - app + Search for doctests and unittests in the named application. - settings.DEBUG = False - suite = unittest.TestSuite() + When looking for tests, the test runner will look in the models and + tests modules for the application. - if test_labels: - for label in test_labels: - if '.' in label: - suite.addTest(build_test(label)) - else: - app = get_app(label) + A list of 'extra' tests may also be provided; these tests + will be added to the test suite. + + Returns the number of tests that failed. + """ + setup_test_environment() + + settings.DEBUG = False + suite = unittest.TestSuite() + + if test_labels: + for label in test_labels: + if '.' in label: + suite.addTest(build_test(label)) + else: + app = get_app(label) + suite.addTest(build_suite(app)) + else: + for app in get_apps(): suite.addTest(build_suite(app)) - else: - for app in get_apps(): - suite.addTest(build_suite(app)) - for test in extra_tests: - suite.addTest(test) + for test in extra_tests: + suite.addTest(test) + + suite = reorder_suite(suite, (TestCase,)) + + old_name = settings.DATABASE_NAME + from django.db import connection + connection.creation.create_test_db(verbosity, autoclobber=not interactive) + result = SkipTestRunner(verbosity=verbosity).run(suite) + connection.creation.destroy_test_db(old_name, verbosity) + + teardown_test_environment() + + return len(result.failures) + len(result.errors) - suite = reorder_suite(suite, (TestCase,)) - old_name = settings.DATABASE_NAME - from django.db import connection - connection.creation.create_test_db(verbosity, autoclobber=not interactive) - result = unittest.TextTestRunner(verbosity=verbosity).run(suite) - connection.creation.destroy_test_db(old_name, verbosity) +class SkipTestRunner: + """ + A test runner class that adds a Skipped category in the output layer. + + Modeled after unittest.TextTestRunner. + + Similarly to unittest.TextTestRunner, prints summary of the results at the end. + (Including a count of skipped tests.) + """ + + def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1): + self.stream = unittest._WritelnDecorator(stream) + self.descriptions = descriptions + self.verbosity = verbosity + self.result = _SkipTestResult(self.stream, descriptions, verbosity) + + def run(self, test): + "Run the given test case or test suite." + startTime = time.time() + test.run(self.result) + stopTime = time.time() + timeTaken = stopTime - startTime + + self.result.printErrors() + self.stream.writeln(self.result.separator2) + run = self.result.testsRun + self.stream.writeln('Ran %d test%s in %.3fs' % + (run, run != 1 and 's' or '', timeTaken)) + self.stream.writeln() + if not self.result.wasSuccessful(): + self.stream.write('FAILED (') + failed, errored, skipped = map(len, (self.result.failures, self.result.errors, self.result.skipped)) + if failed: + self.stream.write('failures=%d' % failed) + if errored: + if failed: self.stream.write(', ') + self.stream.write('errors=%d' % errored) + if skipped: + if errored or failed: self.stream.write(', ') + self.stream.write('skipped=%d' % skipped) + self.stream.writeln(')') + else: + self.stream.writeln('OK') + return self.result + +class _SkipTestResult(unittest._TextTestResult): + """ + A test result class that adds a Skipped category in the output layer. + + Modeled after unittest._TextTestResult. + + Similarly to unittest._TextTestResult, prints out the names of tests as they are + run and errors as they occur. + """ + + def __init__(self, stream, descriptions, verbosity): + unittest._TextTestResult.__init__(self, stream, descriptions, verbosity) + self.skipped = [] + + def addError(self, test, err): + # Determine if this is a skipped test + tracebacks = traceback.extract_tb(err[2]) + if tracebacks[-1][-1].startswith('raise SkippedTest'): + self.skipped.append((test, self._exc_info_to_string(err, test))) + if self.showAll: + self.stream.writeln('SKIPPED') + elif self.dots: + self.stream.write('S') + self.stream.flush() + else: + unittest.TestResult.addError(self, test, err) + if self.showAll: + self.stream.writeln('ERROR') + elif self.dots: + self.stream.write('E') + self.stream.flush() - teardown_test_environment() + def printErrors(self): + if self.dots or self.showAll: + self.stream.writeln() + self.printErrorList('SKIPPED', self.skipped) + self.printErrorList('ERROR', self.errors) + self.printErrorList('FAIL', self.failures) - return len(result.failures) + len(result.errors) diff --git a/django/test/test_coverage.py b/django/test/test_coverage.py new file mode 100644 index 0000000000..7c22edacb1 --- /dev/null +++ b/django/test/test_coverage.py @@ -0,0 +1,129 @@ +import coverage, time + +import os, sys + +from django.conf import settings +from django.db.models.loading import get_app, get_apps + +from django.test.simple import DefaultTestRunner as base_run_tests + +from django.utils.module_tools import get_all_modules +from django.utils.translation import ugettext as _ + +def _get_app_package(app_model_module): + """ + Returns the app module name from the app model module. + """ + return '.'.join(app_model_module.__name__.split('.')[:-1]) + + +class BaseCoverageRunner(object): + """ + Placeholder class for coverage runners. Intended to be easily extended. + """ + + def __init__(self): + """Placeholder (since it is overrideable)""" + self.cov = coverage.coverage(cover_pylib=True, auto_data=True) + self.cov.use_cache(True) + self.cov.load() + #self.cov.combine() + + + + def run_tests(self, test_labels, verbosity=1, interactive=True, + extra_tests=[]): + """ + Runs the specified tests while generating code coverage statistics. Upon + the tests' completion, the results are printed to stdout. + """ + #self.cov.erase() + #Allow an on-disk cache of coverage stats. + #self.cov.use_cache(0) + #for e in getattr(settings, 'COVERAGE_CODE_EXCLUDES', []): + # self.cov.exclude(e) + + + self.cov.start() + brt = base_run_tests() + results = brt.run_tests(test_labels, verbosity, interactive, extra_tests) + self.cov.stop() + #self.cov.erase() + + coverage_modules = [] + if test_labels: + for label in test_labels: + label = label.split('.')[0] + app = get_app(label) + coverage_modules.append(_get_app_package(app)) + else: + for app in get_apps(): + coverage_modules.append(_get_app_package(app)) + + coverage_modules.extend(getattr(settings, 'COVERAGE_ADDITIONAL_MODULES', [])) + + packages, self.modules, self.excludes, self.errors = get_all_modules( + coverage_modules, getattr(settings, 'COVERAGE_MODULE_EXCLUDES', []), + getattr(settings, 'COVERAGE_PATH_EXCLUDES', [])) + + + + return results + +class ConsoleReportCoverageRunner(BaseCoverageRunner): + + def run_tests(self, *args, **kwargs): + """docstring for run_tests""" + res = super(ConsoleReportCoverageRunner, self).run_tests( *args, **kwargs) + self.cov.report(self.modules.values(), show_missing=1) + + if self.excludes: + print >> sys.stdout + print >> sys.stdout, _("The following packages or modules were excluded:"), + for e in self.excludes: + print >> sys.stdout, e, + print >>sys.stdout + if self.errors: + print >> sys.stdout + print >> sys.stderr, _("There were problems with the following packages or modules:"), + for e in self.errors: + print >> sys.stderr, e, + print >> sys.stdout + return res + +class ReportingCoverageRunner(BaseCoverageRunner): + """Runs coverage.py analysis, as well as generating detailed HTML reports.""" + + def __init__(self, outdir = None): + """ + Constructor, overrides BaseCoverageRunner. Sets output directory + for reports. Parameter or setting. + """ + super(ReportingCoverageRunner, self).__init__() + if(outdir): + self.outdir = outdir + else: + # Realistically, we aren't going to ship the entire reporting framework.. + # but for the time being I have left it in. + self.outdir = getattr(settings, 'COVERAGE_REPORT_HTML_OUTPUT_DIR', 'test_html') + self.outdir = os.path.abspath(self.outdir) + # Create directory + if( not os.path.exists(self.outdir)): + os.mkdir(self.outdir) + + + def run_tests(self, *args, **kwargs): + """ + Overrides BaseCoverageRunner.run_tests, and adds html report generation + with the results + """ + res = super(ReportingCoverageRunner, self).run_tests( *args, **kwargs) + self.cov.html_report(self.modules.values(), + directory=self.outdir, + ignore_errors=True, + omit_prefixes='modeltests') + print >>sys.stdout + print >>sys.stdout, _("HTML reports were output to '%s'") %self.outdir + + return res + diff --git a/django/test/testcases.py b/django/test/testcases.py index 8c73c63796..010827fe67 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -212,8 +212,11 @@ class TransactionTestCase(unittest.TestCase): named fixtures. * If the Test Case class has a 'urls' member, replace the ROOT_URLCONF with it. + * If the Test Case class has a 'test_models' member, load the relivent + named models. * Clearing the mail test outbox. """ + self._test_model_setup() self._fixture_setup() self._urlconf_setup() mail.outbox = [] @@ -225,6 +228,37 @@ class TransactionTestCase(unittest.TestCase): # that we're using *args and **kwargs together. call_command('loaddata', *self.fixtures, **{'verbosity': 0}) + def _test_model_setup(self): + if hasattr(self, 'test_models'): + #print self.test_models + if self.__module__.endswith('tests'): + app_label = self.__module__.split('.')[:-1][-1] + app_path = '.'.join(self.__module__.split('.')[:-1]) + else: + app_label = self.__module__.split('.')[:-2][-1] + app_path = '.'.join(self.__module__.split('.')[:-2]) + from django.db.models.loading import cache + from django.utils import importlib + from django.db import models + #importlib.import_module() + cache.write_lock.acquire() + try: + app_mods = cache.app_models[app_label] + for tm in self.test_models: + #print "importing %s " % tm + im = importlib.import_module(app_path + '.' + tm) + #cache.app_store[im] = len(cache.app_store) + #print "finding model classes" + mod_classes = [f for f in im.__dict__.values() if hasattr(f,'__bases__') and issubclass(f,models.Model)] + #print "Found models %s " % mod_classes + for mc in mod_classes: + print "Adding %s to AppCache" % mc + app_mods[mc.__name__.lower()] = mc + finally: + cache.write_lock.release() + #call_command('syncdb', **{'verbosity': 0}) + + def _urlconf_setup(self): if hasattr(self, 'urls'): self._old_root_urlconf = settings.ROOT_URLCONF @@ -261,12 +295,55 @@ class TransactionTestCase(unittest.TestCase): * Putting back the original ROOT_URLCONF if it was changed. """ + self._test_model_teardown() self._fixture_teardown() self._urlconf_teardown() def _fixture_teardown(self): pass + def _test_model_teardown(self): + if hasattr(self, 'test_models'): + #print self.test_models + if self.__module__.endswith('tests'): + app_label = self.__module__.split('.')[:-1][-1] + app_path = '.'.join(self.__module__.split('.')[:-1]) + else: + app_label = self.__module__.split('.')[:-2][-1] + app_path = '.'.join(self.__module__.split('.')[:-2]) + from django.db.models.loading import cache + from django.utils import importlib + from django.db import models + #importlib.import_module() + # cc = cache.get_app(app_label) + # del cache.app_store[cc] + # #del cache.app_models[app_label] + # cache.loaded = False + # print cache.handled + # print '.'.join(cc.__name__.split('.')[:-1]) + # print cc.__package__ + # del cache.handled[cc.__package__] + # cache._populate() + # print cache.get_app(app_label) + #cc = cache.get_app(app_label) + + #reload(cache.get_app(app_label)) + cache.write_lock.acquire() + try: + app_mods = cache.app_models[app_label] + #print app_mods + for tm in self.test_models: + #print "importing %s " % tm + im = importlib.import_module(app_path + '.' + tm) + #cache.app_store[im] = len(cache.app_store) + #print "finding model classes" + mod_classes = [f for f in im.__dict__.values() if hasattr(f,'__bases__') and issubclass(f,models.Model)] + #print "Found models %s " % mod_classes + for mc in mod_classes: + print "Deleting %s from AppCache" % mc + del app_mods[mc.__name__.lower()] + finally: + cache.write_lock.release() def _urlconf_teardown(self): if hasattr(self, '_old_root_urlconf'): settings.ROOT_URLCONF = self._old_root_urlconf diff --git a/django/test/twill_tests.py b/django/test/twill_tests.py new file mode 100644 index 0000000000..e14dd6e14c --- /dev/null +++ b/django/test/twill_tests.py @@ -0,0 +1,332 @@ +""" + +This code is originally by miracle2k: +http://bitbucket.org/miracle2k/djutils/src/97f92c32c621/djutils/test/twill.py + + +Integrates the twill web browsing scripting language with Django. + +Provides too main functions, ``setup()`` and ``teardown``, that hook +(and unhook) a certain host name to the WSGI interface of your Django +app, making it possible to test your site using twill without actually +going through TCP/IP. + +It also changes the twill browsing behaviour, so that relative urls +per default point to the intercept (e.g. your Django app), so long +as you don't browse away from that host. Further, you are allowed to +specify the target url as arguments to Django's ``reverse()``. + +Usage: + + from test_utils.utils import twill_runner as twill + twill.setup() + try: + twill.go('/') # --> Django WSGI + twill.code(200) + + twill.go('http://google.com') + twill.go('/services') # --> http://google.com/services + + twill.go('/list', default=True) # --> back to Django WSGI + + twill.go('proj.app.views.func', + args=[1,2,3]) + finally: + twill.teardown() + +For more information about twill, see: + http://twill.idyll.org/ +""" + +# allows us to import global twill as opposed to this module +from __future__ import absolute_import + +# TODO: import all names with a _-prefix to keep the namespace clean with the twill stuff? +import urlparse +import cookielib + +import twill +import twill.commands +import twill.browser + +from django.conf import settings +from django.core.servers.basehttp import AdminMediaHandler +from django.core.handlers.wsgi import WSGIHandler +from django.core.urlresolvers import reverse +from django.http import HttpRequest +from django.utils.datastructures import SortedDict +from django.contrib import auth + +from django.core.management.commands.test_windmill import ServerContainer, attempt_import + +# make available through this module +from twill.commands import * + +__all__ = ('INSTALLED', 'setup', 'teardown', 'reverse',) + tuple(twill.commands.__all__) + + +DEFAULT_HOST = '127.0.0.1' +DEFAULT_PORT = 9090 +INSTALLED = SortedDict() # keep track of the installed hooks + + +def setup(host=None, port=None, allow_xhtml=True, propagate=True): + """Install the WSGI hook for ``host`` and ``port``. + + The default values will be used if host or port are not specified. + + ``allow_xhtml`` enables a workaround for the "not viewer HTML" + error when browsing sites that are determined to be XHTML, e.g. + featuring xhtml-ish mimetypes. + + Unless ``propagate specifies otherwise``, the + ``DEBUG_PROPAGATE_EXCEPTIONS`` will be enabled for better debugging: + when using twill, we don't really want to see 500 error pages, + but rather directly the exceptions that occured on the view side. + + Multiple calls to this function will only result in one handler + for each host/port combination being installed. + """ + + host = host or DEFAULT_HOST + port = port or DEFAULT_PORT + key = (host, port) + + if not key in INSTALLED: + # installer wsgi handler + app = AdminMediaHandler(WSGIHandler()) + twill.add_wsgi_intercept(host, port, lambda: app) + + # start browser fresh + browser = get_browser() + browser.diverged = False + + # enable xhtml mode if requested + _enable_xhtml(browser, allow_xhtml) + + # init debug propagate setting, and remember old value + if propagate: + old_propgate_setting = settings.DEBUG_PROPAGATE_EXCEPTIONS + settings.DEBUG_PROPAGATE_EXCEPTIONS = True + else: + old_propgate_setting = None + + INSTALLED[key] = (app, old_propgate_setting) + return browser + return False + + +def teardown(host=None, port=None): + """Remove an installed WSGI hook for ``host`` and ```port``. + + If no host or port is passed, the default values will be assumed. + If no hook is installed for the defaults, and both the host and + port are missing, the last hook installed will be removed. + + Returns True if a hook was removed, otherwise False. + """ + + both_missing = not host and not port + host = host or DEFAULT_HOST + port = port or DEFAULT_PORT + key = (host, port) + + key_to_delete = None + if key in INSTALLED: + key_to_delete = key + if not key in INSTALLED and both_missing and len(INSTALLED) > 0: + host, port = key_to_delete = INSTALLED.keys()[-1] + + if key_to_delete: + _, old_propagate = INSTALLED[key_to_delete] + del INSTALLED[key_to_delete] + result = True + if old_propagate is not None: + settings.DEBUG_PROPAGATE_EXCEPTIONS = old_propagate + else: + result = False + + # note that our return value is just a guess according to our + # own records, we pass the request on to twill in any case + twill.remove_wsgi_intercept(host, port) + return result + + +def _enable_xhtml(browser, enable): + """Twill (darcs from 19-09-2008) does not work with documents + identifying themselves as XHTML. + + This is a workaround. + """ + factory = browser._browser._factory + factory.basic_factory._response_type_finder._allow_xhtml = \ + factory.soup_factory._response_type_finder._allow_xhtml = \ + enable + + +class _EasyTwillBrowser(twill.browser.TwillBrowser): + """Custom version of twill's browser class that defaults relative + URLs to the last installed hook, if available. + + It also supports reverse resolving, and some additional commands. + """ + + def __init__(self, *args, **kwargs): + self.diverged = False + self._testing_ = False + super(_EasyTwillBrowser, self).__init__(*args, **kwargs) + + def go(self, url, args=None, kwargs=None, default=None): + assert not ((args or kwargs) and default==False) + + if args or kwargs: + url = reverse(url, args=args, kwargs=kwargs) + default = True # default is implied + + if INSTALLED: + netloc = '%s:%s' % INSTALLED.keys()[-1] + urlbits = urlparse.urlsplit(url) + if not urlbits[0]: + if default: + # force "undiverge" + self.diverged = False + if not self.diverged: + url = urlparse.urlunsplit(('http', netloc)+urlbits[2:]) + else: + self.diverged = True + + if self._testing_: # hack that makes it simple for us to test this + return url + return super(_EasyTwillBrowser, self).go(url) + + def login(self, **credentials): + """Log the user with the given credentials into your Django + site. + + To further simplify things, rather than giving the credentials, + you may pass a ``user`` parameter with the ``User`` instance you + want to login. Note that in this case the user will not be + further validated, i.e. it is possible to login an inactive user + this way. + + This works regardless of the url currently browsed, but does + require the WSGI intercept to be setup. + + Returns ``True`` if login was possible; ``False`` if the + provided credentials are incorrect, or the user is inactive, + or if the sessions framework is not available. + + Based on ``django.test.client.Client.logout``. + + Note: A ``twill.reload()`` will not refresh the cookies sent + with the request, so your login will not have any effect there. + This is different for ``logout``, since it actually invalidates + the session server-side, thus making the current key invalid. + """ + + if not 'django.contrib.sessions' in settings.INSTALLED_APPS: + return False + + host, port = INSTALLED.keys()[-1] + + # determine the user we want to login + user = credentials.pop('user', None) + if user: + # Login expects the user object to reference it's backend. + # Since we're not going through ``authenticate``, we'll + # have to do this ourselves. + backend = auth.get_backends()[0] + user.backend = user.backend = "%s.%s" % ( + backend.__module__, backend.__class__.__name__) + else: + user = auth.authenticate(**credentials) + if not user or not user.is_active: + return False + + # create a fake request to use with ``auth.login`` + request = HttpRequest() + request.session = __import__(settings.SESSION_ENGINE, {}, {}, ['']).SessionStore() + auth.login(request, user) + request.session.save() + + # set the cookie to represent the session + self.cj.set_cookie(cookielib.Cookie( + version=None, + name=settings.SESSION_COOKIE_NAME, + value=request.session.session_key, + port=str(port), # must be a string + port_specified = False, + domain=host, #settings.SESSION_COOKIE_DOMAIN, + domain_specified=True, + domain_initial_dot=False, + path='/', + path_specified=True, + secure=settings.SESSION_COOKIE_SECURE or None, + expires=None, + discard=None, + comment=None, + comment_url=None, + rest=None + )) + + return True + + def logout(self): + """Log the current user out of your Django site. + + This works regardless of the url currently browsed, but does + require the WSGI intercept to be setup. + + Based on ``django.test.client.Client.logout``. + """ + host, port = INSTALLED.keys()[-1] + for cookie in self.cj: + if cookie.name == settings.SESSION_COOKIE_NAME \ + and cookie.domain==host \ + and (not cookie.port or str(cookie.port)==str(port)): + session = __import__(settings.SESSION_ENGINE, {}, {}, ['']).SessionStore() + session.delete(session_key=cookie.value) + self.cj.clear(cookie.domain, cookie.path, cookie.name) + return True + return False + + +def go(*args, **kwargs): + # replace the default ``go`` to make the additional + # arguments that our custom browser provides available. + browser = get_browser() + browser.go(*args, **kwargs) + return browser.get_url() + +def login(*args, **kwargs): + return get_browser().login(*args, **kwargs) + +def logout(*args, **kwargs): + return get_browser().logout(*args, **kwargs) + +def reset_browser(*args, **kwargs): + # replace the default ``reset_browser`` to ensure + # that our custom browser class is used + result = twill.commands.reset_browser(*args, **kwargs) + twill.commands.browser = _EasyTwillBrowser() + return result + +# Monkey-patch our custom browser into twill; this will be global, but +# will only have an actual effect when intercepts are installed through +# our module (via ``setup``). +# Unfortunately, twill pretty much forces us to use the same global +# state it does itself, lest us reimplement everything from +# ``twill.commands``. It's a bit of a shame, we could provide dedicated +# browser instances for each call to ``setup()``. +reset_browser() + + +def url(should_be=None): + """Like the default ``url()``, but can be called without arguments, + in which case it returns the current url. + """ + + if should_be is None: + return get_browser().get_url() + else: + return twill.commands.url(should_be) diff --git a/django/test/utils.py b/django/test/utils.py index d34dd33d15..8430ef8f10 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -5,6 +5,7 @@ from django.core import mail from django.test import signals from django.template import Template from django.utils.translation import deactivate +import inspect class ContextList(list): """A wrapper that provides direct key access to context items contained @@ -80,8 +81,18 @@ def teardown_test_environment(): del mail.outbox -def get_runner(settings): - test_path = settings.TEST_RUNNER.split('.') +def get_runner(settings, coverage = False, reports = False): + """ + Based on the settings and parameters, returns the appropriate test + runner class. + """ + if(coverage): + if(reports): + test_path = 'django.test.test_coverage.ReportingCoverageRunner'.split('.') + else: + test_path = settings.COVERAGE_TEST_RUNNER.split('.') + else: + test_path = settings.TEST_RUNNER.split('.') # Allow for Python 2.5 relative paths if len(test_path) > 1: test_module_name = '.'.join(test_path[:-1]) @@ -90,3 +101,9 @@ def get_runner(settings): test_module = __import__(test_module_name, {}, {}, test_path[-1]) test_runner = getattr(test_module, test_path[-1]) return test_runner + +def calling_func_name(): + """ + Inspect's on the stack to determine the calling functions name. + """ + return inspect.stack()[1][3]
\ No newline at end of file diff --git a/django/test/windmill_tests.py b/django/test/windmill_tests.py new file mode 100644 index 0000000000..3212e9c1e8 --- /dev/null +++ b/django/test/windmill_tests.py @@ -0,0 +1,137 @@ + +# Code from django_live_server_r8458.diff @ http://code.djangoproject.com/ticket/2879#comment:41 +# Editing to monkey patch django rather than be in trunk + +import socket +import threading +from django.core.handlers.wsgi import WSGIHandler +from django.core.servers import basehttp +from django.test.testcases import call_command + +#from django.core.management import call_command + +# support both django 1.0 and 1.1 +try: + from django.test.testcases import TransactionTestCase as TestCase +except ImportError: + from django.test.testcases import TestCase + +try: + from windmill.authoring import unit +except Exception, e: + print "You don't appear to have windmill installed, please install before trying to run windmill tests again." + unit = None + +class StoppableWSGIServer(basehttp.WSGIServer): + """WSGIServer with short timeout, so that server thread can stop this server.""" + + def server_bind(self): + """Sets timeout to 1 second.""" + basehttp.WSGIServer.server_bind(self) + self.socket.settimeout(1) + + def get_request(self): + """Checks for timeout when getting request.""" + try: + sock, address = self.socket.accept() + sock.settimeout(None) + return (sock, address) + except socket.timeout: + raise + +class TestServerThread(threading.Thread): + """Thread for running a http server while tests are running.""" + + def __init__(self, address, port): + self.address = address + self.port = port + self._stopevent = threading.Event() + self.started = threading.Event() + self.error = None + super(TestServerThread, self).__init__() + + def run(self): + """Sets up test server and database and loops over handling http requests.""" + + # Must do database stuff in this new thread if database in memory. + from django.conf import settings + #if settings.DATABASE_ENGINE == 'sqlite3' \ + # and (not settings.TEST_DATABASE_NAME or settings.TEST_DATABASE_NAME == ':memory:'): + from django.db import connection + print 'Creating test DB' + db_name = connection.creation.create_test_db(0,autoclobber=True) + #call_command('syncdb', 0, 0) + # Import the fixture data into the test database. + if hasattr(self, 'fixtures'): + print 'Loading fixtures.' + # We have to use this slightly awkward syntax due to the fact + # that we're using *args and **kwargs together. + call_command('loaddata', *self.fixtures, **{'verbosity': 0}) + + try: + print "running thread" + handler = basehttp.AdminMediaHandler(WSGIHandler()) + httpd = None + while httpd is None: + try: + server_address = (self.address, self.port) + httpd = StoppableWSGIServer(server_address, basehttp.WSGIRequestHandler) + except basehttp.WSGIServerException, e: + if "Address already in use" in str(e): + print "Address already in use" + self.port +=1 + else: + raise e + httpd.set_app(handler) + self.started.set() + except basehttp.WSGIServerException, e: + self.error = e + self.started.set() + return + + + # Loop until we get a stop event. + while not self._stopevent.isSet(): + httpd.handle_request() + httpd.server_close() + + def join(self, timeout=None): + """Stop the thread and wait for it to finish.""" + self._stopevent.set() + threading.Thread.join(self, timeout) + + +def start_test_server(self, address='localhost', port=8000): + """Creates a live test server object (instance of WSGIServer).""" + self.server_thread = TestServerThread(address, port) + if hasattr(self, 'fixtures'): + print 'loading fixtures %s' % self.fixtures + self.server_thread.__setattr__('fixtures', self.fixtures) + self.server_thread.start() + self.server_thread.started.wait() + if self.server_thread.error: + raise self.server_thread.error + return self.server_thread.started + +def stop_test_server(self): + if self.server_thread: + self.server_thread.join() + +## New Code + +TestCase.start_test_server = classmethod(start_test_server) +TestCase.stop_test_server = classmethod(stop_test_server) + + +class WindmillDjangoUnitTest(TestCase, unit.WindmillUnitTestCase): + test_port = 8000 + def setUp(self): + self.start_test_server('localhost', self.test_port) + self.test_url = 'http://localhost:%d' % self.server_thread.port + unit.WindmillUnitTestCase.setUp(self) + + def tearDown(self): + unit.WindmillUnitTestCase.tearDown(self) + self.stop_test_server() + +WindmillDjangoTransactionUnitTest = WindmillDjangoUnitTest diff --git a/django/utils/module_tools/__init__.py b/django/utils/module_tools/__init__.py new file mode 100644 index 0000000000..976d4b5e4f --- /dev/null +++ b/django/utils/module_tools/__init__.py @@ -0,0 +1,3 @@ +from module_loader import * +from module_walker import * + diff --git a/django/utils/module_tools/data_storage.py b/django/utils/module_tools/data_storage.py new file mode 100644 index 0000000000..aed5980e6b --- /dev/null +++ b/django/utils/module_tools/data_storage.py @@ -0,0 +1,42 @@ +""" +Copyright 2009 55 Minutes (http://www.55minutes.com) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +__all__ = ('Packages', 'Modules', 'Excluded', 'Errors') + +class SingletonType(type): + def __call__(cls, *args, **kwargs): + if getattr(cls, '__instance__', None) is None: + instance = cls.__new__(cls) + instance.__init__(*args, **kwargs) + cls.__instance__ = instance + return cls.__instance__ + +class Packages(object): + __metaclass__ = SingletonType + packages = {} + +class Modules(object): + __metaclass__ = SingletonType + modules = {} + +class Excluded(object): + __metaclass__ = SingletonType + excluded = [] + +class Errors(object): + __metaclass__ = SingletonType + errors = [] + diff --git a/django/utils/module_tools/module_loader.py b/django/utils/module_tools/module_loader.py new file mode 100644 index 0000000000..e6dd6ce4b8 --- /dev/null +++ b/django/utils/module_tools/module_loader.py @@ -0,0 +1,79 @@ +""" +Copyright 2009 55 Minutes (http://www.55minutes.com) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import imp, sys, types + +__all__ = ('find_or_load_module',) + +def _brute_force_find_module(module_name, module_path, module_type): + for m in [m for n, m in sys.modules.iteritems() if type(m) == types.ModuleType]: + m_path = [] + try: + if module_type in (imp.PY_COMPILED, imp.PY_SOURCE): + m_path = [m.__file__] + elif module_type==imp.PKG_DIRECTORY: + m_path = m.__path__ + except AttributeError: + pass + for p in m_path: + if p.startswith(module_path): + return m + return None + +def _load_module(module_name, fo, fp, desc): + suffix, mode, mtype = desc + if module_name in sys.modules and \ + sys.modules[module_name].__file__.startswith(fp): + module = sys.modules[module_name] + else: + module = _brute_force_find_module(module_name, fp, mtype) + if not module: + try: + module = imp.load_module(module_name, fo, fp, desc) + except: + raise ImportError + return module + +def _load_package(pkg_name, fp, desc): + suffix, mode, mtype = desc + if pkg_name in sys.modules: + if fp in sys.modules[pkg_name].__path__: + pkg = sys.modules[pkg_name] + else: + pkg = _brute_force_find_module(pkg_name, fp, mtype) + if not pkg: + pkg = imp.load_module(pkg_name, None, fp, desc) + return pkg + +def find_or_load_module(module_name, path=None): + """ + Attempts to lookup ``module_name`` in ``sys.modules``, else uses the + facilities in the ``imp`` module to load the module. + + If module_name specified is not of type ``imp.PY_SOURCE`` or + ``imp.PKG_DIRECTORY``, raise ``ImportError`` since we don't know + what to do with those. + """ + fo, fp, desc = imp.find_module(module_name.split('.')[-1], path) + suffix, mode, mtype = desc + if mtype in (imp.PY_SOURCE, imp.PY_COMPILED): + module = _load_module(module_name, fo, fp, desc) + elif mtype==imp.PKG_DIRECTORY: + module = _load_package(module_name, fp, desc) + else: + raise ImportError("Don't know how to handle this module type.") + return module + diff --git a/django/utils/module_tools/module_walker.py b/django/utils/module_tools/module_walker.py new file mode 100644 index 0000000000..442150e689 --- /dev/null +++ b/django/utils/module_tools/module_walker.py @@ -0,0 +1,135 @@ +""" +Copyright 2009 55 Minutes (http://www.55minutes.com) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os, re, sys +from glob import glob + +from data_storage import * +from module_loader import find_or_load_module + +try: + set +except: + from sets import Set as set + +__all__ = ('get_all_modules',) + +def _build_pkg_path(pkg_name, pkg, path): + for rp in [x for x in pkg.__path__ if path.startswith(x)]: + p = path.replace(rp, '').replace(os.path.sep, '.') + return pkg_name + p + +def _build_module_path(pkg_name, pkg, path): + return _build_pkg_path(pkg_name, pkg, os.path.splitext(path)[0]) + +def _prune_whitelist(whitelist, blacklist): + excluded = Excluded().excluded + + for wp in whitelist[:]: + for bp in blacklist: + if re.search(bp, wp): + whitelist.remove(wp) + excluded.append(wp) + break + return whitelist + +def _parse_module_list(m_list): + packages = Packages().packages + modules = Modules().modules + excluded = Excluded().excluded + errors = Errors().errors + + for m in m_list: + components = m.split('.') + m_name = '' + search_path = [] + processed=False + for i, c in enumerate(components): + m_name = '.'.join([x for x in m_name.split('.') if x] + [c]) + try: + module = find_or_load_module(m_name, search_path or None) + except ImportError: + processed=True + errors.append(m) + break + try: + search_path.extend(module.__path__) + except AttributeError: + processed = True + if i+1==len(components): + modules[m_name] = module + else: + errors.append(m) + break + if not processed: + packages[m_name] = module + +def prune_dirs(root, dirs, exclude_dirs): + _dirs = [os.path.join(root, d) for d in dirs] + for i, p in enumerate(_dirs): + for e in exclude_dirs: + if re.search(e, p): + del dirs[i] + break + +def _get_all_packages(pkg_name, pkg, blacklist, exclude_dirs): + packages = Packages().packages + errors = Errors().errors + + for path in pkg.__path__: + for root, dirs, files in os.walk(path): + prune_dirs(root, dirs, exclude_dirs or []) + m_name = _build_pkg_path(pkg_name, pkg, root) + try: + if _prune_whitelist([m_name], blacklist): + m = find_or_load_module(m_name, [os.path.split(root)[0]]) + packages[m_name] = m + else: + for d in dirs[:]: + dirs.remove(d) + except ImportError: + errors.append(m_name) + for d in dirs[:]: + dirs.remove(d) + +def _get_all_modules(pkg_name, pkg, blacklist): + modules = Modules().modules + errors = Errors().errors + + for p in pkg.__path__: + for f in glob('%s/*.py' %p): + m_name = _build_module_path(pkg_name, pkg, f) + try: + if _prune_whitelist([m_name], blacklist): + m = find_or_load_module(m_name, [p]) + modules[m_name] = m + except ImportError: + errors.append(m_name) + +def get_all_modules(whitelist, blacklist=None, exclude_dirs=None): + packages = Packages().packages + modules = Modules().modules + excluded = Excluded().excluded + errors = Errors().errors + + whitelist = _prune_whitelist(whitelist, blacklist or []) + _parse_module_list(whitelist) + for pkg_name, pkg in packages.copy().iteritems(): + _get_all_packages(pkg_name, pkg, blacklist, exclude_dirs) + for pkg_name, pkg in packages.copy().iteritems(): + _get_all_modules(pkg_name, pkg, blacklist) + return packages, modules, list(set(excluded)), list(set(errors)) + |
