Coverage for tests / test_postgresql.py: 40%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 08:18 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30import itertools 

31import os 

32import unittest 

33import warnings 

34from contextlib import contextmanager 

35from typing import cast 

36 

37import astropy.time 

38import sqlalchemy 

39 

40from lsst.daf.butler import Butler, ButlerConfig, StorageClassFactory, Timespan, ddl 

41from lsst.daf.butler.datastore import NullDatastore 

42from lsst.daf.butler.direct_butler import DirectButler 

43from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory 

44from lsst.daf.butler.tests.postgresql import setup_postgres_test_db 

45 

46try: 

47 from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType 

48except ImportError: 

49 PostgresqlDatabase = None 

50from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests 

51 

52TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

53 

54 

55@unittest.skipUnless(PostgresqlDatabase is not None, "Couldn't load PostgresqlDatabase") 

56class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests): 

57 """Test a postgres Registry.""" 

58 

59 @classmethod 

60 def setUpClass(cls): 

61 super().setUpClass() 

62 cls.postgres = cls.enterClassContext(setup_postgres_test_db()) 

63 

64 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase: 

65 return PostgresqlDatabase.fromUri( 

66 origin=origin, uri=self.postgres.url, namespace=self.postgres.generate_namespace_name() 

67 ) 

68 

69 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase: 

70 return PostgresqlDatabase.fromUri( 

71 origin=database.origin, uri=self.postgres.url, namespace=database.namespace, writeable=writeable 

72 ) 

73 

74 @contextmanager 

75 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase: 

76 yield self.getNewConnection(database, writeable=False) 

77 

78 def testNameShrinking(self): 

79 """Test that too-long names for database entities other than tables 

80 and columns (which we preserve, and just expect to fit) are shrunk. 

81 """ 

82 db = self.makeEmptyDatabase(origin=1) 

83 with db.declareStaticTables(create=True) as context: 

84 # Table and field names are each below the 63-char limit even when 

85 # accounting for the prefix, but their combination (which will 

86 # appear in sequences and constraints) is not. 

87 tableName = "a_table_with_a_very_very_long_42_char_name" 

88 fieldName1 = "a_column_with_a_very_very_long_43_char_name" 

89 fieldName2 = "another_column_with_a_very_very_long_49_char_name" 

90 context.addTable( 

91 tableName, 

92 ddl.TableSpec( 

93 fields=[ 

94 ddl.FieldSpec( 

95 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

96 ), 

97 ddl.FieldSpec( 

98 fieldName2, 

99 dtype=sqlalchemy.String, 

100 length=16, 

101 nullable=False, 

102 ), 

103 ], 

104 unique={(fieldName2,)}, 

105 ), 

106 ) 

107 # Add another table, this time dynamically, with a foreign key to the 

108 # first table. 

109 db.ensureTableExists( 

110 tableName + "_b", 

111 ddl.TableSpec( 

112 fields=[ 

113 ddl.FieldSpec( 

114 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True 

115 ), 

116 ddl.FieldSpec( 

117 fieldName2 + "_b", 

118 dtype=sqlalchemy.String, 

119 length=16, 

120 nullable=False, 

121 ), 

122 ], 

123 foreignKeys=[ 

124 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)), 

125 ], 

126 ), 

127 ) 

128 

129 def test_RangeTimespanType(self): 

130 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

131 offset = astropy.time.TimeDelta(60, format="sec") 

132 timestamps = [start + offset * n for n in range(3)] 

133 timespans = [Timespan(begin=None, end=None)] 

134 timespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

135 timespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

136 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2)) 

137 db = self.makeEmptyDatabase(origin=1) 

138 with db.declareStaticTables(create=True) as context: 

139 tbl = context.addTable( 

140 "tbl", 

141 ddl.TableSpec( 

142 fields=[ 

143 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

144 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType), 

145 ], 

146 ), 

147 ) 

148 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)] 

149 db.insert(tbl, *rows) 

