Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of afw. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22""" 

23Tests for table.Schema 

24 

25Run with: 

26 python test_schema.py 

27or 

28 pytest test_schema.py 

29""" 

30 

31import unittest 

32import pickle 

33 

34import numpy as np 

35 

36import lsst.utils.tests 

37import lsst.pex.exceptions 

38import lsst.geom 

39import lsst.afw.table 

40 

41 

42def _checkSchemaIdentical(schema1, schema2): 

43 return schema1.compare(schema2, lsst.afw.table.Schema.IDENTICAL) == lsst.afw.table.Schema.IDENTICAL 

44 

45 

46def addTestFields(schema): 

47 """Add Fields to the schema to test operations on each Field type. 

48 """ 

49 schema.addField("ra", type="Angle", doc="coord_ra") 

50 schema.addField("dec", type="Angle", doc="coord_dec") 

51 schema.addField("x", type="D", doc="position_x", units="pixel") 

52 schema.addField("y", type="D", doc="position_y") 

53 schema.addField("i", type="I", doc="int") 

54 schema.addField("f", type="F", doc="float", units="m2") 

55 schema.addField("flag", type="Flag", doc="a flag") 

56 schema.addField("string", type="String", doc="A string field", size=42) 

57 schema.addField("variable_string", type="String", doc="A variable-length string field", size=0) 

58 schema.addField("array", type="ArrayF", doc="An array field", size=10) 

59 schema.addField("variable_array", type="ArrayF", doc="A variable-length array field", size=0) 

60 

61 

62class SchemaTestCase(unittest.TestCase): 

63 

64 def testSchema(self): 

65 def testKey(name, key): 

66 col = schema.find(name) 

67 self.assertEqual(col.key, key) 

68 self.assertEqual(col.field.getName(), name) 

69 

70 schema = lsst.afw.table.Schema() 

71 ab_k = lsst.afw.table.CoordKey.addFields(schema, "a_b", "parent coord") 

72 abp_k = lsst.afw.table.Point2DKey.addFields( 

73 schema, "a_b_p", "point", "pixel") 

74 abi_k = schema.addField("a_b_i", type=np.int32, doc="int") 

75 acf_k = schema.addField("a_c_f", type=np.float32, doc="float") 

76 egd_k = schema.addField("e_g_d", type=lsst.geom.Angle, doc="angle") 

77 

78 # Basic test for all native key types. 

79 for name, key in (("a_b_i", abi_k), ("a_c_f", acf_k), ("e_g_d", egd_k)): 

80 testKey(name, key) 

81 

82 # Extra tests for special types 

83 self.assertEqual(ab_k.getRa(), schema["a_b_ra"].asKey()) 

84 abpx_si = schema.find("a_b_p_x") 

85 self.assertEqual(abp_k.getX(), abpx_si.key) 

86 self.assertEqual(abpx_si.field.getName(), "a_b_p_x") 

87 self.assertEqual(abpx_si.field.getDoc(), "point") 

88 self.assertEqual(abp_k.getX(), schema["a_b_p_x"].asKey()) 

89 self.assertEqual(schema.getNames(), {'a_b_dec', 'a_b_i', 'a_b_p_x', 'a_b_p_y', 'a_b_ra', 'a_c_f', 

90 'e_g_d'}) 

91 self.assertEqual(schema.getNames(True), {"a", "e"}) 

92 self.assertEqual(schema["a"].getNames(), { 

93 'b_dec', 'b_i', 'b_p_x', 'b_p_y', 'b_ra', 'c_f'}) 

94 self.assertEqual(schema["a"].getNames(True), {"b", "c"}) 

95 schema2 = lsst.afw.table.Schema(schema) 

96 self.assertEqual(schema, schema2) 

97 schema2.addField("q", type="F", doc="another double") 

98 self.assertNotEqual(schema, schema2) 

99 schema3 = lsst.afw.table.Schema() 

100 schema3.addField("ra", type="Angle", doc="coord_ra") 

101 schema3.addField("dec", type="Angle", doc="coord_dec") 

102 schema3.addField("x", type="D", doc="position_x") 

103 schema3.addField("y", type="D", doc="position_y") 

104 schema3.addField("i", type="I", doc="int") 

105 schema3.addField("f", type="F", doc="float") 

106 schema3.addField("d", type="Angle", doc="angle") 

107 self.assertEqual(schema3, schema) 

108 schema4 = lsst.afw.table.Schema() 

109 keys = [] 

