Coverage for tests/test_expressions.py: 22%

150 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-17 02:31 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import datetime 

23import unittest 

24 

25import astropy.time 

26import sqlalchemy 

27from lsst.daf.butler import ( 

28 ColumnTypeInfo, 

29 DataCoordinate, 

30 DatasetColumnTag, 

31 DimensionUniverse, 

32 TimespanDatabaseRepresentation, 

33 ddl, 

34 time_utils, 

35) 

36from lsst.daf.butler.registry.queries.expressions import make_string_expression_predicate 

37from lsst.daf.butler.registry.queries.expressions.check import CheckVisitor, InspectionVisitor 

38from lsst.daf.butler.registry.queries.expressions.normalForm import NormalForm, NormalFormExpression 

39from lsst.daf.butler.registry.queries.expressions.parser import ParserYacc 

40from lsst.daf.relation import ColumnContainer, ColumnExpression 

41from sqlalchemy.schema import Column 

42 

43 

44class FakeDatasetRecordStorageManager: 

45 ingestDate = Column("ingest_date") 

46 

47 

48class ConvertExpressionToPredicateTestCase(unittest.TestCase): 

49 """A test case for the make_string_expression_predicate function""" 

50 

51 ingest_date_dtype = sqlalchemy.TIMESTAMP 

52 ingest_date_pytype = datetime.datetime 

53 ingest_date_literal = datetime.datetime(2020, 1, 1) 

54 

55 def setUp(self): 

56 self.column_types = ColumnTypeInfo( 

57 timespan_cls=TimespanDatabaseRepresentation.Compound, 

58 universe=DimensionUniverse(), 

59 dataset_id_spec=ddl.FieldSpec("dataset_id", dtype=ddl.GUID), 

60 run_key_spec=ddl.FieldSpec("run_id", dtype=sqlalchemy.BigInteger), 

61 ingest_date_dtype=self.ingest_date_dtype, 

62 ) 

63 

64 def test_simple(self): 

65 """Test with a trivial expression""" 

66 self.assertEqual( 

67 make_string_expression_predicate( 

68 "1 > 0", self.column_types.universe.empty, column_types=self.column_types 

69 )[0], 

70 ColumnExpression.literal(1, dtype=int).gt(ColumnExpression.literal(0, dtype=int)), 

71 ) 

72 

73 def test_time(self): 

74 """Test with a trivial expression including times""" 

75 time_converter = time_utils.TimeConverter() 

76 self.assertEqual( 

77 make_string_expression_predicate( 

78 "T'1970-01-01 00:00/tai' < T'2020-01-01 00:00/tai'", 

79 self.column_types.universe.empty, 

80 column_types=self.column_types, 

81 )[0], 

82 ColumnExpression.literal(time_converter.nsec_to_astropy(0), dtype=astropy.time.Time).lt( 

83 ColumnExpression.literal( 

84 time_converter.nsec_to_astropy(1577836800000000000), dtype=astropy.time.Time 

85 ) 

86 ), 

87 ) 

88 

89 def test_ingest_date(self): 

90 """Test with an expression including ingest_date which is native UTC""" 

91 self.assertEqual( 

92 make_string_expression_predicate( 

93 "ingest_date < T'2020-01-01 00:00/utc'", 

94 self.column_types.universe.empty, 

95 column_types=self.column_types, 

96 dataset_type_name="fake", 

97 )[0], 

98 ColumnExpression.reference( 

99 DatasetColumnTag("fake", "ingest_date"), dtype=self.ingest_date_pytype 

100 ).lt(ColumnExpression.literal(self.ingest_date_literal, dtype=self.ingest_date_pytype)), 

101 ) 

102 

103 def test_bind(self): 

104 """Test with bind parameters""" 

105 