150 

151 # Test basic round-trip through database. 

152 with db.query(tbl.select().order_by(tbl.columns.id)) as sql_result: 

153 self.assertEqual(rows, [row._asdict() for row in sql_result]) 

154 

155 # Test that Timespan's Python methods are consistent with our usage of 

156 # half-open ranges and PostgreSQL operators on ranges. 

157 def subquery(alias: str) -> sqlalchemy.sql.FromClause: 

158 return ( 

159 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")) 

160 .select_from(tbl) 

161 .alias(alias) 

162 ) 

163 

164 sq1 = subquery("sq1") 

165 sq2 = subquery("sq2") 

166 query = sqlalchemy.sql.select( 

167 sq1.columns.id.label("n1"), 

168 sq2.columns.id.label("n2"), 

169 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"), 

170 ) 

171 

172 # `columns` is deprecated since 1.4, but 

173 # `selected_columns` method did not exist in 1.3. 

174 if hasattr(query, "selected_columns"): 

175 columns = query.selected_columns 

176 else: 

177 columns = query.columns 

178 

179 # SQLAlchemy issues a warning about cartesian product of two tables, 

180 # which we do intentionally. Disable that warning temporarily. 

181 with warnings.catch_warnings(): 

182 warnings.filterwarnings( 

183 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning 

184 ) 

185 with db.query(query) as sql_result: 

186 dbResults = { 

187 (row[columns.n1], row[columns.n2]): row[columns.overlaps] for row in sql_result.mappings() 

188 } 

189 

190 pyResults = { 

191 (n1, n2): t1.overlaps(t2) 

192 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans)) 

193 } 

194 self.assertEqual(pyResults, dbResults) 

195 

196 

197class PostgresqlRegistryTests(RegistryTests): 

198 """Tests for `Registry` backed by a PostgreSQL database. 

199 

200 Notes 

201 ----- 

202 This is not a subclass of `unittest.TestCase` but to avoid repetition it 

203 defines methods that override `unittest.TestCase` methods. To make this 

204 work subclasses have to have this class first in the bases list. 

205 """ 

206 

207 @classmethod 

208 def setUpClass(cls): 

209 super().setUpClass() 

210 cls.postgres = cls.enterClassContext(setup_postgres_test_db()) 

211 

212 @classmethod 

213 def getDataDir(cls) -> str: 

214 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry")) 

215 

216 def make_butler(self, config: RegistryConfig | None = None) -> Butler: 

217 if config is None: 

218 config = self.makeRegistryConfig() 

219 self.postgres.patch_registry_config(config) 

220 registry = _RegistryFactory(config).create_from_config() 

221 

222 butler = DirectButler( 

223 config=ButlerConfig(), 

224 registry=registry, 

225 datastore=NullDatastore(None, None), 

226 storageClasses=StorageClassFactory(), 

227 ) 

228 cast(unittest.TestCase, self).enterContext(butler) 

229 

230 return butler 

231 

232 def testSkipCalibs(self): 

233 if self.postgres.server_major_version() < 16: 

234 # TODO DM-44875: This test currently fails for older Postgres. 

235 self.skipTest("TODO DM-44875") 

236 return super().testSkipCalibs() 

237 

238 

239class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

240 """Tests for `Registry` backed by a PostgreSQL database. 

241 

242 This test case uses NameKeyCollectionManager and 

243 ByDimensionsDatasetRecordStorageManagerUUID. 

244 """ 

245 

246 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager" 

247 datasetsManager = ( 

248 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

249 ) 

250 

251 

252class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase): 

253 """Tests for `Registry` backed by a PostgreSQL database. 

254 

255 This test case uses SynthIntKeyCollectionManager and 

256 ByDimensionsDatasetRecordStorageManagerUUID. 

257 """ 

258 

259 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager" 

260 datasetsManager = ( 

261 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID" 

262 ) 

263 

264 

265if __name__ == "__main__": 

266 unittest.main()