110 keys.append(schema4.addField("a", type="Angle", doc="a")) 

111 keys.append(schema4.addField("b", type="Flag", doc="b")) 

112 keys.append(schema4.addField("c", type="I", doc="c")) 

113 keys.append(schema4.addField("d", type="Flag", doc="d")) 

114 self.assertEqual(keys[1].getBit(), 0) 

115 self.assertEqual(keys[3].getBit(), 1) 

116 for n1, n2 in zip(schema4.getOrderedNames(), "abcd"): 

117 self.assertEqual(n1, n2) 

118 keys2 = [x.key for x in schema4] 

119 self.assertEqual(keys, keys2) 

120 

121 def testUnits(self): 

122 schema = lsst.afw.table.Schema() 

123 # first insert some valid units 

124 schema.addField("a", type="I", units="pixel") 

125 schema.addField("b", type="I", units="m2") 

126 schema.addField("c", type="I", units="electron / adu") 

127 schema.addField("d", type="I", units="kg m s^(-2)") 

128 schema.addField("e", type="I", units="GHz / Mpc") 

129 schema.addField("f", type="Angle", units="deg") 

130 schema.addField("g", type="Angle", units="rad") 

131 schema.checkUnits() 

132 # now try inserting invalid units 

133 with self.assertRaises(ValueError): 

134 schema.addField("a", type="I", units="camel") 

135 with self.assertRaises(ValueError): 

136 schema.addField("b", type="I", units="pixels^2^2") 

137 # add invalid units in silent mode, should work fine 

138 schema.addField("h", type="I", units="lala", parse_strict='silent') 

139 # Now this check should raise because there is an invalid unit 

140 with self.assertRaises(ValueError): 

141 schema.checkUnits() 

142 

143 def testInspection(self): 

144 schema = lsst.afw.table.Schema() 

145 keys = [] 

146 keys.append(schema.addField("d", type=np.int32)) 

147 keys.append(schema.addField("c", type=np.float32)) 

148 keys.append(schema.addField("b", type="ArrayF", size=3)) 

149 keys.append(schema.addField("a", type="F")) 

150 for key, item in zip(keys, schema): 

151 self.assertEqual(item.key, key) 

152 self.assertIn(key, schema) 

153 for name in ("a", "b", "c", "d"): 

154 self.assertIn(name, schema) 

155 self.assertNotIn("e", schema) 

156 otherSchema = lsst.afw.table.Schema() 

157 otherKey = otherSchema.addField("d", type=np.float32) 

158 self.assertNotIn(otherKey, schema) 

159 self.assertNotEqual(keys[0], keys[1]) 

160 

161 def testKeyAccessors(self): 

162 schema = lsst.afw.table.Schema() 

163 arrayKey = schema.addField( 

164 "a", type="ArrayF", doc="doc for array field", size=5) 

165 arrayElementKey = arrayKey[1] 

166 self.assertEqual(lsst.afw.table.Key["F"], type(arrayElementKey)) 

167 

168 def testComparison(self): 

169 schema1 = lsst.afw.table.Schema() 

170 schema1.addField("a", type=np.float32, doc="doc for a", units="m") 

171 schema1.addField("b", type=np.int32, doc="doc for b", units="s") 

172 schema2 = lsst.afw.table.Schema() 

173 schema2.addField("a", type=np.int32, doc="doc for a", units="m") 

174 schema2.addField("b", type=np.float32, doc="doc for b", units="s") 

175 cmp1 = schema1.compare(schema2, lsst.afw.table.Schema.IDENTICAL) 

176 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_NAMES) 

177 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_DOCS) 

178 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_UNITS) 

179 self.assertFalse(cmp1 & lsst.afw.table.Schema.EQUAL_KEYS) 

180 schema3 = lsst.afw.table.Schema(schema1) 

181 schema3.addField("c", type=str, doc="doc for c", size=4) 

182 self.assertFalse(schema1.compare(schema3)) 

183 self.assertFalse(schema1.contains(schema3)) 

184 self.assertTrue(schema3.contains(schema1)) 

185 schema1.addField("d", type=str, doc="no docs!", size=4) 

186 cmp2 = schema1.compare(schema3, lsst.afw.table.Schema.IDENTICAL) 

187 self.assertFalse(cmp2 & lsst.afw.table.Schema.EQUAL_NAMES) 

188 self.assertFalse(cmp2 & lsst.afw.table.Schema.EQUAL_DOCS) 

