Coverage for tests/test_schema.py: 11%

316 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-14 03:43 -0800

1# This file is part of afw. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22""" 

23Tests for table.Schema 

24 

25Run with: 

26 python test_schema.py 

27or 

28 pytest test_schema.py 

29""" 

30 

31import unittest 

32import pickle 

33 

34import numpy as np 

35 

36import lsst.utils.tests 

37import lsst.pex.exceptions 

38import lsst.geom 

39import lsst.afw.table 

40 

41 

42def _checkSchemaIdentical(schema1, schema2): 

43 return schema1.compare(schema2, lsst.afw.table.Schema.IDENTICAL) == lsst.afw.table.Schema.IDENTICAL 

44 

45 

46def addTestFields(schema): 

47 """Add Fields to the schema to test operations on each Field type. 

48 """ 

49 schema.addField("ra", type="Angle", doc="coord_ra") 

50 schema.addField("dec", type="Angle", doc="coord_dec") 

51 schema.addField("x", type="D", doc="position_x", units="pixel") 

52 schema.addField("y", type="D", doc="position_y") 

53 schema.addField("i", type="I", doc="int") 

54 schema.addField("f", type="F", doc="float", units="m2") 

55 schema.addField("flag", type="Flag", doc="a flag") 

56 schema.addField("string", type="String", doc="A string field", size=42) 

57 schema.addField("variable_string", type="String", doc="A variable-length string field", size=0) 

58 schema.addField("array", type="ArrayF", doc="An array field", size=10) 

59 schema.addField("variable_array", type="ArrayF", doc="A variable-length array field", size=0) 

60 

61 

62class SchemaTestCase(unittest.TestCase): 

63 

64 def testSchema(self): 

65 def testKey(name, key): 

66 col = schema.find(name) 

67 self.assertEqual(col.key, key) 

68 self.assertEqual(hash(col.key), hash(key)) 

69 self.assertEqual(col.field.getName(), name) 

70 

71 schema = lsst.afw.table.Schema() 

72 ab_k = lsst.afw.table.CoordKey.addFields(schema, "a_b", "parent coord") 

73 abp_k = lsst.afw.table.Point2DKey.addFields( 

74 schema, "a_b_p", "point", "pixel") 

75 abi_k = schema.addField("a_b_i", type=np.int32, doc="int") 

76 acf_k = schema.addField("a_c_f", type=np.float32, doc="float") 

77 egd_k = schema.addField("e_g_d", type=lsst.geom.Angle, doc="angle") 

78 

79 # Basic test for all native key types. 

80 for name, key in (("a_b_i", abi_k), ("a_c_f", acf_k), ("e_g_d", egd_k)): 

81 testKey(name, key) 

82 

83 # Extra tests for special types 

84 self.assertEqual(ab_k.getRa(), schema["a_b_ra"].asKey()) 

85 self.assertEqual(hash(ab_k.getRa()), hash(schema["a_b_ra"].asKey())) 

86 abpx_si = schema.find("a_b_p_x") 

87 self.assertEqual(abp_k.getX(), abpx_si.key) 

88 self.assertEqual(hash(abp_k.getX()), hash(abpx_si.key)) 

89 self.assertEqual(abpx_si.field.getName(), "a_b_p_x") 

90 self.assertEqual(abpx_si.field.getDoc(), "point") 

91 self.assertEqual(abp_k.getX(), schema["a_b_p_x"].asKey()) 

92 self.assertEqual(schema.getNames(), {'a_b_dec', 'a_b_i', 'a_b_p_x', 'a_b_p_y', 'a_b_ra', 'a_c_f', 

93 'e_g_d'}) 

94 self.assertEqual(schema.getNames(True), {"a", "e"}) 

95 self.assertEqual(schema["a"].getNames(), { 

96 'b_dec', 'b_i', 'b_p_x', 'b_p_y', 'b_ra', 'c_f'}) 

97 self.assertEqual(schema["a"].getNames(True), {"b", "c"}) 

98 schema2 = lsst.afw.table.Schema(schema) 

99 self.assertEqual(schema, schema2) 

100 schema2.addField("q", type="F", doc="another double") 

101 self.assertNotEqual(schema, schema2) 

102 schema3 = lsst.afw.table.Schema() 

103 schema3.addField("ra", type="Angle", doc="coord_ra") 

104 schema3.addField("dec", type="Angle", doc="coord_dec") 

105 schema3.addField("x", type="D", doc="position_x") 

106 schema3.addField("y", type="D", doc="position_y") 

107 schema3.addField("i", type="I", doc="int") 

108 schema3.addField("f", type="F", doc="float") 

109 schema3.addField("d", type="Angle", doc="angle") 

110 self.assertEqual(schema3, schema) 

111 schema4 = lsst.afw.table.Schema() 

112 keys = [] 

113 keys.append(schema4.addField("a", type="Angle", doc="a")) 

114 keys.append(schema4.addField("b", type="Flag", doc="b")) 

115 keys.append(schema4.addField("c", type="I", doc="c")) 

116 keys.append(schema4.addField("d", type="Flag", doc="d")) 

117 self.assertEqual(keys[1].getBit(), 0) 

118 self.assertEqual(keys[3].getBit(), 1) 

119 for n1, n2 in zip(schema4.getOrderedNames(), "abcd"): 

120 self.assertEqual(n1, n2) 

121 keys2 = [x.key for x in schema4] 

122 self.assertEqual(keys, keys2) 

123 

124 def testUnits(self): 

125 schema = lsst.afw.table.Schema() 

126 # first insert some valid units 

127 schema.addField("a", type="I", units="pixel") 

128 schema.addField("b", type="I", units="m2") 

129 schema.addField("c", type="I", units="electron / adu") 

130 schema.addField("d", type="I", units="kg m s^(-2)") 

131 schema.addField("e", type="I", units="GHz / Mpc") 

132 schema.addField("f", type="Angle", units="deg") 

133 schema.addField("g", type="Angle", units="rad") 

134 schema.checkUnits() 

135 # now try inserting invalid units 

136 with self.assertRaises(ValueError): 

137 schema.addField("a", type="I", units="camel") 

138 with self.assertRaises(ValueError): 

139 schema.addField("b", type="I", units="pixels^2^2") 

140 # add invalid units in silent mode, should work fine 

141 schema.addField("h", type="I", units="lala", parse_strict='silent') 

142 # Now this check should raise because there is an invalid unit 

143 with self.assertRaises(ValueError): 

144 schema.checkUnits() 

145 

146 def testInspection(self): 

147 schema = lsst.afw.table.Schema() 

148 keys = [] 

149 keys.append(schema.addField("d", type=np.int32)) 

150 keys.append(schema.addField("c", type=np.float32)) 

151 keys.append(schema.addField("b", type="ArrayF", size=3)) 

152 keys.append(schema.addField("a", type="F")) 

153 for key, item in zip(keys, schema): 

154 self.assertEqual(item.key, key) 

155 self.assertIn(key, schema) 

156 for name in ("a", "b", "c", "d"): 

157 self.assertIn(name, schema) 

158 self.assertNotIn("e", schema) 

159 otherSchema = lsst.afw.table.Schema() 

160 otherKey = otherSchema.addField("d", type=np.float32) 

161 self.assertNotIn(otherKey, schema) 

162 self.assertNotEqual(keys[0], keys[1]) 

163 

164 def testKeyAccessors(self): 

165 schema = lsst.afw.table.Schema() 

166 arrayKey = schema.addField( 

167 "a", type="ArrayF", doc="doc for array field", size=5) 

168 arrayElementKey = arrayKey[1] 

169 self.assertEqual(lsst.afw.table.Key["F"], type(arrayElementKey)) 

170 

171 def testComparison(self): 

172 schema1 = lsst.afw.table.Schema() 

173 schema1.addField("a", type=np.float32, doc="doc for a", units="m") 

174 schema1.addField("b", type=np.int32, doc="doc for b", units="s") 

175 schema2 = lsst.afw.table.Schema() 

176 schema2.addField("a", type=np.int32, doc="doc for a", units="m") 

177 schema2.addField("b", type=np.float32, doc="doc for b", units="s") 

178 cmp1 = schema1.compare(schema2, lsst.afw.table.Schema.IDENTICAL) 

179 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_NAMES) 

180 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_DOCS) 

181 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_UNITS) 

182 self.assertFalse(cmp1 & lsst.afw.table.Schema.EQUAL_KEYS) 

183 schema3 = lsst.afw.table.Schema(schema1) 

184 schema3.addField("c", type=str, doc="doc for c", size=4) 

185 self.assertFalse(schema1.compare(schema3)) 

186 self.assertFalse(schema1.contains(schema3)) 

187 self.assertTrue(schema3.contains(schema1)) 

188 schema1.addField("d", type=str, doc="no docs!", size=4) 

189 cmp2 = schema1.compare(schema3, lsst.afw.table.Schema.IDENTICAL) 

190 self.assertFalse(cmp2 & lsst.afw.table.Schema.EQUAL_NAMES) 

191 self.assertFalse(cmp2 & lsst.afw.table.Schema.EQUAL_DOCS) 

192 self.assertTrue(cmp2 & lsst.afw.table.Schema.EQUAL_KEYS) 

193 self.assertTrue(cmp2 & lsst.afw.table.Schema.EQUAL_UNITS) 

194 self.assertFalse(schema1.compare( 

195 schema3, lsst.afw.table.Schema.EQUAL_NAMES)) 

196 

197 def testPickle(self): 

198 schema = lsst.afw.table.Schema() 

199 addTestFields(schema) 

200 

201 pickled = pickle.dumps(schema, protocol=pickle.HIGHEST_PROTOCOL) 

202 unpickled = pickle.loads(pickled) 

203 self.assertEqual(schema, unpickled) 

204 

205 

206class SchemaMapperTestCase(unittest.TestCase): 

207 

208 def testJoin(self): 

209 inputs = [lsst.afw.table.Schema(), lsst.afw.table.Schema(), 

210 lsst.afw.table.Schema()] 

211 prefixes = ["u", "v", "w"] 

212 ka = inputs[0].addField("a", type=np.float64, doc="doc for a") 

213 kb = inputs[0].addField("b", type=np.int32, doc="doc for b") 

214 kc = inputs[1].addField("c", type=np.float32, doc="doc for c") 

215 kd = inputs[2].addField("d", type=np.int64, doc="doc for d") 

216 flags1 = lsst.afw.table.Schema.IDENTICAL 

217 flags2 = flags1 & ~lsst.afw.table.Schema.EQUAL_NAMES 

218 mappers1 = lsst.afw.table.SchemaMapper.join(inputs) 

219 mappers2 = lsst.afw.table.SchemaMapper.join(inputs, prefixes) 

220 records = [lsst.afw.table.BaseTable.make(schema).makeRecord() for 

221 schema in inputs] 

222 records[0].set(ka, 3.14159) 

223 records[0].set(kb, 21623) 

224 records[1].set(kc, 1.5616) 

225 records[2].set(kd, 1261236) 

226 for mappers, flags in zip((mappers1, mappers2), (flags1, flags2)): 

227 output = lsst.afw.table.BaseTable.make( 

228 mappers[0].getOutputSchema()).makeRecord() 

229 for mapper, record in zip(mappers, records): 

230 output.assign(record, mapper) 

231 self.assertEqual( 

232 mapper.getOutputSchema().compare(output.getSchema(), 

233 flags), 

234 flags) 

235 self.assertEqual( 

236 mapper.getInputSchema().compare(record.getSchema(), 

237 flags), 

238 flags) 

239 names = output.getSchema().getOrderedNames() 

240 self.assertEqual(output.get(names[0]), records[0].get(ka)) 

241 self.assertEqual(output.get(names[1]), records[0].get(kb)) 

242 self.assertEqual(output.get(names[2]), records[1].get(kc)) 

243 self.assertEqual(output.get(names[3]), records[2].get(kd)) 

244 

245 def testMinimalSchema(self): 

246 front = lsst.afw.table.Schema() 

247 ka = front.addField("a", type=np.float64, doc="doc for a") 

248 kb = front.addField("b", type=np.int32, doc="doc for b") 

249 full = lsst.afw.table.Schema(front) 

250 kc = full.addField("c", type=np.float32, doc="doc for c") 

251 kd = full.addField("d", type=np.int64, doc="doc for d") 

252 mapper1 = lsst.afw.table.SchemaMapper(full) 

253 mapper2 = lsst.afw.table.SchemaMapper(full) 

254 mapper3 = lsst.afw.table.SchemaMapper.removeMinimalSchema(full, front) 

255 mapper1.addMinimalSchema(front) 

256 mapper2.addMinimalSchema(front, False) 

257 self.assertIn(ka, mapper1.getOutputSchema()) 

258 self.assertIn(kb, mapper1.getOutputSchema()) 

259 self.assertNotIn(kc, mapper1.getOutputSchema()) 

260 self.assertNotIn(kd, mapper1.getOutputSchema()) 

261 self.assertIn(ka, mapper2.getOutputSchema()) 

262 self.assertIn(kb, mapper2.getOutputSchema()) 

263 self.assertNotIn(kc, mapper2.getOutputSchema()) 

264 self.assertNotIn(kd, mapper2.getOutputSchema()) 

265 self.assertNotIn(ka, mapper3.getOutputSchema()) 

266 self.assertNotIn(kb, mapper3.getOutputSchema()) 

267 self.assertNotIn(kc, mapper3.getOutputSchema()) 

268 self.assertNotIn(kd, mapper3.getOutputSchema()) 

269 inputRecord = lsst.afw.table.BaseTable.make(full).makeRecord() 

270 inputRecord.set(ka, np.pi) 

271 inputRecord.set(kb, 2) 

272 inputRecord.set(kc, np.exp(1)) 

273 inputRecord.set(kd, 4) 

274 outputRecord1 = lsst.afw.table.BaseTable.make( 

275 mapper1.getOutputSchema()).makeRecord() 

276 outputRecord1.assign(inputRecord, mapper1) 

277 self.assertEqual(inputRecord.get(ka), outputRecord1.get(ka)) 

278 self.assertEqual(inputRecord.get(kb), outputRecord1.get(kb)) 

279 

280 def testOutputSchema(self): 

281 mapper = lsst.afw.table.SchemaMapper(lsst.afw.table.Schema()) 

282 out1 = mapper.getOutputSchema() 

283 out2 = mapper.editOutputSchema() 

284 k1 = out1.addField("a1", type=np.int32) 

285 self.assertNotIn(k1, mapper.getOutputSchema()) 

286 self.assertIn(k1, out1) 

287 self.assertNotIn(k1, out2) 

288 k2 = mapper.addOutputField( 

289 lsst.afw.table.Field[np.float32]("a2", "doc for a2")) 

290 self.assertNotIn(k2, out1) 

291 self.assertIn(k2, mapper.getOutputSchema()) 

292 self.assertIn(k2, out2) 

293 k3 = out2.addField("a3", type=np.float32, doc="doc for a3") 

294 self.assertNotIn(k3, out1) 

295 self.assertIn(k3, mapper.getOutputSchema()) 

296 self.assertIn(k3, out2) 

297 self.assertIn(k2, out2) 

298 

299 def testDoReplace(self): 

300 inSchema = lsst.afw.table.Schema() 

301 ka = inSchema.addField("a", type=np.int32) 

302 outSchema = lsst.afw.table.Schema(inSchema) 

303 kb = outSchema.addField("b", type=np.int32) 

304 kc = outSchema.addField("c", type=np.int32) 

305 mapper1 = lsst.afw.table.SchemaMapper(inSchema, outSchema) 

306 mapper1.addMapping(ka, True) 

307 self.assertEqual(mapper1.getMapping(ka), ka) 

308 mapper2 = lsst.afw.table.SchemaMapper(inSchema, outSchema) 

309 mapper2.addMapping( 

310 ka, lsst.afw.table.Field[np.int32]("b", "doc for b"), True) 

311 self.assertEqual(mapper2.getMapping(ka), kb) 

312 mapper3 = lsst.afw.table.SchemaMapper(inSchema, outSchema) 

313 mapper3.addMapping(ka, "c", True) 

314 self.assertEqual(mapper3.getMapping(ka), kc) 

315 

316 def testJoin2(self): 

317 s1 = lsst.afw.table.Schema() 

318 self.assertEqual(s1.join("a", "b"), "a_b") 

319 self.assertEqual(s1.join("a", "b", "c"), "a_b_c") 

320 self.assertEqual(s1.join("a", "b", "c", "d"), "a_b_c_d") 

321 

322 def _makeMapper(self, name1="a", name2="bb", name3="ccc", name4="dddd"): 

323 """Make a SchemaMapper for testing. 