106 self.assertEqual( 

107 make_string_expression_predicate( 

108 "a > b OR t in (x, y, z)", 

109 self.column_types.universe.empty, 

110 column_types=self.column_types, 

111 bind={"a": 1, "b": 2, "t": 0, "x": 10, "y": 20, "z": 30}, 

112 )[0], 

113 ColumnExpression.literal(1, dtype=int) 

114 .gt(ColumnExpression.literal(2, dtype=int)) 

115 .logical_or( 

116 ColumnContainer.sequence( 

117 [ 

118 ColumnExpression.literal(10, dtype=int), 

119 ColumnExpression.literal(20, dtype=int), 

120 ColumnExpression.literal(30, dtype=int), 

121 ], 

122 dtype=int, 

123 ).contains(ColumnExpression.literal(0, dtype=int)) 

124 ), 

125 ) 

126 

127 def test_bind_list(self): 

128 """Test with bind parameter which is list/tuple/set inside IN rhs.""" 

129 

130 self.assertEqual( 

131 make_string_expression_predicate( 

132 "a > b OR t in (x)", 

133 self.column_types.universe.empty, 

134 column_types=self.column_types, 

135 bind={"a": 1, "b": 2, "t": 0, "x": (10, 20, 30)}, 

136 )[0], 

137 ColumnExpression.literal(1, dtype=int) 

138 .gt(ColumnExpression.literal(2, dtype=int)) 

139 .logical_or( 

140 ColumnContainer.sequence( 

141 [ 

142 ColumnExpression.literal(10, dtype=int), 

143 ColumnExpression.literal(20, dtype=int), 

144 ColumnExpression.literal(30, dtype=int), 

145 ], 

146 dtype=int, 

147 ).contains( 

148 ColumnExpression.literal(0, dtype=int), 

149 ) 

150 ), 

151 ) 

152 # Couple of bound variables inside IN() with different combinations 

153 # of scalars and list. 

154 self.assertEqual( 

155 make_string_expression_predicate( 

156 "a > b OR t in (x, y)", 

157 self.column_types.universe.empty, 

158 column_types=self.column_types, 

159 bind={"a": 1, "b": 2, "t": 0, "x": 10, "y": 20}, 

160 )[0], 

161 ColumnExpression.literal(1, dtype=int) 

162 .gt(ColumnExpression.literal(2, dtype=int)) 

163 .logical_or( 

164 ColumnContainer.sequence( 

165 [ 

166 ColumnExpression.literal(10, dtype=int), 

167 ColumnExpression.literal(20, dtype=int), 

168 ], 

169 dtype=int, 

170 ).contains( 

171 ColumnExpression.literal(0, dtype=int), 

172 ) 

173 ), 

174 ) 

175 self.assertEqual( 

176 make_string_expression_predicate( 

177 "a > b OR t in (x, y)", 

178 self.column_types.universe.empty, 

179 column_types=self.column_types, 

180 bind={"a": 1, "b": 2, "t": 0, "x": [10, 30], "y": 20}, 

181 )[0], 

182 ColumnExpression.literal(1, dtype=int) 

183 .gt(ColumnExpression.literal(2, dtype=int)) 

184 .logical_or( 

185 ColumnContainer.sequence( 

186 [ 

187 ColumnExpression.literal(10, dtype=int), 

188 ColumnExpression.literal(30, dtype=int), 

189 ColumnExpression.literal(20, dtype=int), 

190 ], 

191 dtype=int, 

192 ).contains( 

193 ColumnExpression.literal(0, dtype=int), 

194 ) 

195 ), 

196 ) 

197 self.assertEqual( 

198 make_string_expression_predicate( 

199 "a > b OR t in (x, y)", 

200 self.column_types.universe.empty, 

201 column_types=self.column_types, 

202 bind={"a": 1, "b": 2, "t": 0, "x": (10, 30), "y": {20}}, 

203 )[0], 

204 ColumnExpression.literal(1, dtype=int) 

205 .gt(ColumnExpression.literal(2, dtype=int)) 

206 .logical_or( 

207 ColumnContainer.sequence( 

208 [ 

209 ColumnExpression.literal(10, dtype=int), 

210 ColumnExpression.literal(30, dtype=int), 

211 ColumnExpression.literal(20, dtype=int), 

212 ], 

213 dtype=int, 

214 ).contains(ColumnExpression.literal(0, dtype=int)) 

215 ), 

216 ) 