189 self.assertTrue(cmp2 & lsst.afw.table.Schema.EQUAL_KEYS) 

190 self.assertTrue(cmp2 & lsst.afw.table.Schema.EQUAL_UNITS) 

191 self.assertFalse(schema1.compare( 

192 schema3, lsst.afw.table.Schema.EQUAL_NAMES)) 

193 

194 def testPickle(self): 

195 schema = lsst.afw.table.Schema() 

196 addTestFields(schema) 

197 

198 pickled = pickle.dumps(schema, protocol=pickle.HIGHEST_PROTOCOL) 

199 unpickled = pickle.loads(pickled) 

200 self.assertEqual(schema, unpickled) 

201 

202 

203class SchemaMapperTestCase(unittest.TestCase): 

204 

205 def testJoin(self): 

206 inputs = [lsst.afw.table.Schema(), lsst.afw.table.Schema(), 

207 lsst.afw.table.Schema()] 

208 prefixes = ["u", "v", "w"] 

209 ka = inputs[0].addField("a", type=np.float64, doc="doc for a") 

210 kb = inputs[0].addField("b", type=np.int32, doc="doc for b") 

211 kc = inputs[1].addField("c", type=np.float32, doc="doc for c") 

212 kd = inputs[2].addField("d", type=np.int64, doc="doc for d") 

213 flags1 = lsst.afw.table.Schema.IDENTICAL 

214 flags2 = flags1 & ~lsst.afw.table.Schema.EQUAL_NAMES 

215 mappers1 = lsst.afw.table.SchemaMapper.join(inputs) 

216 mappers2 = lsst.afw.table.SchemaMapper.join(inputs, prefixes) 

217 records = [lsst.afw.table.BaseTable.make(schema).makeRecord() for 

218 schema in inputs] 

219 records[0].set(ka, 3.14159) 

220 records[0].set(kb, 21623) 

221 records[1].set(kc, 1.5616) 

222 records[2].set(kd, 1261236) 

223 for mappers, flags in zip((mappers1, mappers2), (flags1, flags2)): 

224 output = lsst.afw.table.BaseTable.make( 

225 mappers[0].getOutputSchema()).makeRecord() 

226 for mapper, record in zip(mappers, records): 

227 output.assign(record, mapper) 

228 self.assertEqual( 

229 mapper.getOutputSchema().compare(output.getSchema(), 

230 flags), 

231 flags) 

232 self.assertEqual( 

233 mapper.getInputSchema().compare(record.getSchema(), 

234 flags), 

235 flags) 

236 names = output.getSchema().getOrderedNames() 

237 self.assertEqual(output.get(names[0]), records[0].get(ka)) 

238 self.assertEqual(output.get(names[1]), records[0].get(kb)) 

239 self.assertEqual(output.get(names[2]), records[1].get(kc)) 

240 self.assertEqual(output.get(names[3]), records[2].get(kd)) 

241 

242 def testMinimalSchema(self): 

243 front = lsst.afw.table.Schema() 

244 ka = front.addField("a", type=np.float64, doc="doc for a") 

245 kb = front.addField("b", type=np.int32, doc="doc for b") 

246 full = lsst.afw.table.Schema(front) 

247 kc = full.addField("c", type=np.float32, doc="doc for c") 

248 kd = full.addField("d", type=np.int64, doc="doc for d") 

249 mapper1 = lsst.afw.table.SchemaMapper(full) 

250 mapper2 = lsst.afw.table.SchemaMapper(full) 

251 mapper3 = lsst.afw.table.SchemaMapper.removeMinimalSchema(full, front) 

252 mapper1.addMinimalSchema(front) 

253 mapper2.addMinimalSchema(front, False) 

254 self.assertIn(ka, mapper1.getOutputSchema()) 

255 self.assertIn(kb, mapper1.getOutputSchema()) 

256 self.assertNotIn(kc, mapper1.getOutputSchema()) 

257 self.assertNotIn(kd, mapper1.getOutputSchema()) 

258 self.assertIn(ka, mapper2.getOutputSchema()) 

259 self.assertIn(kb, mapper2.getOutputSchema()) 

260 self.assertNotIn(kc, mapper2.getOutputSchema()) 

261 self.assertNotIn(kd, mapper2.getOutputSchema()) 

262 self.assertNotIn(ka, mapper3.getOutputSchema()) 

263 self.assertNotIn(kb, mapper3.getOutputSchema()) 

264 self.assertNotIn(kc, mapper3.getOutputSchema()) 

