Coverage for tests/test_postgresql.py: 35%

114 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-26 15:13 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import gc 

23import itertools 

24import os 

25import secrets 

26import unittest 

27import warnings 

28from contextlib import contextmanager 

29from typing import Optional 

30 

31import astropy.time 

32 

33try: 

34 # It's possible but silly to have testing.postgresql installed without 

35 # having the postgresql server installed (because then nothing in 

36 # testing.postgresql would work), so we use the presence of that module 

37 # to test whether we can expect the server to be available. 

38 import testing.postgresql 

39except ImportError: 

40 testing = None 

41 

42import sqlalchemy 

43from lsst.daf.butler import Timespan, ddl 

44from lsst.daf.butler.registry import Registry 

45from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType 

46from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests 

47from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

48 

49TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

50 

51 

52def _startServer(root): 

53 """Start a PostgreSQL server and create a database within it, returning 

54 an object encapsulating both. 

55 """ 

56 server = testing.postgresql.Postgresql(base_dir=root) 

57 engine = sqlalchemy.engine.create_engine(server.url()) 

58 with engine.begin() as connection: 

59 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;")) 

60 return server 

61 

62 

63@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

64class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests): 

65 @classmethod 

66 def setUpClass(cls): 

67 cls.root = makeTestTempDir(TESTDIR) 

68 cls.server = _startServer(cls.root) 

69 

70 @classmethod 

71 def tearDownClass(cls): 

72 # Clean up any lingering SQLAlchemy engines/connections 

73 # so they're closed before we shut down the server. 

74 gc.collect() 

75 cls.server.stop() 

76 removeTestTempDir(cls.root) 

77 

78 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase: 

79 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

80 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace) 

81 

82 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase: 

83 return PostgresqlDatabase.fromUri( 

84 origin=database.origin, uri=self.server.url(), namespace=database.namespace, writeable=writeable 

85 ) 

86 

87 @contextmanager 

88 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase: 

89 yield self.getNewConnection(database, writeable=False) 

90 

91 def testNameShrinking(self): 

92 """Test that too-long names for database entities other than tables 

93 and columns (which we preserve, and just expect to fit) are shrunk. 

94 """ 

95 db = self.makeEmptyDatabase(origin=1) 

96 with db.declareStaticTables(create=True) as context: 

97 # Table and field names are each below the 63-char limit even when 

98 # accounting for the prefix, but their combination (which will 

99 # appear in sequences and constraints) is not. 

100 tableName = "a_table_with_a_very_very_long_42_char_name" 

101 fieldName1 = "a_column_with_a_very_very_long_43_char_name" 

102 fieldName2 = "another_column_with_a_very_very_long_49_char_name" 

103 context.addTable( 

104 tableName, 

105 ddl.TableSpec( 

106 fields=[ 

107 ddl.FieldSpec( 

108 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

109 ), 

110 ddl.FieldSpec( 

111 fieldName2, 

112 dtype=sqlalchemy.String, 

113 length=16, 

114 nullable=False, 

115 ), 

116 ], 

117 unique={(fieldName2,)}, 

118 ), 

119 ) 

120 # Add another table, this time dynamically, with a foreign key to the 

121 # first table. 

122 db.ensureTableExists( 

123 tableName + "_b", 

124 ddl.TableSpec( 

125 fields=[ 

126 ddl.FieldSpec( 

127 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

128 ), 

129 ddl.FieldSpec( 

130 fieldName2 + "_b", 

131 dtype=sqlalchemy.String, 

132 length=16, 

133 nullable=False, 

134 ), 

135 ], 

136 foreignKeys=[ 

137 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)), 

138 ], 

139 ), 

140 ) 

141 

142 def test_RangeTimespanType(self): 

143 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

144 offset = astropy.time.TimeDelta(60, format="sec") 

145 timestamps = [start + offset * n for n in range(3)] 

146 timespans = [Timespan(begin=None, end=None)] 

147 timespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

148 timespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

149 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2)) 

150 db = self.makeEmptyDatabase(origin=1) 

151 with db.declareStaticTables(create=True) as context: 

