Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23from contextlib import contextmanager 

24import itertools 

25import secrets 

26import unittest 

27import gc 

28 

29import astropy.time 

30try: 

31 # It's possible but silly to have testing.postgresql installed without 

32 # having the postgresql server installed (because then nothing in 

33 # testing.postgresql would work), so we use the presence of that module 

34 # to test whether we can expect the server to be available. 

35 import testing.postgresql 

36except ImportError: 

37 testing = None 

38 

39import sqlalchemy 

40 

41from lsst.daf.butler import ddl, Timespan 

42from lsst.daf.butler.registry import Registry 

43from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType 

44from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests 

45from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

46 

47TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

48 

49 

50def _startServer(root): 

51 """Start a PostgreSQL server and create a database within it, returning 

52 an object encapsulating both. 

53 """ 

54 server = testing.postgresql.Postgresql(base_dir=root) 

55 engine = sqlalchemy.engine.create_engine(server.url()) 

56 engine.execute("CREATE EXTENSION btree_gist;") 

57 return server 

58 

59 

60@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

61class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests): 

62 

63 @classmethod 

64 def setUpClass(cls): 

65 cls.root = makeTestTempDir(TESTDIR) 

66 cls.server = _startServer(cls.root) 

67 

68 @classmethod 

69 def tearDownClass(cls): 

70 # Clean up any lingering SQLAlchemy engines/connections 

71 # so they're closed before we shut down the server. 

72 gc.collect() 

73 cls.server.stop() 

74 removeTestTempDir(cls.root) 

75 

76 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase: 

77 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

78 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace) 

79 

80 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase: 

81 return PostgresqlDatabase.fromUri(origin=database.origin, uri=self.server.url(), 

82 namespace=database.namespace, writeable=writeable) 

83 

84 @contextmanager 

85 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase: 

86 yield self.getNewConnection(database, writeable=False) 

87 

88 def testNameShrinking(self): 

89 """Test that too-long names for database entities other than tables 

90 and columns (which we preserve, and just expect to fit) are shrunk. 

91 """ 

92 db = self.makeEmptyDatabase(origin=1) 

93 with db.declareStaticTables(create=True) as context: 

94 # Table and field names are each below the 63-char limit even when 

95 # accounting for the prefix, but their combination (which will 

96 # appear in sequences and constraints) is not. 

97 tableName = "a_table_with_a_very_very_long_42_char_name" 

98 fieldName1 = "a_column_with_a_very_very_long_43_char_name" 

99 fieldName2 = "another_column_with_a_very_very_long_49_char_name" 

100 context.addTable( 

101 tableName, 

102 ddl.TableSpec( 

103 fields=[ 

104 ddl.FieldSpec( 

105 fieldName1, 

106 dtype=sqlalchemy.BigInteger, 

107 autoincrement=True, 

108 primaryKey=True 

109 ), 

110 ddl.FieldSpec( 

111 fieldName2, 

112 dtype=sqlalchemy.String, 

113 length=16, 

114 nullable=False, 

115 ), 

116 ], 

117 unique={(fieldName2,)}, 

118 ) 

119 ) 

120 # Add another table, this time dynamically, with a foreign key to the 

121 # first table. 

122 db.ensureTableExists( 

123 tableName + "_b", 

124 ddl.TableSpec( 

125 fields=[ 

126 ddl.FieldSpec( 

127 fieldName1 + "_b", 

128 dtype=sqlalchemy.BigInteger, 

129 autoincrement=True, 

130 primaryKey=True 

131 ), 

132 ddl.FieldSpec( 

133 fieldName2 + "_b", 

134 dtype=sqlalchemy.String, 

135 length=16, 

136 nullable=False, 

137 ), 

138 ], 

139 foreignKeys=[ 

140 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)), 

141 ] 

142 ) 

143 ) 

144 

145 def test_RangeTimespanType(self): 

146 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai") 

147 offset = astropy.time.TimeDelta(60, format="sec") 

148 timestamps = [start + offset*n for n in range(3)] 

149 timespans = [Timespan(begin=None, end=None)] 

150 timespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

151 timespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

152 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2)) 

153 db = self.makeEmptyDatabase(origin=1) 

154 with db.declareStaticTables(create=True) as context: 

155 tbl = context.addTable( 

156 "tbl", 

157 ddl.TableSpec( 

158 fields=[ 

159 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

160 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType), 

161 ], 

162 ) 

163 ) 

164 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)] 

165 db.insert(tbl, *rows) 

166 

167 # Test basic round-trip through database. 

168 self.assertEqual( 

169 rows, 

170 [dict(row) for row in db.query(tbl.select().order_by(tbl.columns.id)).fetchall()] 

171 ) 

172 

173 # Test that Timespan's Python methods are consistent with our usage of 

174 # half-open ranges and PostgreSQL operators on ranges. 

175 def subquery(alias: str) -> sqlalchemy.sql.FromClause: 

176 return sqlalchemy.sql.select( 

177 [tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")] 

178 ).select_from( 

179 tbl 

180 ).alias(alias) 

181 sq1 = subquery("sq1") 

182 sq2 = subquery("sq2") 

183 query = sqlalchemy.sql.select([ 

184 sq1.columns.id.label("n1"), 

185 sq2.columns.id.label("n2"), 

186 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"), 

187 ]) 

188 

189 dbResults = { 

190 (row[query.columns.n1], row[query.columns.n2]): row[query.columns.overlaps] 

191 for row in db.query(query) 

192 } 

193 pyResults = { 

194 (n1, n2): t1.overlaps(t2) 

195 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans)) 

196 } 

197 self.assertEqual(pyResults, dbResults) 

198 

199 

200@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

201class PostgresqlRegistryTests(RegistryTests): 

202 """Tests for `Registry` backed by a PostgreSQL database. 

203 

204 Note 

205 ---- 

206 This is not a subclass of `unittest.TestCase` but to avoid repetition it 

207 defines methods that override `unittest.TestCase` methods. To make this 

208 work sublasses have to have this class first in the bases list. 

209 """ 

210 

211 @classmethod 

212 def setUpClass(cls): 

213 cls.root = makeTestTempDir(TESTDIR) 

214 cls.server = _startServer(cls.root) 

215 

216 @classmethod 

217 def tearDownClass(cls): 

218 # Clean up any lingering SQLAlchemy engines/connections 

219 # so they're closed before we shut down the server. 

220 gc.collect() 

221 cls.server.stop() 

222 removeTestTempDir(cls.root) 

223 

224 @classmethod 

225 def getDataDir(cls) -> str: 

226 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry")) 

227 

228 def makeRegistry(self) -> Registry: 

229 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

230 config = self.makeRegistryConfig() 

231 config["db"] = self.server.url() 

232 config["namespace"] = namespace 

233 return Registry.createFromConfig(config) 

234 

235 

236class PostgresqlRegistryNameKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase): 

237 """Tests for `Registry` backed by a PostgreSQL database. 

238 

239 This test case uses NameKeyCollectionManager. 

240 """ 

241 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager" 

242 

243 

244class PostgresqlRegistrySynthIntKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase): 

245 """Tests for `Registry` backed by a PostgreSQL database. 

246 

247 This test case uses SynthIntKeyCollectionManager. 

248 """ 

249 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager" 

250 

251 

252if __name__ == "__main__": 252 ↛ 253line 252 didn't jump to line 253, because the condition on line 252 was never true

253 unittest.main()