324 

325 Parameters 

326 ---------- 

327 name1 : `str`, optional 

328 Name of a field to "default map" from input to output. 

329 name2 : `str`, optional 

330 Name of a field to map from input to ``name3`` in output. 

331 name3 : `str`, optional 

332 Name of a field that is unmapped in input, and mapped from 

333 ``name2`` in output. 

334 name4 : `str`, optional 

335 Name of a field that is unmapped in output. 

336 

337 Returns 

338 ------- 

339 mapper : `lsst.afw.table.SchemaMapper` 

340 The created mapper. 

341 """ 

342 schema = lsst.afw.table.Schema() 

343 schema.addField(name1, type=float) 

344 schema.addField(name2, type=float) 

345 schema.addField(name3, type=float) 

346 schema.addField("asdf", type="Flag") 

347 mapper = lsst.afw.table.SchemaMapper(schema) 

348 

349 # add a default mapping for the first field 

350 mapper.addMapping(schema.find(name1).key) 

351 

352 # add a mapping to a new field for the second field 

353 field = lsst.afw.table.Field[float](name3, "doc for thingy") 

354 mapper.addMapping(schema.find(name2).key, field) 

355 

356 # add a totally separate field to the output 

357 mapper.addOutputField(lsst.afw.table.Field[float](name4, 'docstring')) 

358 return mapper 

359 

360 def testOperatorEquals(self): 

361 mapper1 = self._makeMapper() 

362 mapper2 = self._makeMapper() 

363 self.assertEqual(mapper1, mapper2) 

364 

365 def testNotEqualInput(self): 

366 """Check that differing input schema compare not equal.""" 

367 mapper1 = self._makeMapper(name2="somethingelse") 

368 mapper2 = self._makeMapper() 

369 # output schema should still be equal 

370 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema())) 

371 self.assertNotEqual(mapper1, mapper2) 

372 

373 def testNotEqualOutput(self): 

374 """Check that differing output schema compare not equal.""" 

375 mapper1 = self._makeMapper(name4="another") 

376 mapper2 = self._makeMapper() 

377 # input schema should still be equal 

378 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema())) 

379 self.assertNotEqual(mapper1, mapper2) 

380 

381 def testNotEqualMappings(self): 

382 """Check that differing mappings but same schema compare not equal.""" 

383 schema = lsst.afw.table.Schema() 

384 schema.addField('a', type=np.int32, doc="int") 

385 schema.addField('b', type=np.int32, doc="int") 

386 mapper1 = lsst.afw.table.SchemaMapper(schema) 

387 mapper2 = lsst.afw.table.SchemaMapper(schema) 

388 mapper1.addMapping(schema['a'].asKey(), 'c') 

389 mapper1.addMapping(schema['b'].asKey(), 'd') 

390 mapper2.addMapping(schema['b'].asKey(), 'c') 

391 mapper2.addMapping(schema['a'].asKey(), 'd') 

392 

393 # input and output schemas should still be equal 

394 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema())) 

395 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema())) 

396 self.assertNotEqual(mapper1, mapper2) 

397 

398 def testNotEqualMappingsSomeFieldsUnmapped(self): 

399 """Check that differing mappings, with some unmapped fields, but the 