265 self.assertNotIn(kd, mapper3.getOutputSchema()) 

266 inputRecord = lsst.afw.table.BaseTable.make(full).makeRecord() 

267 inputRecord.set(ka, np.pi) 

268 inputRecord.set(kb, 2) 

269 inputRecord.set(kc, np.exp(1)) 

270 inputRecord.set(kd, 4) 

271 outputRecord1 = lsst.afw.table.BaseTable.make( 

272 mapper1.getOutputSchema()).makeRecord() 

273 outputRecord1.assign(inputRecord, mapper1) 

274 self.assertEqual(inputRecord.get(ka), outputRecord1.get(ka)) 

275 self.assertEqual(inputRecord.get(kb), outputRecord1.get(kb)) 

276 

277 def testOutputSchema(self): 

278 mapper = lsst.afw.table.SchemaMapper(lsst.afw.table.Schema()) 

279 out1 = mapper.getOutputSchema() 

280 out2 = mapper.editOutputSchema() 

281 k1 = out1.addField("a1", type=np.int32) 

282 self.assertNotIn(k1, mapper.getOutputSchema()) 

283 self.assertIn(k1, out1) 

284 self.assertNotIn(k1, out2) 

285 k2 = mapper.addOutputField( 

286 lsst.afw.table.Field[np.float32]("a2", "doc for a2")) 

287 self.assertNotIn(k2, out1) 

288 self.assertIn(k2, mapper.getOutputSchema()) 

289 self.assertIn(k2, out2) 

290 k3 = out2.addField("a3", type=np.float32, doc="doc for a3") 

291 self.assertNotIn(k3, out1) 

292 self.assertIn(k3, mapper.getOutputSchema()) 

293 self.assertIn(k3, out2) 

294 self.assertIn(k2, out2) 

295 

296 def testDoReplace(self): 

297 inSchema = lsst.afw.table.Schema() 

298 ka = inSchema.addField("a", type=np.int32) 

299 outSchema = lsst.afw.table.Schema(inSchema) 

300 kb = outSchema.addField("b", type=np.int32) 

301 kc = outSchema.addField("c", type=np.int32) 

302 mapper1 = lsst.afw.table.SchemaMapper(inSchema, outSchema) 

303 mapper1.addMapping(ka, True) 

304 self.assertEqual(mapper1.getMapping(ka), ka) 

305 mapper2 = lsst.afw.table.SchemaMapper(inSchema, outSchema) 

306 mapper2.addMapping( 

307 ka, lsst.afw.table.Field[np.int32]("b", "doc for b"), True) 

308 self.assertEqual(mapper2.getMapping(ka), kb) 

309 mapper3 = lsst.afw.table.SchemaMapper(inSchema, outSchema) 

310 mapper3.addMapping(ka, "c", True) 

311 self.assertEqual(mapper3.getMapping(ka), kc) 

312 

313 def testJoin2(self): 

314 s1 = lsst.afw.table.Schema() 

315 self.assertEqual(s1.join("a", "b"), "a_b") 

316 self.assertEqual(s1.join("a", "b", "c"), "a_b_c") 

317 self.assertEqual(s1.join("a", "b", "c", "d"), "a_b_c_d") 

318 

319 def _makeMapper(self, name1="a", name2="bb", name3="ccc", name4="dddd"): 

320 """Make a SchemaMapper for testing. 

321 

322 Parameters 

323 ---------- 

324 name1 : `str`, optional 

325 Name of a field to "default map" from input to output. 

326 name2 : `str`, optional 

327 Name of a field to map from input to ``name3`` in output. 

328 name3 : `str`, optional 

329 Name of a field that is unmapped in input, and mapped from 

330 ``name2`` in output. 

331 name4 : `str`, optional 

332 Name of a field that is unmapped in output. 

333 

334 Returns 

335 ------- 

336 mapper : `lsst.afw.table.SchemaMapper` 

337 The created mapper. 

338 """ 

339 schema = lsst.afw.table.Schema() 

340 schema.addField(name1, type=float) 

341 schema.addField(name2, type=float) 

342 schema.addField(name3, type=float) 

343 schema.addField("asdf", type="Flag") 

344 mapper = lsst.afw.table.SchemaMapper(schema) 

345 

346 # add a default mapping for the first field 

347 mapper.addMapping(schema.find(name1).key) 

348 

349 # add a mapping to a new field for the second field 

350 field = lsst.afw.table.Field[float](name3, "doc for thingy") 

