Coverage for tests/test_postgresql.py: 39%

115 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-10-02 07:59 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30import gc 

31import itertools 

32import os 

33import secrets 

34import unittest 

35import warnings 

36from contextlib import contextmanager 

37 

38import astropy.time 

39 

40try: 

41 # It's possible but silly to have testing.postgresql installed without 

42 # having the postgresql server installed (because then nothing in 

43 # testing.postgresql would work), so we use the presence of that module 

44 # to test whether we can expect the server to be available. 

45 import testing.postgresql 

46except ImportError: 

47 testing = None 

48 

49import sqlalchemy 

50from lsst.daf.butler import Timespan, ddl 

51from lsst.daf.butler.registry import _ButlerRegistry, _RegistryFactory 

52 

53try: 

54 from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType 

55except ImportError: 

56 testing = None 

57from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests 

58from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

59 

60TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

61 

62 

63def _startServer(root): 

64 """Start a PostgreSQL server and create a database within it, returning 

65 an object encapsulating both. 

66 """ 

67 server = testing.postgresql.Postgresql(base_dir=root) 

68 engine = sqlalchemy.engine.create_engine(server.url()) 

69 with engine.begin() as connection: 

70 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;")) 

71 return server 

72 

73 

74@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

75class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests): 

76 """Test a postgres Registry.""" 

77 

78 @classmethod 

79 def setUpClass(cls): 

80 cls.root = makeTestTempDir(TESTDIR) 

81 cls.server = _startServer(cls.root) 

82 

83 @classmethod 

84 def tearDownClass(cls): 

85 # Clean up any lingering SQLAlchemy engines/connections 

86 # so they're closed before we shut down the server. 

87 gc.collect() 

88 cls.server.stop() 

89 removeTestTempDir(cls.root) 

90 

91 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase: 

92 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

93 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace) 

94 

95 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase: 

96 return PostgresqlDatabase.fromUri( 

97 origin=database.origin, uri=self.server.url(), namespace=database.namespace, writeable=writeable 

98 ) 

99 

100 @contextmanager 

101 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase: 

102 yield self.getNewConnection(database, writeable=False) 

103 

104 def testNameShrinking(self): 

105 """Test that too-long names for database entities other than tables 

106 and columns (which we preserve, and just expect to fit) are shrunk. 

107 """ 

108 db = self.makeEmptyDatabase(origin=1) 

109 with db.declareStaticTables(create=True) as context: 

110 # Table and field names are each below the 63-char limit even when 

111 # accounting for the prefix, but their combination (which will 

112 # appear in sequences and constraints) is not. 

113 tableName = "a_table_with_a_very_very_long_42_char_name" 

114 fieldName1 = "a_column_with_a_very_very_long_43_char_name" 

115 fieldName2 = "another_column_with_a_very_very_long_49_char_name" 

116 context.addTable( 

117 tableName, 

118 ddl.TableSpec( 

119 fields=[ 

120 ddl.FieldSpec( 

121 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

122 ), 

123 ddl.FieldSpec( 

124 fieldName2, 

125 dtype=sqlalchemy.String, 

126 length=16, 

127 nullable=False, 

128 ), 

129 ], 

130 unique={(fieldName2,)}, 

131 ), 

132 ) 

133 # Add another table, this time dynamically, with a foreign key to the 

134 # first table. 

135 db.ensureTableExists( 

136 tableName + "_b", 

137 ddl.TableSpec( 

138 fields=[ 

139 ddl.FieldSpec( 

140 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

141 ), 

142 ddl.FieldSpec( 

143 fieldName2 + "_b", 

144 dtype=sqlalchemy.String, 

145 length=16, 

146 nullable=False, 

147 ), 

148 ], 

149 foreignKeys=[ 

150 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)), 

151 ], 

152 ), 

153 ) 

154 

155 def test_RangeTimespanType(self): 

156 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

157 offset = astropy.time.TimeDelta(60, format="sec") 

158 timestamps = [start + offset * n for n in range(3)] 

159 timespans = [Timespan(begin=None, end=None)] 

160 timespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

161 timespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

162 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2)) 

163 db = self.makeEmptyDatabase(origin=1) 