400 same input and output schema compare not equal. 

401 """ 

402 schema = lsst.afw.table.Schema() 

403 schema.addField('a', type=np.int32, doc="int") 

404 schema.addField('b', type=np.int32, doc="int") 

405 mapper1 = lsst.afw.table.SchemaMapper(schema) 

406 mapper2 = lsst.afw.table.SchemaMapper(schema) 

407 mapper1.addMapping(schema['a'].asKey(), 'c') 

408 mapper1.addMapping(schema['b'].asKey(), 'd') 

409 mapper2.addMapping(schema['b'].asKey(), 'c') 

410 # add an unmapped field to output of 2 to match 1 

411 mapper2.addOutputField(lsst.afw.table.Field[np.int32]('d', doc="int")) 

412 

413 # input and output schemas should still be equal 

414 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema())) 

415 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema())) 

416 self.assertNotEqual(mapper1, mapper2) 

417 

418 def testPickle(self): 

419 schema = lsst.afw.table.Schema() 

420 addTestFields(schema) 

421 mapper = lsst.afw.table.SchemaMapper(schema) 

422 mapper.addMinimalSchema(schema) 

423 inKey = schema.addField("bb", type=float) 

424 outField = lsst.afw.table.Field[float]("cc", "doc for bb->cc") 

425 mapper.addMapping(inKey, outField, True) 

426 

427 pickled = pickle.dumps(mapper, protocol=pickle.HIGHEST_PROTOCOL) 

428 unpickled = pickle.loads(pickled) 

429 self.assertEqual(mapper, unpickled) 

430 

431 def testPickleMissingInput(self): 

432 """Test pickling with some fields not being mapped.""" 

433 mapper = self._makeMapper() 

434 

435 pickled = pickle.dumps(mapper, protocol=pickle.HIGHEST_PROTOCOL) 

436 unpickled = pickle.loads(pickled) 

437 

438 self.assertEqual(mapper, unpickled) 

439 

440 

441class MemoryTester(lsst.utils.tests.MemoryTestCase): 

442 pass 

443 

444 

445def setup_module(module): 

446 lsst.utils.tests.init() 

447 

448 

449if __name__ == "__main__": 449 ↛ 450line 449 didn't jump to line 450, because the condition on line 449 was never true

450 lsst.utils.tests.init() 

451 unittest.main()