Coverage for tests/test_expressions.py: 22%

150 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-05 01:26 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import datetime 

23import unittest 

24 

25import astropy.time 

26import sqlalchemy 

27from lsst.daf.butler import ( 

28 ColumnTypeInfo, 

29 DataCoordinate, 

30 DatasetColumnTag, 

31 DimensionUniverse, 

32 TimespanDatabaseRepresentation, 

33 ddl, 

34 time_utils, 

35) 

36from lsst.daf.butler.registry.queries.expressions import make_string_expression_predicate 

37from lsst.daf.butler.registry.queries.expressions.check import CheckVisitor, InspectionVisitor 

38from lsst.daf.butler.registry.queries.expressions.normalForm import NormalForm, NormalFormExpression 

39from lsst.daf.butler.registry.queries.expressions.parser import ParserYacc 

40from lsst.daf.relation import ColumnContainer, ColumnExpression 

41from sqlalchemy.schema import Column 

42 

43 

44class FakeDatasetRecordStorageManager: 

45 """Fake class for representing dataset record storage.""" 

46 

47 ingestDate = Column("ingest_date") 

48 

49 

50class ConvertExpressionToPredicateTestCase(unittest.TestCase): 

51 """A test case for the make_string_expression_predicate function""" 

52 

53 ingest_date_dtype = sqlalchemy.TIMESTAMP 

54 ingest_date_pytype = datetime.datetime 

55 ingest_date_literal = datetime.datetime(2020, 1, 1) 

56 

57 def setUp(self): 

58 self.column_types = ColumnTypeInfo( 

59 timespan_cls=TimespanDatabaseRepresentation.Compound, 

60 universe=DimensionUniverse(), 

61 dataset_id_spec=ddl.FieldSpec("dataset_id", dtype=ddl.GUID), 

62 run_key_spec=ddl.FieldSpec("run_id", dtype=sqlalchemy.BigInteger), 

63 ingest_date_dtype=self.ingest_date_dtype, 

64 ) 

65 

66 def test_simple(self): 

67 """Test with a trivial expression""" 

68 self.assertEqual( 

69 make_string_expression_predicate( 

70 "1 > 0", self.column_types.universe.empty, column_types=self.column_types 

71 )[0], 

72 ColumnExpression.literal(1, dtype=int).gt(ColumnExpression.literal(0, dtype=int)), 

73 ) 

74 

75 def test_time(self): 

76 """Test with a trivial expression including times""" 

77 time_converter = time_utils.TimeConverter() 

78 self.assertEqual( 

79 make_string_expression_predicate( 

80 "T'1970-01-01 00:00/tai' < T'2020-01-01 00:00/tai'", 

81 self.column_types.universe.empty, 

82 column_types=self.column_types, 

83 )[0], 

84 ColumnExpression.literal(time_converter.nsec_to_astropy(0), dtype=astropy.time.Time).lt( 

85 ColumnExpression.literal( 

86 time_converter.nsec_to_astropy(1577836800000000000), dtype=astropy.time.Time 

87 ) 

88 ), 

89 ) 

90 

91 def test_ingest_date(self): 

92 """Test with an expression including ingest_date which is native UTC""" 

93 self.assertEqual( 

94 make_string_expression_predicate( 

95 "ingest_date < T'2020-01-01 00:00/utc'", 

96 self.column_types.universe.empty, 

97 column_types=self.column_types, 

98 dataset_type_name="fake", 

99 )[0], 

100 ColumnExpression.reference( 

101 DatasetColumnTag("fake", "ingest_date"), dtype=self.ingest_date_pytype 

102 ).lt(ColumnExpression.literal(self.ingest_date_literal, dtype=self.ingest_date_pytype)), 

103 ) 

104 

105 def test_bind(self): 

106 """Test with bind parameters""" 

