Coverage for tests/test_postgresql.py: 39%

115 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-05 01:25 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import gc 

25import itertools 

26import os 

27import secrets 

28import unittest 

29import warnings 

30from contextlib import contextmanager 

31 

32import astropy.time 

33 

34try: 

35 # It's possible but silly to have testing.postgresql installed without 

36 # having the postgresql server installed (because then nothing in 

37 # testing.postgresql would work), so we use the presence of that module 

38 # to test whether we can expect the server to be available. 

39 import testing.postgresql 

40except ImportError: 

41 testing = None 

42 

43import sqlalchemy 

44from lsst.daf.butler import Timespan, ddl 

45from lsst.daf.butler.registry import _ButlerRegistry, _RegistryFactory 

46 

47try: 

48 from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType 

49except ImportError: 

50 testing = None 

51from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests 

52from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

53 

54TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

55 

56 

57def _startServer(root): 

58 """Start a PostgreSQL server and create a database within it, returning 

59 an object encapsulating both. 

60 """ 

61 server = testing.postgresql.Postgresql(base_dir=root) 

62 engine = sqlalchemy.engine.create_engine(server.url()) 

63 with engine.begin() as connection: 

64 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;")) 

65 return server 

66 

67 

68@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

69class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests): 

70 """Test a postgres Registry.""" 

71 

72 @classmethod 

73 def setUpClass(cls): 

74 cls.root = makeTestTempDir(TESTDIR) 

75 cls.server = _startServer(cls.root) 

76 

77 @classmethod 

78 def tearDownClass(cls): 

79 # Clean up any lingering SQLAlchemy engines/connections 

80 # so they're closed before we shut down the server. 

81 gc.collect() 

82 cls.server.stop() 

83 removeTestTempDir(cls.root) 

84 

85 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase: 

86 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

87 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace) 

88 

89 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase: 

90 return PostgresqlDatabase.fromUri( 

91 origin=database.origin, uri=self.server.url(), namespace=database.namespace, writeable=writeable 

92 ) 

93 

94 @contextmanager 

95 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase: 

96 yield self.getNewConnection(database, writeable=False) 

97 

98 def testNameShrinking(self): 

99 """Test that too-long names for database entities other than tables 

100 and columns (which we preserve, and just expect to fit) are shrunk. 

101 """ 

102 db = self.makeEmptyDatabase(origin=1) 

103 with db.declareStaticTables(create=True) as context: 

104 # Table and field names are each below the 63-char limit even when 

105 # accounting for the prefix, but their combination (which will 

106 # appear in sequences and constraints) is not. 

107 tableName = "a_table_with_a_very_very_long_42_char_name" 

108 fieldName1 = "a_column_with_a_very_very_long_43_char_name" 

109 fieldName2 = "another_column_with_a_very_very_long_49_char_name" 

110 context.addTable( 

111 tableName, 

112 ddl.TableSpec( 

113 fields=[ 

114 ddl.FieldSpec( 

115 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

116 ), 

117 ddl.FieldSpec( 

118 fieldName2, 

119 dtype=sqlalchemy.String, 

120 length=16, 

121 nullable=False, 

122 ), 

123 ], 

124 unique={(fieldName2,)}, 

125 ), 

126 ) 

127 # Add another table, this time dynamically, with a foreign key to the 

128 # first table. 

129 db.ensureTableExists( 

130 tableName + "_b", 

131 ddl.TableSpec( 

132 fields=[ 

133 ddl.FieldSpec( 

134 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

135 ), 

136 ddl.FieldSpec( 

137 fieldName2 + "_b", 

138 dtype=sqlalchemy.String, 

139 length=16, 

140 nullable=False, 

141 ), 

142 ], 

143 foreignKeys=[ 

144 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)), 

145 ], 

146 ), 

147 ) 

148 

149 def test_RangeTimespanType(self): 

150 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

151 offset = astropy.time.TimeDelta(60, format="sec") 

152 timestamps = [start + offset * n for n in range(3)] 

153 timespans = [Timespan(begin=None, end=None)] 

154 timespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

155 timespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

156 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2)) 

157 db = self.makeEmptyDatabase(origin=1) 

