Skip to content

Commit

Permalink
Add postgresql schema support. (#507)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mark Bakhit <archiethemonger@gmail.com>
  • Loading branch information
3 people authored Mar 6, 2024
1 parent 6c11a67 commit 87952dc
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 8 deletions.
27 changes: 26 additions & 1 deletion dbbackup/db/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import List, Optional
from urllib.parse import quote

from .base import BaseCommandDBConnector
Expand Down Expand Up @@ -37,16 +38,23 @@ class PgDumpConnector(BaseCommandDBConnector):
restore_cmd = "psql"
single_transaction = True
drop = True
schemas: Optional[List[str]] = []

def _create_dump(self):
cmd = f"{self.dump_cmd} "
cmd = cmd + create_postgres_uri(self)

for table in self.exclude:
cmd += f" --exclude-table-data={table}"

if self.drop:
cmd += " --clean"

if self.schemas:
# First schema is not prefixed with -n
# when using join function so add it manually.
cmd += " -n " + " -n ".join(self.schemas)

cmd = f"{self.dump_prefix} {cmd} {self.dump_suffix}"
stdout, stderr = self.run_command(cmd, env=self.dump_env)
return stdout
Expand All @@ -57,8 +65,13 @@ def _restore_dump(self, dump):

# without this, psql terminates with an exit value of 0 regardless of errors
cmd += " --set ON_ERROR_STOP=on"

if self.schemas:
cmd += " -n " + " -n ".join(self.schemas)

if self.single_transaction:
cmd += " --single-transaction"

cmd += " {}".format(self.settings["NAME"])
cmd = f"{self.restore_prefix} {cmd} {self.restore_suffix}"
stdout, stderr = self.run_command(cmd, stdin=dump, env=self.restore_env)
Expand All @@ -77,10 +90,13 @@ def _enable_postgis(self):
cmd = f'{self.psql_cmd} -c "CREATE EXTENSION IF NOT EXISTS postgis;"'
cmd += " --username={}".format(self.settings["ADMIN_USER"])
cmd += " --no-password"

if self.settings.get("HOST"):
cmd += " --host={}".format(self.settings["HOST"])

if self.settings.get("PORT"):
cmd += " --port={}".format(self.settings["PORT"])

return self.run_command(cmd)

def _restore_dump(self, dump):
Expand Down Expand Up @@ -108,8 +124,12 @@ def _create_dump(self):
cmd += " --format=custom"
for table in self.exclude:
cmd += f" --exclude-table-data={table}"

if self.schemas:
cmd += " -n " + " -n ".join(self.schemas)

cmd = f"{self.dump_prefix} {cmd} {self.dump_suffix}"
stdout, stderr = self.run_command(cmd, env=self.dump_env)
stdout, _ = self.run_command(cmd, env=self.dump_env)
return stdout

def _restore_dump(self, dump):
Expand All @@ -118,8 +138,13 @@ def _restore_dump(self, dump):

if self.single_transaction:
cmd += " --single-transaction"

if self.drop:
cmd += " --clean"

if self.schemas:
cmd += " -n " + " -n ".join(self.schemas)

cmd = f"{self.restore_prefix} {cmd} {self.restore_suffix}"
stdout, stderr = self.run_command(cmd, stdin=dump, env=self.restore_env)
return stdout, stderr
22 changes: 20 additions & 2 deletions dbbackup/management/commands/dbbackup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class Command(BaseDbBackupCommand):
help = "Backup a database, encrypt and/or compress and write to " "storage." ""
help = "Backup a database, encrypt and/or compress."
content_type = "db"

option_list = BaseDbBackupCommand.option_list + (
Expand Down Expand Up @@ -60,6 +60,13 @@ class Command(BaseDbBackupCommand):
make_option(
"-x", "--exclude-tables", default=None, help="Exclude tables from backup"
),
make_option(
"-n",
"--schema",
action="append",
default=[],
help="Specify schema(s) to backup. Can be used multiple times.",
),
)

@utils.email_uncaught_exception
Expand All @@ -78,6 +85,7 @@ def handle(self, **options):
self.path = options.get("output_path")
self.exclude_tables = options.get("exclude_tables")
self.storage = get_storage()
self.schemas = options.get("schema")

self.database = options.get("database") or ""

Expand All @@ -103,22 +111,32 @@ def _save_new_backup(self, database):
Save a new backup file.
"""
self.logger.info("Backing Up Database: %s", database["NAME"])
# Get backup and name
# Get backup, schema and name
filename = self.connector.generate_filename(self.servername)

if self.schemas:
self.connector.schemas = self.schemas

outputfile = self.connector.create_dump()

# Apply trans
if self.compress:
compressed_file, filename = utils.compress_file(outputfile, filename)
outputfile = compressed_file

if self.encrypt:
encrypted_file, filename = utils.encrypt_file(outputfile, filename)
outputfile = encrypted_file

# Set file name
filename = self.filename or filename
self.logger.debug("Backup size: %s", utils.handle_size(outputfile))

# Store backup
outputfile.seek(0)

if self.path is None:
self.write_to_storage(outputfile, filename)

else:
self.write_local_file(outputfile, self.path)
20 changes: 18 additions & 2 deletions dbbackup/management/commands/dbrestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@


class Command(BaseDbBackupCommand):
help = """Restore a database backup from storage, encrypted and/or
compressed."""
help = "Restore a database backup from storage, encrypted and/or compressed."
content_type = "db"

option_list = BaseDbBackupCommand.option_list + (
Expand Down Expand Up @@ -46,6 +45,13 @@ class Command(BaseDbBackupCommand):
default=False,
help="Uncompress gzip data before restoring",
),
make_option(
"-n",
"--schema",
action="append",
default=[],
help="Specify schema(s) to restore. Can be used multiple times.",
),
)

def handle(self, *args, **options):
Expand All @@ -68,6 +74,7 @@ def handle(self, *args, **options):
self.input_database_name
)
self.storage = get_storage()
self.schemas = options.get("schema")
self._restore_backup()
except StorageError as err:
raise CommandError(err) from err
Expand All @@ -91,11 +98,16 @@ def _restore_backup(self):
input_filename, input_file = self._get_backup_file(
database=self.input_database_name, servername=self.servername
)

self.logger.info(
"Restoring backup for database '%s' and server '%s'",
self.database_name,
self.servername,
)

if self.schemas:
self.logger.info(f"Restoring schemas: {self.schemas}")

self.logger.info(f"Restoring: {input_filename}")

if self.decrypt:
Expand All @@ -117,4 +129,8 @@ def _restore_backup(self):

input_file.seek(0)
self.connector = get_connector(self.database_name)

if self.schemas:
self.connector.schemas = self.schemas

self.connector.restore_dump(input_file)
10 changes: 9 additions & 1 deletion dbbackup/tests/commands/test_dbbackup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"""

import os
from unittest.mock import patch

from django.test import TestCase
from mock import patch

from dbbackup.db.base import get_connector
from dbbackup.management.commands.dbbackup import Command as DbbackupCommand
Expand All @@ -27,6 +27,7 @@ def setUp(self):
self.command.stdout = DEV_NULL
self.command.filename = None
self.command.path = None
self.command.schemas = []

def tearDown(self):
clean_gpg_keys()
Expand All @@ -50,6 +51,12 @@ def test_path(self):
# tearDown
os.remove(self.command.path)

def test_schema(self):
self.command.schemas = ["public"]
result = self.command._save_new_backup(TEST_DATABASE)

self.assertIsNone(result)

@patch("dbbackup.settings.DATABASES", ["db-from-settings"])
def test_get_database_keys(self):
with self.subTest("use --database from CLI"):
Expand All @@ -76,6 +83,7 @@ def setUp(self):
self.command.filename = None
self.command.path = None
self.command.connector = get_connector("default")
self.command.schemas = []

def tearDown(self):
clean_gpg_keys()
Expand Down
44 changes: 43 additions & 1 deletion dbbackup/tests/commands/test_dbrestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@

from shutil import copyfileobj
from tempfile import mktemp
from unittest.mock import patch

from django.conf import settings
from django.core.files import File
from django.core.management.base import CommandError
from django.test import TestCase
from mock import patch

from dbbackup import utils
from dbbackup.db.base import get_connector
from dbbackup.db.mongodb import MongoDumpConnector
from dbbackup.db.postgresql import PgDumpConnector
from dbbackup.management.commands.dbrestore import Command as DbrestoreCommand
from dbbackup.settings import HOSTNAME
from dbbackup.storage import get_storage
Expand Down Expand Up @@ -47,6 +48,7 @@ def setUp(self):
self.command.input_database_name = None
self.command.database_name = "default"
self.command.connector = get_connector("default")
self.command.schemas = []
HANDLED_FILES.clean()

def tearDown(self):
Expand Down Expand Up @@ -103,6 +105,45 @@ def test_path(self, *args):
HANDLED_FILES["written_files"].append((self.command.filepath, get_dump()))
self.command._restore_backup()

@patch("dbbackup.management.commands.dbrestore.get_connector")
@patch("dbbackup.db.base.BaseDBConnector.restore_dump")
def test_schema(self, mock_restore_dump, mock_get_connector, *args):
"""Schema is only used for postgresql."""
mock_get_connector.return_value = PgDumpConnector()
mock_restore_dump.return_value = True

mock_file = File(get_dump())
HANDLED_FILES["written_files"].append((self.command.filename, mock_file))

with self.assertLogs("dbbackup.command", "INFO") as cm:
# Without
self.command.path = None
self.command._restore_backup()
self.assertEqual(self.command.connector.schemas, [])

# With
self.command.path = None
self.command.schemas = ["public"]
self.command._restore_backup()
self.assertEqual(self.command.connector.schemas, ["public"])
self.assertIn(
"INFO:dbbackup.command:Restoring schemas: ['public']",
cm.output,
)

# With multiple
self.command.path = None
self.command.schemas = ["public", "other"]
self.command._restore_backup()
self.assertEqual(self.command.connector.schemas, ["public", "other"])
self.assertIn(
"INFO:dbbackup.command:Restoring schemas: ['public', 'other']",
cm.output,
)

mock_get_connector.assert_called_with("default")
mock_restore_dump.assert_called_with(mock_file)


class DbrestoreCommandGetDatabaseTest(TestCase):
def setUp(self):
Expand Down Expand Up @@ -147,6 +188,7 @@ def setUp(self):
self.command.database_name = "mongo"
self.command.input_database_name = None
self.command.servername = HOSTNAME
self.command.schemas = []
HANDLED_FILES.clean()
add_private_gpg()

Expand Down
Loading

0 comments on commit 87952dc

Please sign in to comment.