107 self.assertEqual( 

108 make_string_expression_predicate( 

109 "a > b OR t in (x, y, z)", 

110 self.column_types.universe.empty, 

111 column_types=self.column_types, 

112 bind={"a": 1, "b": 2, "t": 0, "x": 10, "y": 20, "z": 30}, 

113 )[0], 

114 ColumnExpression.literal(1, dtype=int) 

115 .gt(ColumnExpression.literal(2, dtype=int)) 

116 .logical_or( 

117 ColumnContainer.sequence( 

118 [ 

119 ColumnExpression.literal(10, dtype=int), 

120 ColumnExpression.literal(20, dtype=int), 

121 ColumnExpression.literal(30, dtype=int), 

122 ], 

123 dtype=int, 

124 ).contains(ColumnExpression.literal(0, dtype=int)) 

125 ), 

126 ) 

127 

128 def test_bind_list(self): 

129 """Test with bind parameter which is list/tuple/set inside IN rhs.""" 

130 self.assertEqual( 

131 make_string_expression_predicate( 

132 "a > b OR t in (x)", 

133 self.column_types.universe.empty, 

134 column_types=self.column_types, 

135 bind={"a": 1, "b": 2, "t": 0, "x": (10, 20, 30)}, 

136 )[0], 

137 ColumnExpression.literal(1, dtype=int) 

138 .gt(ColumnExpression.literal(2, dtype=int)) 

139 .logical_or( 

140 ColumnContainer.sequence( 

141 [ 

142 ColumnExpression.literal(10, dtype=int), 

143 ColumnExpression.literal(20, dtype=int), 

144 ColumnExpression.literal(30, dtype=int), 

145 ], 

146 dtype=int, 

147 ).contains( 

148 ColumnExpression.literal(0, dtype=int), 

149 ) 

150 ), 

151 ) 

152 # Couple of bound variables inside IN() with different combinations 

153 # of scalars and list. 

154 self.assertEqual( 

155 make_string_expression_predicate( 

156 "a > b OR t in (x, y)", 

157 self.column_types.universe.empty, 

158 column_types=self.column_types, 

159 bind={"a": 1, "b": 2, "t": 0, "x": 10, "y": 20}, 

160 )[0], 

161 ColumnExpression.literal(1, dtype=int) 

162 .gt(ColumnExpression.literal(2, dtype=int)) 

163 .logical_or( 

164 ColumnContainer.sequence( 

165 [ 

166 ColumnExpression.literal(10, dtype=int), 

167 ColumnExpression.literal(20, dtype=int), 

168 ], 

169 dtype=int, 

170 ).contains( 

171 ColumnExpression.literal(0, dtype=int), 

172 ) 

173 ), 

174 ) 

175 self.assertEqual( 

176 make_string_expression_predicate( 

177 "a > b OR t in (x, y)", 

178 self.column_types.universe.empty, 

179 column_types=self.column_types, 

180 bind={"a": 1, "b": 2, "t": 0, "x": [10, 30], "y": 20}, 

181 )[0], 

182 ColumnExpression.literal(1, dtype=int) 

183 .gt(ColumnExpression.literal(2, dtype=int)) 

184 .logical_or( 

185 ColumnContainer.sequence( 

186 [ 

187 ColumnExpression.literal(10, dtype=int), 

188 ColumnExpression.literal(30, dtype=int), 

189 ColumnExpression.literal(20, dtype=int), 

190 ], 

191 dtype=int, 

192 ).contains( 

193 ColumnExpression.literal(0, dtype=int), 

194 ) 

195 ), 

196 ) 

197 self.assertEqual( 

198 make_string_expression_predicate( 

199 "a > b OR t in (x, y)", 

200 self.column_types.universe.empty, 

201 column_types=self.column_types, 

202 bind={"a": 1, "b": 2, "t": 0, "x": (10, 30), "y": {20}}, 

203 )[0], 

204 ColumnExpression.literal(1, dtype=int) 

205 .gt(ColumnExpression.literal(2, dtype=int)) 

206 .logical_or( 

207 ColumnContainer.sequence( 

208 [ 

209 ColumnExpression.literal(10, dtype=int), 

210 ColumnExpression.literal(30, dtype=int), 

211 ColumnExpression.literal(20, dtype=int), 

212 ], 

213 dtype=int, 

214 ).contains(ColumnExpression.literal(0, dtype=int)) 

215 ), 

216 ) 

