summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Michel Vourgère <nirgal@debian.org>2015-06-17 15:37:09 +0200
committerTim Graham <timograham@gmail.com>2015-06-30 18:21:51 -0400
commitb64c0d4d613b5cabedbc9b847682fe14877537de (patch)
treeac3ea4e0c72ccfe41e4c314944d4636791c30da0
parenteecd42ea7d97bce04bc909c71bed14850060c39c (diff)
Fixed #23658 -- Provided the password to PostgreSQL dbshell command
The password from settings.py is written in a temporary .pgpass file file whose name is given to psql using the PGPASSFILE environment variable.
-rw-r--r--django/db/backends/postgresql_psycopg2/client.py67
-rw-r--r--docs/releases/1.9.txt4
-rw-r--r--tests/dbshell/test_postgresql_psycopg2.py117
3 files changed, 178 insertions, 10 deletions
diff --git a/django/db/backends/postgresql_psycopg2/client.py b/django/db/backends/postgresql_psycopg2/client.py
index aa60e58943..5e3e288301 100644
--- a/django/db/backends/postgresql_psycopg2/client.py
+++ b/django/db/backends/postgresql_psycopg2/client.py
@@ -1,19 +1,66 @@
+import os
import subprocess
+from django.core.files.temp import NamedTemporaryFile
from django.db.backends.base.client import BaseDatabaseClient
+from django.utils.six import print_
+
+
+def _escape_pgpass(txt):
+ """
+ Escape a fragment of a PostgreSQL .pgpass file.
+ """
+ return txt.replace('\\', '\\\\').replace(':', '\\:')
class DatabaseClient(BaseDatabaseClient):
executable_name = 'psql'
+ @classmethod
+ def runshell_db(cls, settings_dict):
+ args = [cls.executable_name]
+
+ host = settings_dict.get('HOST', '')
+ port = settings_dict.get('PORT', '')
+ name = settings_dict.get('NAME', '')
+ user = settings_dict.get('USER', '')
+ passwd = settings_dict.get('PASSWORD', '')
+
+ if user:
+ args += ['-U', user]
+ if host:
+ args += ['-h', host]
+ if port:
+ args += ['-p', str(port)]
+ args += [name]
+
+ temp_pgpass = None
+ try:
+ if passwd:
+ # Create temporary .pgpass file.
+ temp_pgpass = NamedTemporaryFile(mode='w+')
+ try:
+ print_(
+ _escape_pgpass(host) or '*',
+ str(port) or '*',
+ _escape_pgpass(name) or '*',
+ _escape_pgpass(user) or '*',
+ _escape_pgpass(passwd),
+ file=temp_pgpass,
+ sep=':',
+ flush=True,
+ )
+ os.environ['PGPASSFILE'] = temp_pgpass.name
+ except UnicodeEncodeError:
+ # If the current locale can't encode the data, we let
+ # the user input the password manually.
+ pass
+ subprocess.call(args)
+ finally:
+ if temp_pgpass:
+ temp_pgpass.close()
+ if 'PGPASSFILE' in os.environ: # unit tests need cleanup
+ del os.environ['PGPASSFILE']
+
def runshell(self):
- settings_dict = self.connection.settings_dict
- args = [self.executable_name]
- if settings_dict['USER']:
- args += ["-U", settings_dict['USER']]
- if settings_dict['HOST']:
- args.extend(["-h", settings_dict['HOST']])
- if settings_dict['PORT']:
- args.extend(["-p", str(settings_dict['PORT'])])
- args += [settings_dict['NAME']]
- subprocess.call(args)
+ DatabaseClient.runshell_db(self.connection.settings_dict)
diff --git a/docs/releases/1.9.txt b/docs/releases/1.9.txt
index 600c88ca2c..9f2dd6ccd1 100644
--- a/docs/releases/1.9.txt
+++ b/docs/releases/1.9.txt
@@ -350,6 +350,10 @@ Management Commands
* The :djadmin:`startapp` command creates an ``apps.py`` file and adds
``default_app_config`` in ``__init__.py``.
+* When using the PostgreSQL backend, the :djadmin:`dbshell` command can connect
+ to the database using the password from your settings file (instead of
+ requiring it to be manually entered).
+
Models
^^^^^^
diff --git a/tests/dbshell/test_postgresql_psycopg2.py b/tests/dbshell/test_postgresql_psycopg2.py
new file mode 100644
index 0000000000..aecbba7f42
--- /dev/null
+++ b/tests/dbshell/test_postgresql_psycopg2.py
@@ -0,0 +1,117 @@
+# -*- coding: utf8 -*-
+from __future__ import unicode_literals
+
+import locale
+import os
+
+from django.db.backends.postgresql_psycopg2.client import DatabaseClient
+from django.test import SimpleTestCase, mock
+from django.utils import six
+from django.utils.encoding import force_bytes, force_str
+
+
+class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
+
+ def _run_it(self, dbinfo):
+ """
+ That function invokes the runshell command, while mocking
+ subprocess.call. It returns a 2-tuple with:
+ - The command line list
+ - The binary content of file pointed by environment PGPASSFILE, or
+ None.
+ """
+ def _mock_subprocess_call(*args):
+ self.subprocess_args = list(*args)
+ if 'PGPASSFILE' in os.environ:
+ self.pgpass = open(os.environ['PGPASSFILE'], 'rb').read()
+ else:
+ self.pgpass = None
+ return 0
+ self.subprocess_args = None
+ self.pgpass = None
+ with mock.patch('subprocess.call', new=_mock_subprocess_call):
+ DatabaseClient.runshell_db(dbinfo)
+ return self.subprocess_args, self.pgpass
+
+ def test_basic(self):
+ self.assertEqual(
+ self._run_it({
+ 'NAME': 'dbname',
+ 'USER': 'someuser',
+ 'PASSWORD': 'somepassword',
+ 'HOST': 'somehost',
+ 'PORT': 444,
+ }), (
+ ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
+ b'somehost:444:dbname:someuser:somepassword\n',
+ )
+ )
+
+ def test_nopass(self):
+ self.assertEqual(
+ self._run_it({
+ 'NAME': 'dbname',
+ 'USER': 'someuser',
+ 'HOST': 'somehost',
+ 'PORT': 444,
+ }), (
+ ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
+ None,
+ )
+ )
+
+ def test_column(self):
+ self.assertEqual(
+ self._run_it({
+ 'NAME': 'dbname',
+ 'USER': 'some:user',
+ 'PASSWORD': 'some:password',
+ 'HOST': '::1',
+ 'PORT': 444,
+ }), (
+ ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'],
+ b'\\:\\:1:444:dbname:some\\:user:some\\:password\n',
+ )
+ )
+
+ def test_escape_characters(self):
+ self.assertEqual(
+ self._run_it({
+ 'NAME': 'dbname',
+ 'USER': 'some\\user',
+ 'PASSWORD': 'some\\password',
+ 'HOST': 'somehost',
+ 'PORT': 444,
+ }), (
+ ['psql', '-U', 'some\\user', '-h', 'somehost', '-p', '444', 'dbname'],
+ b'somehost:444:dbname:some\\\\user:some\\\\password\n',
+ )
+ )
+
+ def test_accent(self):
+ # The pgpass temporary file needs to be encoded using the system locale.
+ encoding = locale.getpreferredencoding()
+ username = 'rôle'
+ password = 'sésame'
+ try:
+ username_str = force_str(username, encoding)
+ password_str = force_str(password, encoding)
+ pgpass_bytes = force_bytes(
+ 'somehost:444:dbname:%s:%s\n' % (username, password),
+ encoding=encoding,
+ )
+ except UnicodeEncodeError:
+ if six.PY2:
+ self.skipTest("Your locale can't run this test.")
+ self.assertEqual(
+ self._run_it({
+ 'NAME': 'dbname',
+ 'USER': username_str,
+ 'PASSWORD': password_str,
+ 'HOST': 'somehost',
+ 'PORT': 444,
+ }), (
+ ['psql', '-U', username_str, '-h', 'somehost', '-p', '444', 'dbname'],
+ pgpass_bytes,
+ )
+ )