217 

218 

219class ConvertExpressionToPredicateTestCaseAstropy(ConvertExpressionToPredicateTestCase): 

220 """A test case for the make_string_expression_predicate function with 

221 ingest_date defined as nanoseconds. 

222 """ 

223 

224 ingest_date_dtype = ddl.AstropyTimeNsecTai 

225 ingest_date_pytype = astropy.time.Time 

226 ingest_date_literal = astropy.time.Time(datetime.datetime(2020, 1, 1), scale="utc") 

227 

228 

229class InspectionVisitorTestCase(unittest.TestCase): 

230 """Tests for InspectionVisitor class.""" 

231 

232 def test_simple(self): 

233 """Test for simple expressions""" 

234 

235 universe = DimensionUniverse() 

236 parser = ParserYacc() 

237 

238 tree = parser.parse("instrument = 'LSST'") 

239 bind = {} 

240 summary = tree.visit(InspectionVisitor(universe, bind)) 

241 self.assertEqual(summary.dimensions.names, {"instrument"}) 

242 self.assertFalse(summary.columns) 

243 self.assertFalse(summary.hasIngestDate) 

244 self.assertEqual(summary.dataIdKey, universe["instrument"]) 

245 self.assertEqual(summary.dataIdValue, "LSST") 

246 

247 tree = parser.parse("instrument != 'LSST'") 

248 summary = tree.visit(InspectionVisitor(universe, bind)) 

249 self.assertEqual(summary.dimensions.names, {"instrument"}) 

250 self.assertFalse(summary.columns) 

251 self.assertIsNone(summary.dataIdKey) 

252 self.assertIsNone(summary.dataIdValue) 

253 

254 tree = parser.parse("instrument = 'LSST' AND visit = 1") 

255 summary = tree.visit(InspectionVisitor(universe, bind)) 

256 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

257 self.assertFalse(summary.columns) 

258 self.assertIsNone(summary.dataIdKey) 

259 self.assertIsNone(summary.dataIdValue) 

260 

261 tree = parser.parse("instrument = 'LSST' AND visit = 1 AND skymap = 'x'") 

262 summary = tree.visit(InspectionVisitor(universe, bind)) 

263 self.assertEqual( 

264 summary.dimensions.names, {"instrument", "visit", "band", "physical_filter", "skymap"} 

265 ) 

266 self.assertFalse(summary.columns) 

267 self.assertIsNone(summary.dataIdKey) 

268 self.assertIsNone(summary.dataIdValue) 

269 

270 def test_bind(self): 

271 """Test for simple expressions with binds.""" 

272 

273 universe = DimensionUniverse() 

274 parser = ParserYacc() 

275 

276 tree = parser.parse("instrument = instr") 

277 bind = {"instr": "LSST"} 

278 summary = tree.visit(InspectionVisitor(universe, bind)) 

279 self.assertEqual(summary.dimensions.names, {"instrument"}) 

280 self.assertFalse(summary.hasIngestDate) 

281 self.assertEqual(summary.dataIdKey, universe["instrument"]) 

282 self.assertEqual(summary.dataIdValue, "LSST") 

283 

284 tree = parser.parse("instrument != instr") 

285 self.assertEqual(summary.dimensions.names, {"instrument"}) 

286 summary = tree.visit(InspectionVisitor(universe, bind)) 

287 self.assertIsNone(summary.dataIdKey) 

288 self.assertIsNone(summary.dataIdValue) 

289 

290 tree = parser.parse("instrument = instr AND visit = visit_id") 

291 bind = {"instr": "LSST", "visit_id": 1} 

292 summary = tree.visit(InspectionVisitor(universe, bind)) 

293 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

294 self.assertIsNone(summary.dataIdKey) 