217 

218 

219class ConvertExpressionToPredicateTestCaseAstropy(ConvertExpressionToPredicateTestCase): 

220 """A test case for the make_string_expression_predicate function with 

221 ingest_date defined as nanoseconds. 

222 """ 

223 

224 ingest_date_dtype = ddl.AstropyTimeNsecTai 

225 ingest_date_pytype = astropy.time.Time 

226 ingest_date_literal = astropy.time.Time(datetime.datetime(2020, 1, 1), scale="utc") 

227 

228 

229class InspectionVisitorTestCase(unittest.TestCase): 

230 """Tests for InspectionVisitor class.""" 

231 

232 def test_simple(self): 

233 """Test for simple expressions""" 

234 universe = DimensionUniverse() 

235 parser = ParserYacc() 

236 

237 tree = parser.parse("instrument = 'LSST'") 

238 bind = {} 

239 summary = tree.visit(InspectionVisitor(universe, bind)) 

240 self.assertEqual(summary.dimensions.names, {"instrument"}) 

241 self.assertFalse(summary.columns) 

242 self.assertFalse(summary.hasIngestDate) 

243 self.assertEqual(summary.dataIdKey, universe["instrument"]) 

244 self.assertEqual(summary.dataIdValue, "LSST") 

245 

246 tree = parser.parse("instrument != 'LSST'") 

247 summary = tree.visit(InspectionVisitor(universe, bind)) 

248 self.assertEqual(summary.dimensions.names, {"instrument"}) 

249 self.assertFalse(summary.columns) 

250 self.assertIsNone(summary.dataIdKey) 

251 self.assertIsNone(summary.dataIdValue) 

252 

253 tree = parser.parse("instrument = 'LSST' AND visit = 1") 

254 summary = tree.visit(InspectionVisitor(universe, bind)) 

255 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

256 self.assertFalse(summary.columns) 

257 self.assertIsNone(summary.dataIdKey) 

258 self.assertIsNone(summary.dataIdValue) 

259 

260 tree = parser.parse("instrument = 'LSST' AND visit = 1 AND skymap = 'x'") 

261 summary = tree.visit(InspectionVisitor(universe, bind)) 

262 self.assertEqual( 

263 summary.dimensions.names, {"instrument", "visit", "band", "physical_filter", "skymap"} 

264 ) 

265 self.assertFalse(summary.columns) 

266 self.assertIsNone(summary.dataIdKey) 

267 self.assertIsNone(summary.dataIdValue) 

268 

269 def test_bind(self): 

270 """Test for simple expressions with binds.""" 

271 universe = DimensionUniverse() 

272 parser = ParserYacc() 

273 

274 tree = parser.parse("instrument = instr") 

275 bind = {"instr": "LSST"} 

276 summary = tree.visit(InspectionVisitor(universe, bind)) 

277 self.assertEqual(summary.dimensions.names, {"instrument"}) 

278 self.assertFalse(summary.hasIngestDate) 

279 self.assertEqual(summary.dataIdKey, universe["instrument"]) 

280 self.assertEqual(summary.dataIdValue, "LSST") 

281 

282 tree = parser.parse("instrument != instr") 

283 self.assertEqual(summary.dimensions.names, {"instrument"}) 

284 summary = tree.visit(InspectionVisitor(universe, bind)) 

285 self.assertIsNone(summary.dataIdKey) 

286 self.assertIsNone(summary.dataIdValue) 

287 

288 tree = parser.parse("instrument = instr AND visit = visit_id") 

289 bind = {"instr": "LSST", "visit_id": 1} 

290 summary = tree.visit(InspectionVisitor(universe, bind)) 