164 with db.declareStaticTables(create=True) as context: 

165 tbl = context.addTable( 

166 "tbl", 

167 ddl.TableSpec( 

168 fields=[ 

169 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

170 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType), 

171 ], 

172 ), 

173 ) 

174 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)] 

175 db.insert(tbl, *rows) 

176 

177 # Test basic round-trip through database. 

178 with db.query(tbl.select().order_by(tbl.columns.id)) as sql_result: 

179 self.assertEqual(rows, [row._asdict() for row in sql_result]) 

180 

181 # Test that Timespan's Python methods are consistent with our usage of 

182 # half-open ranges and PostgreSQL operators on ranges. 

183 def subquery(alias: str) -> sqlalchemy.sql.FromClause: 

184 return ( 

185 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")) 

186 .select_from(tbl) 

187 .alias(alias) 

188 ) 

189 

190 sq1 = subquery("sq1") 

191 sq2 = subquery("sq2") 

192 query = sqlalchemy.sql.select( 

193 sq1.columns.id.label("n1"), 

194 sq2.columns.id.label("n2"), 

195 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"), 

196 ) 

197 

198 # `columns` is deprecated since 1.4, but 

199 # `selected_columns` method did not exist in 1.3. 

200 if hasattr(query, "selected_columns"): 

201 columns = query.selected_columns 

202 else: 

203 columns = query.columns 

204 

205 # SQLAlchemy issues a warning about cartesian product of two tables, 

206 # which we do intentionally. Disable that warning temporarily. 

207 with warnings.catch_warnings(): 

208 warnings.filterwarnings( 

209 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning 

210 ) 

211 with db.query(query) as sql_result: 

212 dbResults = { 

213 (row[columns.n1], row[columns.n2]): row[columns.overlaps] for row in sql_result.mappings() 

214 } 

215 

216 pyResults = { 

217 (n1, n2): t1.overlaps(t2) 

218 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans)) 

219 } 

220 self.assertEqual(pyResults, dbResults) 

221 

222 

223@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

224class PostgresqlRegistryTests(RegistryTests): 

225 """Tests for `Registry` backed by a PostgreSQL database. 

226 

227 Notes 

228 ----- 

229 This is not a subclass of `unittest.TestCase` but to avoid repetition it 

230 defines methods that override `unittest.TestCase` methods. To make this 

231 work subclasses have to have this class first in the bases list. 

232 """ 

233 

234 @classmethod 

235 def setUpClass(cls): 

236 cls.root = makeTestTempDir(TESTDIR) 

237 cls.server = _startServer(cls.root) 

238 

239 @classmethod 

240 def tearDownClass(cls): 

241 # Clean up any lingering SQLAlchemy engines/connections 

242 # so they're closed before we shut down the server. 

243 gc.collect() 

244 cls.server.stop() 

245 removeTestTempDir(cls.root) 

246 

247 @classmethod 

248 def getDataDir(cls) -> str: 

249 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry")) 

250 

251 def makeRegistry(self, share_repo_with: _ButlerRegistry | None = None) -> _ButlerRegistry: 

252 if share_repo_with is None: 

253 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

254 else: 

255 namespace = share_repo_with._db.namespace 

256 config = self.makeRegistryConfig() 

257 config["db"] = self.server.url() 

258 config["namespace"] = namespace 

259 if share_repo_with is None: 

260 return _RegistryFactory(config).create_from_config() 

261 else: 

262 return _RegistryFactory(config).from_config() 

263 

264 

265class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

266 """Tests for `Registry` backed by a PostgreSQL database. 

267 

268 This test case uses NameKeyCollectionManager and 

269 ByDimensionsDatasetRecordStorageManagerUUID. 

270 """ 

271 

272 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager" 

273 datasetsManager = ( 

274 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

275 ) 

276 

277 

278class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

279 """Tests for `Registry` backed by a PostgreSQL database. 

280 

281 This test case uses SynthIntKeyCollectionManager and 

282 ByDimensionsDatasetRecordStorageManagerUUID. 

283 """ 

284 

285 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager" 

286 datasetsManager = ( 

287 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

288 ) 

289 

290 

291if __name__ == "__main__": 

292 unittest.main()