295 self.assertIsNone(summary.dataIdValue) 

296 

297 tree = parser.parse("instrument = 'LSST' AND visit = 1 AND skymap = skymap_name") 

298 bind = {"instr": "LSST", "visit_id": 1, "skymap_name": "x"} 

299 summary = tree.visit(InspectionVisitor(universe, bind)) 

300 self.assertEqual( 

301 summary.dimensions.names, {"instrument", "visit", "band", "physical_filter", "skymap"} 

302 ) 

303 self.assertIsNone(summary.dataIdKey) 

304 self.assertIsNone(summary.dataIdValue) 

305 

306 def test_in(self): 

307 """Test for IN expressions.""" 

308 

309 universe = DimensionUniverse() 

310 parser = ParserYacc() 

311 

312 tree = parser.parse("instrument IN ('LSST')") 

313 bind = {} 

314 summary = tree.visit(InspectionVisitor(universe, bind)) 

315 self.assertEqual(summary.dimensions.names, {"instrument"}) 

316 self.assertFalse(summary.hasIngestDate) 

317 # we do not handle IN with a single item as `=` 

318 self.assertIsNone(summary.dataIdKey) 

319 self.assertIsNone(summary.dataIdValue) 

320 

321 tree = parser.parse("instrument IN (instr)") 

322 bind = {"instr": "LSST"} 

323 summary = tree.visit(InspectionVisitor(universe, bind)) 

324 self.assertEqual(summary.dimensions.names, {"instrument"}) 

325 self.assertIsNone(summary.dataIdKey) 

326 self.assertIsNone(summary.dataIdValue) 

327 

328 tree = parser.parse("visit IN (1,2,3)") 

329 bind = {} 

330 summary = tree.visit(InspectionVisitor(universe, bind)) 

331 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

332 self.assertIsNone(summary.dataIdKey) 

333 self.assertIsNone(summary.dataIdValue) 

334 

335 tree = parser.parse("visit IN (visit1, visit2, visit3)") 

336 bind = {"visit1": 1, "visit2": 2, "visit3": 3} 

337 summary = tree.visit(InspectionVisitor(universe, bind)) 

338 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

339 self.assertIsNone(summary.dataIdKey) 

340 self.assertIsNone(summary.dataIdValue) 

341 

342 tree = parser.parse("visit IN (visits)") 

343 bind = {"visits": (1, 2, 3)} 

344 summary = tree.visit(InspectionVisitor(universe, bind)) 

345 self.assertEqual(summary.dimensions.names, {"instrument", "visit", "band", "physical_filter"}) 

346 self.assertIsNone(summary.dataIdKey) 

347 self.assertIsNone(summary.dataIdValue) 

348 

349 

350class CheckVisitorTestCase(unittest.TestCase): 

351 """Tests for CheckVisitor class.""" 

352 

353 def test_governor(self): 

354 """Test with governor dimension in expression""" 

355 

356 parser = ParserYacc() 

357 

358 universe = DimensionUniverse() 

359 graph = universe.extract(("instrument", "visit")) 

360 dataId = DataCoordinate.makeEmpty(universe) 

361 defaults = DataCoordinate.makeEmpty(universe) 

362 

363 # governor-only constraint 

364 tree = parser.parse("instrument = 'LSST'") 

365 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE) 

366 binds = {} 

367 visitor = CheckVisitor(dataId, graph, binds, defaults) 

368 expr.visit(visitor) 

369 

370 tree = parser.parse("'LSST' = instrument") 

371 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE) 

372 binds = {} 

373 visitor = CheckVisitor(dataId, graph, binds, defaults) 

374 expr.visit(visitor) 

375 

376 # use bind for governor 

377 tree = parser.parse("instrument = instr") 

378 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE) 

379 binds = {"instr": "LSST"} 

380 visitor = CheckVisitor(dataId, graph, binds, defaults) 

381 expr.visit(visitor) 

382 

383 

384if __name__ == "__main__": 

385 unittest.main()