291 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

292 self.assertIsNone(summary.dataIdKey) 

293 self.assertIsNone(summary.dataIdValue) 

294 

295 tree = parser.parse("instrument = 'LSST' AND visit = 1 AND skymap = skymap_name") 

296 bind = {"instr": "LSST", "visit_id": 1, "skymap_name": "x"} 

297 summary = tree.visit(InspectionVisitor(universe, bind)) 

298 self.assertEqual( 

299 summary.dimensions.names, {"instrument", "visit", "band", "physical_filter", "skymap"} 

300 ) 

301 self.assertIsNone(summary.dataIdKey) 

302 self.assertIsNone(summary.dataIdValue) 

303 

304 def test_in(self): 

305 """Test for IN expressions.""" 

306 universe = DimensionUniverse() 

307 parser = ParserYacc() 

308 

309 tree = parser.parse("instrument IN ('LSST')") 

310 bind = {} 

311 summary = tree.visit(InspectionVisitor(universe, bind)) 

312 self.assertEqual(summary.dimensions.names, {"instrument"}) 

313 self.assertFalse(summary.hasIngestDate) 

314 # we do not handle IN with a single item as `=` 

315 self.assertIsNone(summary.dataIdKey) 

316 self.assertIsNone(summary.dataIdValue) 

317 

318 tree = parser.parse("instrument IN (instr)") 

319 bind = {"instr": "LSST"} 

320 summary = tree.visit(InspectionVisitor(universe, bind)) 

321 self.assertEqual(summary.dimensions.names, {"instrument"}) 

322 self.assertIsNone(summary.dataIdKey) 

323 self.assertIsNone(summary.dataIdValue) 

324 

325 tree = parser.parse("visit IN (1,2,3)") 

326 bind = {} 

327 summary = tree.visit(InspectionVisitor(universe, bind)) 

328 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

329 self.assertIsNone(summary.dataIdKey) 

330 self.assertIsNone(summary.dataIdValue) 

331 

332 tree = parser.parse("visit IN (visit1, visit2, visit3)") 

333 bind = {"visit1": 1, "visit2": 2, "visit3": 3} 

334 summary = tree.visit(InspectionVisitor(universe, bind)) 

335 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

336 self.assertIsNone(summary.dataIdKey) 

337 self.assertIsNone(summary.dataIdValue) 

338 

339 tree = parser.parse("visit IN (visits)") 

340 bind = {"visits": (1, 2, 3)} 

341 summary = tree.visit(InspectionVisitor(universe, bind)) 

342 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

343 self.assertIsNone(summary.dataIdKey) 

344 self.assertIsNone(summary.dataIdValue) 

345 

346 

347class CheckVisitorTestCase(unittest.TestCase): 

348 """Tests for CheckVisitor class.""" 

349 

350 def test_governor(self): 

351 """Test with governor dimension in expression""" 

352 parser = ParserYacc() 

353 

354 universe = DimensionUniverse() 

355 graph = universe.extract(("instrument", "visit")) 

356 dataId = DataCoordinate.makeEmpty(universe) 

357 defaults = DataCoordinate.makeEmpty(universe) 

358 

359 # governor-only constraint 

360 tree = parser.parse("instrument = 'LSST'") 

361 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE) 

362 binds = {} 

363 visitor = CheckVisitor(dataId, graph, binds, defaults) 

364 expr.visit(visitor) 

365 

366 tree = parser.parse("'LSST' = instrument") 

367 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE) 

368 binds = {} 

369 visitor = CheckVisitor(dataId, graph, binds, defaults) 

370 expr.visit(visitor) 

371 

372 # use bind for governor 

373 tree = parser.parse("instrument = instr") 

374 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE) 

375 binds = {"instr": "LSST"} 

376 visitor = CheckVisitor(dataId, graph, binds, defaults) 

377 expr.visit(visitor) 

378 

379 

380if __name__ == "__main__": 

381 unittest.main()