152 tbl = context.addTable( 

153 "tbl", 

154 ddl.TableSpec( 

155 fields=[ 

156 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

157 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType), 

158 ], 

159 ), 

160 ) 

161 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)] 

162 db.insert(tbl, *rows) 

163 

164 # Test basic round-trip through database. 

165 with db.query(tbl.select().order_by(tbl.columns.id)) as sql_result: 

166 self.assertEqual(rows, [row._asdict() for row in sql_result]) 

167 

168 # Test that Timespan's Python methods are consistent with our usage of 

169 # half-open ranges and PostgreSQL operators on ranges. 

170 def subquery(alias: str) -> sqlalchemy.sql.FromClause: 

171 return ( 

172 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")) 

173 .select_from(tbl) 

174 .alias(alias) 

175 ) 

176 

177 sq1 = subquery("sq1") 

178 sq2 = subquery("sq2") 

179 query = sqlalchemy.sql.select( 

180 sq1.columns.id.label("n1"), 

181 sq2.columns.id.label("n2"), 

182 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"), 

183 ) 

184 

185 # `columns` is deprecated since 1.4, but 

186 # `selected_columns` method did not exist in 1.3. 

187 if hasattr(query, "selected_columns"): 

188 columns = query.selected_columns 

189 else: 

190 columns = query.columns 

191 

192 # SQLAlchemy issues a warning about cartesian product of two tables, 

193 # which we do intentionally. Disable that warning temporarily. 

194 with warnings.catch_warnings(): 

195 warnings.filterwarnings( 

196 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning 

197 ) 

198 with db.query(query) as sql_result: 

199 dbResults = { 

200 (row[columns.n1], row[columns.n2]): row[columns.overlaps] for row in sql_result.mappings() 

201 } 

202 

203 pyResults = { 

204 (n1, n2): t1.overlaps(t2) 

205 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans)) 

206 } 

207 self.assertEqual(pyResults, dbResults) 

208 

209 

210@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

211class PostgresqlRegistryTests(RegistryTests): 

212 """Tests for `Registry` backed by a PostgreSQL database. 

213 

214 Note 

215 ---- 

216 This is not a subclass of `unittest.TestCase` but to avoid repetition it 

217 defines methods that override `unittest.TestCase` methods. To make this 

218 work sublasses have to have this class first in the bases list. 

219 """ 

220 

221 @classmethod 

222 def setUpClass(cls): 

223 cls.root = makeTestTempDir(TESTDIR) 

224 cls.server = _startServer(cls.root) 

225 

226 @classmethod 

227 def tearDownClass(cls): 

228 # Clean up any lingering SQLAlchemy engines/connections 

229 # so they're closed before we shut down the server. 

230 gc.collect() 

231 cls.server.stop() 

232 removeTestTempDir(cls.root) 

233 

234 @classmethod 

235 def getDataDir(cls) -> str: 

236 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry")) 

237 

238 def makeRegistry(self, share_repo_with: Optional[Registry] = None) -> Registry: 

239 if share_repo_with is None: 

240 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

241 else: 

242 namespace = share_repo_with._db.namespace 

243 config = self.makeRegistryConfig() 

244 config["db"] = self.server.url() 

245 config["namespace"] = namespace 

246 if share_repo_with is None: 

247 return Registry.createFromConfig(config) 

248 else: 

249 return Registry.fromConfig(config) 

250 

251 

252class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

253 """Tests for `Registry` backed by a PostgreSQL database. 

254 

255 This test case uses NameKeyCollectionManager and 

256 ByDimensionsDatasetRecordStorageManagerUUID. 

257 """ 

258 

259 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager" 

260 datasetsManager = ( 

261 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

262 ) 

263 

264 

265class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

266 """Tests for `Registry` backed by a PostgreSQL database. 

267 

268 This test case uses SynthIntKeyCollectionManager and 

269 ByDimensionsDatasetRecordStorageManagerUUID. 

270 """ 

271 

272 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager" 

273 datasetsManager = ( 

274 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

275 ) 

276 

277 

278if __name__ == "__main__": 278 ↛ 279line 278 didn't jump to line 279, because the condition on line 278 was never true

279 unittest.main()