Coverage for tests/test_postgresql.py: 40%

116 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-15 02:02 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30import gc 

31import itertools 

32import os 

33import secrets 

34import unittest 

35import warnings 

36from contextlib import contextmanager 

37 

38import astropy.time 

39 

40try: 

41 # It's possible but silly to have testing.postgresql installed without 

42 # having the postgresql server installed (because then nothing in 

43 # testing.postgresql would work), so we use the presence of that module 

44 # to test whether we can expect the server to be available. 

45 import testing.postgresql 

46except ImportError: 

47 testing = None 

48 

49import sqlalchemy 

50from lsst.daf.butler import Timespan, ddl 

51from lsst.daf.butler.registry import _RegistryFactory 

52from lsst.daf.butler.registry.sql_registry import SqlRegistry 

53 

54try: 

55 from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType 

56except ImportError: 

57 testing = None 

58from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests 

59from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

60 

61TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

62 

63 

64def _startServer(root): 

65 """Start a PostgreSQL server and create a database within it, returning 

66 an object encapsulating both. 

67 """ 

68 server = testing.postgresql.Postgresql(base_dir=root) 

69 engine = sqlalchemy.engine.create_engine(server.url()) 

70 with engine.begin() as connection: 

71 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;")) 

72 return server 

73 

74 

75@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

76class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests): 

77 """Test a postgres Registry.""" 

78 

79 @classmethod 

80 def setUpClass(cls): 

81 cls.root = makeTestTempDir(TESTDIR) 

82 cls.server = _startServer(cls.root) 

83 

84 @classmethod 

85 def tearDownClass(cls): 

86 # Clean up any lingering SQLAlchemy engines/connections 

87 # so they're closed before we shut down the server. 

88 gc.collect() 

89 cls.server.stop() 

90 removeTestTempDir(cls.root) 

91 

92 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase: 

93 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

94 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace) 

95 

96 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase: 

97 return PostgresqlDatabase.fromUri( 

98 origin=database.origin, uri=self.server.url(), namespace=database.namespace, writeable=writeable 

99 ) 

100 

101 @contextmanager 

102 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase: 

103 yield self.getNewConnection(database, writeable=False) 

104 

105 def testNameShrinking(self): 

106 """Test that too-long names for database entities other than tables 

107 and columns (which we preserve, and just expect to fit) are shrunk. 

108 """ 

109 db = self.makeEmptyDatabase(origin=1) 

110 with db.declareStaticTables(create=True) as context: 

111 # Table and field names are each below the 63-char limit even when 

112 # accounting for the prefix, but their combination (which will 

113 # appear in sequences and constraints) is not. 

114 tableName = "a_table_with_a_very_very_long_42_char_name" 

115 fieldName1 = "a_column_with_a_very_very_long_43_char_name" 

116 fieldName2 = "another_column_with_a_very_very_long_49_char_name" 

117 context.addTable( 

118 tableName, 

119 ddl.TableSpec( 

120 fields=[ 

121 ddl.FieldSpec( 

122 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

123 ), 

124 ddl.FieldSpec( 

125 fieldName2, 

126 dtype=sqlalchemy.String, 

127 length=16, 

128 nullable=False, 

129 ), 

130 ], 

131 unique={(fieldName2,)}, 

132 ), 

133 ) 

134 # Add another table, this time dynamically, with a foreign key to the 

135 # first table. 

136 db.ensureTableExists( 

137 tableName + "_b", 

138 ddl.TableSpec( 

139 fields=[ 

140 ddl.FieldSpec( 

141 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

142 ), 

143 ddl.FieldSpec( 

144 fieldName2 + "_b", 

145 dtype=sqlalchemy.String, 

146 length=16, 

147 nullable=False, 

148 ), 

149 ], 

150 foreignKeys=[ 

151 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)), 

152 ], 

153 ), 

154 ) 

155 

156 def test_RangeTimespanType(self): 

157 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

158 offset = astropy.time.TimeDelta(60, format="sec") 

159 timestamps = [start + offset * n for n in range(3)] 

160 timespans = [Timespan(begin=None, end=None)] 

161 timespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

162 timespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

163 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2)) 

164 db = self.makeEmptyDatabase(origin=1) 