158 with db.declareStaticTables(create=True) as context: 

159 tbl = context.addTable( 

160 "tbl", 

161 ddl.TableSpec( 

162 fields=[ 

163 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

164 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType), 

165 ], 

166 ), 

167 ) 

168 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)] 

169 db.insert(tbl, *rows) 

170 

171 # Test basic round-trip through database. 

172 with db.query(tbl.select().order_by(tbl.columns.id)) as sql_result: 

173 self.assertEqual(rows, [row._asdict() for row in sql_result]) 

174 

175 # Test that Timespan's Python methods are consistent with our usage of 

176 # half-open ranges and PostgreSQL operators on ranges. 

177 def subquery(alias: str) -> sqlalchemy.sql.FromClause: 

178 return ( 

179 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")) 

180 .select_from(tbl) 

181 .alias(alias) 

182 ) 

183 

184 sq1 = subquery("sq1") 

185 sq2 = subquery("sq2") 

186 query = sqlalchemy.sql.select( 

187 sq1.columns.id.label("n1"), 

188 sq2.columns.id.label("n2"), 

189 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"), 

190 ) 

191 

192 # `columns` is deprecated since 1.4, but 

193 # `selected_columns` method did not exist in 1.3. 

194 if hasattr(query, "selected_columns"): 

195 columns = query.selected_columns 

196 else: 

197 columns = query.columns 

198 

199 # SQLAlchemy issues a warning about cartesian product of two tables, 

200 # which we do intentionally. Disable that warning temporarily. 

201 with warnings.catch_warnings(): 

202 warnings.filterwarnings( 

203 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning 

204 ) 

205 with db.query(query) as sql_result: 

206 dbResults = { 

207 (row[columns.n1], row[columns.n2]): row[columns.overlaps] for row in sql_result.mappings() 

208 } 

209 

210 pyResults = { 

211 (n1, n2): t1.overlaps(t2) 

212 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans)) 

213 } 

214 self.assertEqual(pyResults, dbResults) 

215 

216 

217@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

218class PostgresqlRegistryTests(RegistryTests): 

219 """Tests for `Registry` backed by a PostgreSQL database. 

220 

221 Notes 

222 ----- 

223 This is not a subclass of `unittest.TestCase` but to avoid repetition it 

224 defines methods that override `unittest.TestCase` methods. To make this 

225 work subclasses have to have this class first in the bases list. 

226 """ 

227 

228 @classmethod 

229 def setUpClass(cls): 

230 cls.root = makeTestTempDir(TESTDIR) 

231 cls.server = _startServer(cls.root) 

232 

233 @classmethod 

234 def tearDownClass(cls): 

235 # Clean up any lingering SQLAlchemy engines/connections 

236 # so they're closed before we shut down the server. 

237 gc.collect() 

238 cls.server.stop() 

239 removeTestTempDir(cls.root) 

240 

241 @classmethod 

242 def getDataDir(cls) -> str: 

243 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry")) 

244 

245 def makeRegistry(self, share_repo_with: _ButlerRegistry | None = None) -> _ButlerRegistry: 

246 if share_repo_with is None: 

247 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

248 else: 

249 namespace = share_repo_with._db.namespace 

250 config = self.makeRegistryConfig() 

251 config["db"] = self.server.url() 

252 config["namespace"] = namespace 

253 if share_repo_with is None: 

254 return _RegistryFactory(config).create_from_config() 

255 else: 

256 return _RegistryFactory(config).from_config() 

257 

258 

259class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

260 """Tests for `Registry` backed by a PostgreSQL database. 

261 

262 This test case uses NameKeyCollectionManager and 

263 ByDimensionsDatasetRecordStorageManagerUUID. 

264 """ 

265 

266 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager" 

267 datasetsManager = ( 

268 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

269 ) 

270 

271 

272class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

273 """Tests for `Registry` backed by a PostgreSQL database. 

274 

275 This test case uses SynthIntKeyCollectionManager and 

276 ByDimensionsDatasetRecordStorageManagerUUID. 

277 """ 

278 

279 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager" 

280 datasetsManager = ( 

281 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

282 ) 

283 

284 

285if __name__ == "__main__": 

286 unittest.main()