351 mapper.addMapping(schema.find(name2).key, field) 

352 

353 # add a totally separate field to the output 

354 mapper.addOutputField(lsst.afw.table.Field[float](name4, 'docstring')) 

355 return mapper 

356 

357 def testOperatorEquals(self): 

358 mapper1 = self._makeMapper() 

359 mapper2 = self._makeMapper() 

360 self.assertEqual(mapper1, mapper2) 

361 

362 def testNotEqualInput(self): 

363 """Check that differing input schema compare not equal.""" 

364 mapper1 = self._makeMapper(name2="somethingelse") 

365 mapper2 = self._makeMapper() 

366 # output schema should still be equal 

367 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema())) 

368 self.assertNotEqual(mapper1, mapper2) 

369 

370 def testNotEqualOutput(self): 

371 """Check that differing output schema compare not equal.""" 

372 mapper1 = self._makeMapper(name4="another") 

373 mapper2 = self._makeMapper() 

374 # input schema should still be equal 

375 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema())) 

376 self.assertNotEqual(mapper1, mapper2) 

377 

378 def testNotEqualMappings(self): 

379 """Check that differing mappings but same schema compare not equal.""" 

380 schema = lsst.afw.table.Schema() 

381 schema.addField('a', type=np.int32, doc="int") 

382 schema.addField('b', type=np.int32, doc="int") 

383 mapper1 = lsst.afw.table.SchemaMapper(schema) 

384 mapper2 = lsst.afw.table.SchemaMapper(schema) 

385 mapper1.addMapping(schema['a'].asKey(), 'c') 

386 mapper1.addMapping(schema['b'].asKey(), 'd') 

387 mapper2.addMapping(schema['b'].asKey(), 'c') 

388 mapper2.addMapping(schema['a'].asKey(), 'd') 

389 

390 # input and output schemas should still be equal 

391 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema())) 

392 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema())) 

393 self.assertNotEqual(mapper1, mapper2) 

394 

395 def testNotEqualMappingsSomeFieldsUnmapped(self): 

396 """Check that differing mappings, with some unmapped fields, but the 

397 same input and output schema compare not equal. 

398 """ 

399 schema = lsst.afw.table.Schema() 

400 schema.addField('a', type=np.int32, doc="int") 

401 schema.addField('b', type=np.int32, doc="int") 

402 mapper1 = lsst.afw.table.SchemaMapper(schema) 

403 mapper2 = lsst.afw.table.SchemaMapper(schema) 

404 mapper1.addMapping(schema['a'].asKey(), 'c') 

405 mapper1.addMapping(schema['b'].asKey(), 'd') 

406 mapper2.addMapping(schema['b'].asKey(), 'c') 

407 # add an unmapped field to output of 2 to match 1 

408 mapper2.addOutputField(lsst.afw.table.Field[np.int32]('d', doc="int")) 

409 

410 # input and output schemas should still be equal 

411 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema())) 

412 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema())) 

413 self.assertNotEqual(mapper1, mapper2) 

414 

415 def testPickle(self): 

416 schema = lsst.afw.table.Schema() 

417 addTestFields(schema) 

418 mapper = lsst.afw.table.SchemaMapper(schema) 

419 mapper.addMinimalSchema(schema) 

420 inKey = schema.addField("bb", type=float) 

421 outField = lsst.afw.table.Field[float]("cc", "doc for bb->cc") 

422 mapper.addMapping(inKey, outField, True) 

423 

424 pickled = pickle.dumps(mapper, protocol=pickle.HIGHEST_PROTOCOL) 

425 unpickled = pickle.loads(pickled) 

426 self.assertEqual(mapper, unpickled) 

427 

428 def testPickleMissingInput(self): 

429 """Test pickling with some fields not being mapped.""" 

430 mapper = self._makeMapper() 

431 

432 pickled = pickle.dumps(mapper, protocol=pickle.HIGHEST_PROTOCOL) 

433 unpickled = pickle.loads(pickled) 

434 

435 self.assertEqual(mapper, unpickled) 

436 

437 

438class MemoryTester(lsst.utils.tests.MemoryTestCase): 

439 pass 

440 

441 

442def setup_module(module): 

443 lsst.utils.tests.init() 

444 

445 

446if __name__ == "__main__": 446 ↛ 447line 446 didn't jump to line 447, because the condition on line 446 was never true

447 lsst.utils.tests.init() 

448 unittest.main()