165 with db.declareStaticTables(create=True) as context: 

166 tbl = context.addTable( 

167 "tbl", 

168 ddl.TableSpec( 

169 fields=[ 

170 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

171 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType), 

172 ], 

173 ), 

174 ) 

175 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)] 

176 db.insert(tbl, *rows) 

177 

178 # Test basic round-trip through database. 

179 with db.query(tbl.select().order_by(tbl.columns.id)) as sql_result: 

180 self.assertEqual(rows, [row._asdict() for row in sql_result]) 

181 

182 # Test that Timespan's Python methods are consistent with our usage of 

183 # half-open ranges and PostgreSQL operators on ranges. 

184 def subquery(alias: str) -> sqlalchemy.sql.FromClause: 

185 return ( 

186 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")) 

187 .select_from(tbl) 

188 .alias(alias) 

189 ) 

190 

191 sq1 = subquery("sq1") 

192 sq2 = subquery("sq2") 

193 query = sqlalchemy.sql.select( 

194 sq1.columns.id.label("n1"), 

195 sq2.columns.id.label("n2"), 

196 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"), 

197 ) 

198 

199 # `columns` is deprecated since 1.4, but 

200 # `selected_columns` method did not exist in 1.3. 

201 if hasattr(query, "selected_columns"): 

202 columns = query.selected_columns 

203 else: 

204 columns = query.columns 

205 

206 # SQLAlchemy issues a warning about cartesian product of two tables, 

207 # which we do intentionally. Disable that warning temporarily. 

208 with warnings.catch_warnings(): 

209 warnings.filterwarnings( 

210 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning 

211 ) 

212 with db.query(query) as sql_result: 

213 dbResults = { 

214 (row[columns.n1], row[columns.n2]): row[columns.overlaps] for row in sql_result.mappings() 

215 } 

216 

217 pyResults = { 

218 (n1, n2): t1.overlaps(t2) 

219 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans)) 

220 } 

221 self.assertEqual(pyResults, dbResults) 

222 

223 

224@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

225class PostgresqlRegistryTests(RegistryTests): 

226 """Tests for `Registry` backed by a PostgreSQL database. 

227 

228 Notes 

229 ----- 

230 This is not a subclass of `unittest.TestCase` but to avoid repetition it 

231 defines methods that override `unittest.TestCase` methods. To make this 

232 work subclasses have to have this class first in the bases list. 

233 """ 

234 

235 @classmethod 

236 def setUpClass(cls): 

237 cls.root = makeTestTempDir(TESTDIR) 

238 cls.server = _startServer(cls.root) 

239 

240 @classmethod 

241 def tearDownClass(cls): 

242 # Clean up any lingering SQLAlchemy engines/connections 

243 # so they're closed before we shut down the server. 

244 gc.collect() 

245 cls.server.stop() 

246 removeTestTempDir(cls.root) 

247 

248 @classmethod 

249 def getDataDir(cls) -> str: 

250 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry")) 

251 

252 def makeRegistry(self, share_repo_with: SqlRegistry | None = None) -> SqlRegistry: 

253 if share_repo_with is None: 

254 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

255 else: 

256 namespace = share_repo_with._db.namespace 

257 config = self.makeRegistryConfig() 

258 config["db"] = self.server.url() 

259 config["namespace"] = namespace 

260 if share_repo_with is None: 

261 return _RegistryFactory(config).create_from_config() 

262 else: 

263 return _RegistryFactory(config).from_config() 

264 

265 

266class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

267 """Tests for `Registry` backed by a PostgreSQL database. 

268 

269 This test case uses NameKeyCollectionManager and 

270 ByDimensionsDatasetRecordStorageManagerUUID. 

271 """ 

272 

273 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager" 

274 datasetsManager = ( 

275 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

276 ) 

277 

278 

279class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

280 """Tests for `Registry` backed by a PostgreSQL database. 

281 

282 This test case uses SynthIntKeyCollectionManager and 

283 ByDimensionsDatasetRecordStorageManagerUUID. 

284 """ 

285 

286 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager" 

287 datasetsManager = ( 

288 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

289 ) 

290 

291 

292if __name__ == "__main__": 

293 unittest.main()