Coverage for python / lsst / daf / butler / tests / postgresql.py: 50%
48 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 08:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 08:55 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28from __future__ import annotations
30import gc
31import secrets
32import unittest
33from collections.abc import Iterator
34from contextlib import contextmanager
36import sqlalchemy
38from .._butler_config import ButlerConfig
39from .._config import Config
41try:
42 from testing.postgresql import Postgresql
43except ImportError:
44 Postgresql = None
47@contextmanager
48def setup_postgres_test_db() -> Iterator[TemporaryPostgresInstance]:
49 """Set up a temporary postgres instance that can be used for testing the
50 Butler.
51 """
52 if Postgresql is None:
53 raise unittest.SkipTest("testing.postgresql module not available.")
55 with Postgresql() as server:
56 engine = sqlalchemy.engine.create_engine(server.url())
57 instance = TemporaryPostgresInstance(server, engine)
58 with instance.begin() as connection:
59 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;"))
61 yield instance
63 # Clean up any lingering SQLAlchemy engines/connections
64 # so they're closed before we shut down the server.
65 gc.collect()
66 engine.dispose()
69class TemporaryPostgresInstance: # numpydoc ignore=PR01
70 """Wrapper for a temporary postgres database with utilities for connecting
71 a Butler to it.
72 """
74 def __init__(self, server: Postgresql, engine: sqlalchemy.Engine) -> None:
75 self._server = server
76 self._engine = engine
78 @property
79 def url(self) -> str:
80 """Return connection URL for the temporary database server."""
81 return self._server.url()
83 @contextmanager
84 def begin(self) -> Iterator[sqlalchemy.Connection]:
85 """Return a SQLAlchemy connection to the test database."""
86 with self._engine.begin() as connection:
87 yield connection
89 def patch_butler_config(self, config: ButlerConfig | Config) -> None: # numpydoc ignore=PR01
90 """Modify a butler configuration in-place to point the registry to the
91 temporary database in a new empty namespace.
92 """
93 config["registry", "db"] = self.url
94 config["registry", "namespace"] = self.generate_namespace_name()
96 def patch_registry_config(self, config: Config) -> None: # numpydoc ignore=PR01
97 """Modify a registry configuration in-place to point the database
98 connection to the temporary database in a new empty namespace.
99 """
100 config["db"] = self.url
101 config["namespace"] = self.generate_namespace_name()
103 def generate_namespace_name(self) -> str:
104 """Return a unique namespace name that can be used to separate the data
105 from multiple tests.
106 """
107 return f"namespace_{secrets.token_hex(8).lower()}"
109 def server_major_version(self) -> int:
110 """Return the major version number of the Postgres server (e.g. 13 or
111 16).
112 """
113 from ..registry.databases.postgresql import get_postgres_server_version
115 with self.begin() as connection:
116 return get_postgres_server_version(connection)[0]