Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23from contextlib import contextmanager 

24import itertools 

25import secrets 

26import unittest 

27import gc 

28 

29import astropy.time 

30try: 

31 # It's possible but silly to have testing.postgresql installed without 

32 # having the postgresql server installed (because then nothing in 

33 # testing.postgresql would work), so we use the presence of that module 

34 # to test whether we can expect the server to be available. 

35 import testing.postgresql 

36except ImportError: 

37 testing = None 

38 

39import sqlalchemy 

40 

41from lsst.daf.butler import ddl, Timespan 

42from lsst.daf.butler.registry import Registry 

43from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType 

44from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests 

45 

46 

47def _startServer(): 

48 """Start a PostgreSQL server and create a database within it, returning 

49 an object encapsulating both. 

50 """ 

51 server = testing.postgresql.Postgresql() 

52 engine = sqlalchemy.engine.create_engine(server.url()) 

53 engine.execute("CREATE EXTENSION btree_gist;") 

54 return server 

55 

56 

57@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

58class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests): 

59 

60 @classmethod 

61 def setUpClass(cls): 

62 cls.server = _startServer() 

63 

64 @classmethod 

65 def tearDownClass(cls): 

66 # Clean up any lingering SQLAlchemy engines/connections 

67 # so they're closed before we shut down the server. 

68 gc.collect() 

69 cls.server.stop() 

70 

71 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase: 

72 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

73 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace) 

74 

75 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase: 

76 return PostgresqlDatabase.fromUri(origin=database.origin, uri=self.server.url(), 

77 namespace=database.namespace, writeable=writeable) 

78 

79 @contextmanager 

80 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase: 

81 yield self.getNewConnection(database, writeable=False) 

82 

83 def testNameShrinking(self): 

84 """Test that too-long names for database entities other than tables 

85 and columns (which we preserve, and just expect to fit) are shrunk. 

86 """ 

87 db = self.makeEmptyDatabase(origin=1) 

88 with db.declareStaticTables(create=True) as context: 

89 # Table and field names are each below the 63-char limit even when 

90 # accounting for the prefix, but their combination (which will 

91 # appear in sequences and constraints) is not. 

92 tableName = "a_table_with_a_very_very_long_42_char_name" 

93 fieldName1 = "a_column_with_a_very_very_long_43_char_name" 

94 fieldName2 = "another_column_with_a_very_very_long_49_char_name" 

95 context.addTable( 

96 tableName, 

97 ddl.TableSpec( 

98 fields=[ 

99 ddl.FieldSpec( 

100 fieldName1, 

101 dtype=sqlalchemy.BigInteger, 

102 autoincrement=True, 

103 primaryKey=True 

104 ), 

105 ddl.FieldSpec( 

106 fieldName2, 

107 dtype=sqlalchemy.String, 

108 length=16, 

109 nullable=False, 

110 ), 

111 ], 

112 unique={(fieldName2,)}, 

113 ) 

114 ) 

115 # Add another table, this time dynamically, with a foreign key to the 

116 # first table. 

117 db.ensureTableExists( 

118 tableName + "_b", 

119 ddl.TableSpec( 

120 fields=[ 

121 ddl.FieldSpec( 

122 fieldName1 + "_b", 

123 dtype=sqlalchemy.BigInteger, 

124 autoincrement=True, 

125 primaryKey=True 

126 ), 

127 ddl.FieldSpec( 

128 fieldName2 + "_b", 

129 dtype=sqlalchemy.String, 

130 length=16, 

131 nullable=False, 

132 ), 

133 ], 

134 foreignKeys=[ 

135 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)), 

136 ] 

137 ) 

138 ) 

139 

140 def test_RangeTimespanType(self): 

141 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai") 

142 offset = astropy.time.TimeDelta(60, format="sec") 

143 timestamps = [start + offset*n for n in range(3)] 

144 timespans = [Timespan(begin=None, end=None)] 

145 timespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

146 timespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

147 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2)) 

148 db = self.makeEmptyDatabase(origin=1) 

149 with db.declareStaticTables(create=True) as context: 

150 tbl = context.addTable( 

151 "tbl", 

152 ddl.TableSpec( 

153 fields=[ 

154 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

155 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType), 

156 ], 

157 ) 

158 ) 

159 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)] 

160 db.insert(tbl, *rows) 

161 

162 # Test basic round-trip through database. 

163 self.assertEqual( 

164 rows, 

165 [dict(row) for row in db.query(tbl.select().order_by(tbl.columns.id)).fetchall()] 

166 ) 

167 

168 # Test that Timespan's Python methods are consistent with our usage of 

169 # half-open ranges and PostgreSQL operators on ranges. 

170 def subquery(alias: str) -> sqlalchemy.sql.FromClause: 

171 return sqlalchemy.sql.select( 

172 [tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")] 

173 ).select_from( 

174 tbl 

175 ).alias(alias) 

176 sq1 = subquery("sq1") 

177 sq2 = subquery("sq2") 

178 query = sqlalchemy.sql.select([ 

179 sq1.columns.id.label("n1"), 

180 sq2.columns.id.label("n2"), 

181 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"), 

182 ]) 

183 

184 dbResults = { 

185 (row[query.columns.n1], row[query.columns.n2]): row[query.columns.overlaps] 

186 for row in db.query(query) 

187 } 

188 pyResults = { 

189 (n1, n2): t1.overlaps(t2) 

190 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans)) 

191 } 

192 self.assertEqual(pyResults, dbResults) 

193 

194 

195@unittest.skipUnless(testing is not None, "testing.postgresql module not found") 

196class PostgresqlRegistryTests(RegistryTests): 

197 """Tests for `Registry` backed by a PostgreSQL database. 

198 

199 Note 

200 ---- 

201 This is not a subclass of `unittest.TestCase` but to avoid repetition it 

202 defines methods that override `unittest.TestCase` methods. To make this 

203 work sublasses have to have this class first in the bases list. 

204 """ 

205 

206 @classmethod 

207 def setUpClass(cls): 

208 cls.server = _startServer() 

209 

210 @classmethod 

211 def tearDownClass(cls): 

212 cls.server.stop() 

213 

214 @classmethod 

215 def getDataDir(cls) -> str: 

216 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry")) 

217 

218 def makeRegistry(self) -> Registry: 

219 namespace = f"namespace_{secrets.token_hex(8).lower()}" 

220 config = self.makeRegistryConfig() 

221 config["db"] = self.server.url() 

222 config["namespace"] = namespace 

223 return Registry.fromConfig(config, create=True) 

224 

225 

226class PostgresqlRegistryNameKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase): 

227 """Tests for `Registry` backed by a PostgreSQL database. 

228 

229 This test case uses NameKeyCollectionManager. 

230 """ 

231 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager" 

232 

233 

234class PostgresqlRegistrySynthIntKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase): 

235 """Tests for `Registry` backed by a PostgreSQL database. 

236 

237 This test case uses SynthIntKeyCollectionManager. 

238 """ 

239 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager" 

240 

241 

242if __name__ == "__main__": 242 ↛ 243line 242 didn't jump to line 243, because the condition on line 242 was never true

243 unittest.main()