Coverage for tests/test_schema.py : 11%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of afw.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""
23Tests for table.Schema
25Run with:
26 python test_schema.py
27or
28 pytest test_schema.py
29"""
31import unittest
32import pickle
34import numpy as np
36import lsst.utils.tests
37import lsst.pex.exceptions
38import lsst.geom
39import lsst.afw.table
42def _checkSchemaIdentical(schema1, schema2):
43 return schema1.compare(schema2, lsst.afw.table.Schema.IDENTICAL) == lsst.afw.table.Schema.IDENTICAL
46def addTestFields(schema):
47 """Add Fields to the schema to test operations on each Field type.
48 """
49 schema.addField("ra", type="Angle", doc="coord_ra")
50 schema.addField("dec", type="Angle", doc="coord_dec")
51 schema.addField("x", type="D", doc="position_x", units="pixel")
52 schema.addField("y", type="D", doc="position_y")
53 schema.addField("i", type="I", doc="int")
54 schema.addField("f", type="F", doc="float", units="m2")
55 schema.addField("flag", type="Flag", doc="a flag")
56 schema.addField("string", type="String", doc="A string field", size=42)
57 schema.addField("variable_string", type="String", doc="A variable-length string field", size=0)
58 schema.addField("array", type="ArrayF", doc="An array field", size=10)
59 schema.addField("variable_array", type="ArrayF", doc="A variable-length array field", size=0)
62class SchemaTestCase(unittest.TestCase):
64 def testSchema(self):
65 def testKey(name, key):
66 col = schema.find(name)
67 self.assertEqual(col.key, key)
68 self.assertEqual(col.field.getName(), name)
70 schema = lsst.afw.table.Schema()
71 ab_k = lsst.afw.table.CoordKey.addFields(schema, "a_b", "parent coord")
72 abp_k = lsst.afw.table.Point2DKey.addFields(
73 schema, "a_b_p", "point", "pixel")
74 abi_k = schema.addField("a_b_i", type=np.int32, doc="int")
75 acf_k = schema.addField("a_c_f", type=np.float32, doc="float")
76 egd_k = schema.addField("e_g_d", type=lsst.geom.Angle, doc="angle")
78 # Basic test for all native key types.
79 for name, key in (("a_b_i", abi_k), ("a_c_f", acf_k), ("e_g_d", egd_k)):
80 testKey(name, key)
82 # Extra tests for special types
83 self.assertEqual(ab_k.getRa(), schema["a_b_ra"].asKey())
84 abpx_si = schema.find("a_b_p_x")
85 self.assertEqual(abp_k.getX(), abpx_si.key)
86 self.assertEqual(abpx_si.field.getName(), "a_b_p_x")
87 self.assertEqual(abpx_si.field.getDoc(), "point")
88 self.assertEqual(abp_k.getX(), schema["a_b_p_x"].asKey())
89 self.assertEqual(schema.getNames(), {'a_b_dec', 'a_b_i', 'a_b_p_x', 'a_b_p_y', 'a_b_ra', 'a_c_f',
90 'e_g_d'})
91 self.assertEqual(schema.getNames(True), {"a", "e"})
92 self.assertEqual(schema["a"].getNames(), {
93 'b_dec', 'b_i', 'b_p_x', 'b_p_y', 'b_ra', 'c_f'})
94 self.assertEqual(schema["a"].getNames(True), {"b", "c"})
95 schema2 = lsst.afw.table.Schema(schema)
96 self.assertEqual(schema, schema2)
97 schema2.addField("q", type="F", doc="another double")
98 self.assertNotEqual(schema, schema2)
99 schema3 = lsst.afw.table.Schema()
100 schema3.addField("ra", type="Angle", doc="coord_ra")
101 schema3.addField("dec", type="Angle", doc="coord_dec")
102 schema3.addField("x", type="D", doc="position_x")
103 schema3.addField("y", type="D", doc="position_y")
104 schema3.addField("i", type="I", doc="int")
105 schema3.addField("f", type="F", doc="float")
106 schema3.addField("d", type="Angle", doc="angle")
107 self.assertEqual(schema3, schema)
108 schema4 = lsst.afw.table.Schema()
109 keys = []
110 keys.append(schema4.addField("a", type="Angle", doc="a"))
111 keys.append(schema4.addField("b", type="Flag", doc="b"))
112 keys.append(schema4.addField("c", type="I", doc="c"))
113 keys.append(schema4.addField("d", type="Flag", doc="d"))
114 self.assertEqual(keys[1].getBit(), 0)
115 self.assertEqual(keys[3].getBit(), 1)
116 for n1, n2 in zip(schema4.getOrderedNames(), "abcd"):
117 self.assertEqual(n1, n2)
118 keys2 = [x.key for x in schema4]
119 self.assertEqual(keys, keys2)
121 def testUnits(self):
122 schema = lsst.afw.table.Schema()
123 # first insert some valid units
124 schema.addField("a", type="I", units="pixel")
125 schema.addField("b", type="I", units="m2")
126 schema.addField("c", type="I", units="electron / adu")
127 schema.addField("d", type="I", units="kg m s^(-2)")
128 schema.addField("e", type="I", units="GHz / Mpc")
129 schema.addField("f", type="Angle", units="deg")
130 schema.addField("g", type="Angle", units="rad")
131 schema.checkUnits()
132 # now try inserting invalid units
133 with self.assertRaises(ValueError):
134 schema.addField("a", type="I", units="camel")
135 with self.assertRaises(ValueError):
136 schema.addField("b", type="I", units="pixels^2^2")
137 # add invalid units in silent mode, should work fine
138 schema.addField("h", type="I", units="lala", parse_strict='silent')
139 # Now this check should raise because there is an invalid unit
140 with self.assertRaises(ValueError):
141 schema.checkUnits()
143 def testInspection(self):
144 schema = lsst.afw.table.Schema()
145 keys = []
146 keys.append(schema.addField("d", type=np.int32))
147 keys.append(schema.addField("c", type=np.float32))
148 keys.append(schema.addField("b", type="ArrayF", size=3))
149 keys.append(schema.addField("a", type="F"))
150 for key, item in zip(keys, schema):
151 self.assertEqual(item.key, key)
152 self.assertIn(key, schema)
153 for name in ("a", "b", "c", "d"):
154 self.assertIn(name, schema)
155 self.assertNotIn("e", schema)
156 otherSchema = lsst.afw.table.Schema()
157 otherKey = otherSchema.addField("d", type=np.float32)
158 self.assertNotIn(otherKey, schema)
159 self.assertNotEqual(keys[0], keys[1])
161 def testKeyAccessors(self):
162 schema = lsst.afw.table.Schema()
163 arrayKey = schema.addField(
164 "a", type="ArrayF", doc="doc for array field", size=5)
165 arrayElementKey = arrayKey[1]
166 self.assertEqual(lsst.afw.table.Key["F"], type(arrayElementKey))
168 def testComparison(self):
169 schema1 = lsst.afw.table.Schema()
170 schema1.addField("a", type=np.float32, doc="doc for a", units="m")
171 schema1.addField("b", type=np.int32, doc="doc for b", units="s")
172 schema2 = lsst.afw.table.Schema()
173 schema2.addField("a", type=np.int32, doc="doc for a", units="m")
174 schema2.addField("b", type=np.float32, doc="doc for b", units="s")
175 cmp1 = schema1.compare(schema2, lsst.afw.table.Schema.IDENTICAL)
176 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_NAMES)
177 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_DOCS)
178 self.assertTrue(cmp1 & lsst.afw.table.Schema.EQUAL_UNITS)
179 self.assertFalse(cmp1 & lsst.afw.table.Schema.EQUAL_KEYS)
180 schema3 = lsst.afw.table.Schema(schema1)
181 schema3.addField("c", type=str, doc="doc for c", size=4)
182 self.assertFalse(schema1.compare(schema3))
183 self.assertFalse(schema1.contains(schema3))
184 self.assertTrue(schema3.contains(schema1))
185 schema1.addField("d", type=str, doc="no docs!", size=4)
186 cmp2 = schema1.compare(schema3, lsst.afw.table.Schema.IDENTICAL)
187 self.assertFalse(cmp2 & lsst.afw.table.Schema.EQUAL_NAMES)
188 self.assertFalse(cmp2 & lsst.afw.table.Schema.EQUAL_DOCS)
189 self.assertTrue(cmp2 & lsst.afw.table.Schema.EQUAL_KEYS)
190 self.assertTrue(cmp2 & lsst.afw.table.Schema.EQUAL_UNITS)
191 self.assertFalse(schema1.compare(
192 schema3, lsst.afw.table.Schema.EQUAL_NAMES))
194 def testPickle(self):
195 schema = lsst.afw.table.Schema()
196 addTestFields(schema)
198 pickled = pickle.dumps(schema, protocol=pickle.HIGHEST_PROTOCOL)
199 unpickled = pickle.loads(pickled)
200 self.assertEqual(schema, unpickled)
203class SchemaMapperTestCase(unittest.TestCase):
205 def testJoin(self):
206 inputs = [lsst.afw.table.Schema(), lsst.afw.table.Schema(),
207 lsst.afw.table.Schema()]
208 prefixes = ["u", "v", "w"]
209 ka = inputs[0].addField("a", type=np.float64, doc="doc for a")
210 kb = inputs[0].addField("b", type=np.int32, doc="doc for b")
211 kc = inputs[1].addField("c", type=np.float32, doc="doc for c")
212 kd = inputs[2].addField("d", type=np.int64, doc="doc for d")
213 flags1 = lsst.afw.table.Schema.IDENTICAL
214 flags2 = flags1 & ~lsst.afw.table.Schema.EQUAL_NAMES
215 mappers1 = lsst.afw.table.SchemaMapper.join(inputs)
216 mappers2 = lsst.afw.table.SchemaMapper.join(inputs, prefixes)
217 records = [lsst.afw.table.BaseTable.make(schema).makeRecord() for
218 schema in inputs]
219 records[0].set(ka, 3.14159)
220 records[0].set(kb, 21623)
221 records[1].set(kc, 1.5616)
222 records[2].set(kd, 1261236)
223 for mappers, flags in zip((mappers1, mappers2), (flags1, flags2)):
224 output = lsst.afw.table.BaseTable.make(
225 mappers[0].getOutputSchema()).makeRecord()
226 for mapper, record in zip(mappers, records):
227 output.assign(record, mapper)
228 self.assertEqual(
229 mapper.getOutputSchema().compare(output.getSchema(),
230 flags),
231 flags)
232 self.assertEqual(
233 mapper.getInputSchema().compare(record.getSchema(),
234 flags),
235 flags)
236 names = output.getSchema().getOrderedNames()
237 self.assertEqual(output.get(names[0]), records[0].get(ka))
238 self.assertEqual(output.get(names[1]), records[0].get(kb))
239 self.assertEqual(output.get(names[2]), records[1].get(kc))
240 self.assertEqual(output.get(names[3]), records[2].get(kd))
242 def testMinimalSchema(self):
243 front = lsst.afw.table.Schema()
244 ka = front.addField("a", type=np.float64, doc="doc for a")
245 kb = front.addField("b", type=np.int32, doc="doc for b")
246 full = lsst.afw.table.Schema(front)
247 kc = full.addField("c", type=np.float32, doc="doc for c")
248 kd = full.addField("d", type=np.int64, doc="doc for d")
249 mapper1 = lsst.afw.table.SchemaMapper(full)
250 mapper2 = lsst.afw.table.SchemaMapper(full)
251 mapper3 = lsst.afw.table.SchemaMapper.removeMinimalSchema(full, front)
252 mapper1.addMinimalSchema(front)
253 mapper2.addMinimalSchema(front, False)
254 self.assertIn(ka, mapper1.getOutputSchema())
255 self.assertIn(kb, mapper1.getOutputSchema())
256 self.assertNotIn(kc, mapper1.getOutputSchema())
257 self.assertNotIn(kd, mapper1.getOutputSchema())
258 self.assertIn(ka, mapper2.getOutputSchema())
259 self.assertIn(kb, mapper2.getOutputSchema())
260 self.assertNotIn(kc, mapper2.getOutputSchema())
261 self.assertNotIn(kd, mapper2.getOutputSchema())
262 self.assertNotIn(ka, mapper3.getOutputSchema())
263 self.assertNotIn(kb, mapper3.getOutputSchema())
264 self.assertNotIn(kc, mapper3.getOutputSchema())
265 self.assertNotIn(kd, mapper3.getOutputSchema())
266 inputRecord = lsst.afw.table.BaseTable.make(full).makeRecord()
267 inputRecord.set(ka, np.pi)
268 inputRecord.set(kb, 2)
269 inputRecord.set(kc, np.exp(1))
270 inputRecord.set(kd, 4)
271 outputRecord1 = lsst.afw.table.BaseTable.make(
272 mapper1.getOutputSchema()).makeRecord()
273 outputRecord1.assign(inputRecord, mapper1)
274 self.assertEqual(inputRecord.get(ka), outputRecord1.get(ka))
275 self.assertEqual(inputRecord.get(kb), outputRecord1.get(kb))
277 def testOutputSchema(self):
278 mapper = lsst.afw.table.SchemaMapper(lsst.afw.table.Schema())
279 out1 = mapper.getOutputSchema()
280 out2 = mapper.editOutputSchema()
281 k1 = out1.addField("a1", type=np.int32)
282 self.assertNotIn(k1, mapper.getOutputSchema())
283 self.assertIn(k1, out1)
284 self.assertNotIn(k1, out2)
285 k2 = mapper.addOutputField(
286 lsst.afw.table.Field[np.float32]("a2", "doc for a2"))
287 self.assertNotIn(k2, out1)
288 self.assertIn(k2, mapper.getOutputSchema())
289 self.assertIn(k2, out2)
290 k3 = out2.addField("a3", type=np.float32, doc="doc for a3")
291 self.assertNotIn(k3, out1)
292 self.assertIn(k3, mapper.getOutputSchema())
293 self.assertIn(k3, out2)
294 self.assertIn(k2, out2)
296 def testDoReplace(self):
297 inSchema = lsst.afw.table.Schema()
298 ka = inSchema.addField("a", type=np.int32)
299 outSchema = lsst.afw.table.Schema(inSchema)
300 kb = outSchema.addField("b", type=np.int32)
301 kc = outSchema.addField("c", type=np.int32)
302 mapper1 = lsst.afw.table.SchemaMapper(inSchema, outSchema)
303 mapper1.addMapping(ka, True)
304 self.assertEqual(mapper1.getMapping(ka), ka)
305 mapper2 = lsst.afw.table.SchemaMapper(inSchema, outSchema)
306 mapper2.addMapping(
307 ka, lsst.afw.table.Field[np.int32]("b", "doc for b"), True)
308 self.assertEqual(mapper2.getMapping(ka), kb)
309 mapper3 = lsst.afw.table.SchemaMapper(inSchema, outSchema)
310 mapper3.addMapping(ka, "c", True)
311 self.assertEqual(mapper3.getMapping(ka), kc)
313 def testJoin2(self):
314 s1 = lsst.afw.table.Schema()
315 self.assertEqual(s1.join("a", "b"), "a_b")
316 self.assertEqual(s1.join("a", "b", "c"), "a_b_c")
317 self.assertEqual(s1.join("a", "b", "c", "d"), "a_b_c_d")
319 def _makeMapper(self, name1="a", name2="bb", name3="ccc", name4="dddd"):
320 """Make a SchemaMapper for testing.
322 Parameters
323 ----------
324 name1 : `str`, optional
325 Name of a field to "default map" from input to output.
326 name2 : `str`, optional
327 Name of a field to map from input to ``name3`` in output.
328 name3 : `str`, optional
329 Name of a field that is unmapped in input, and mapped from
330 ``name2`` in output.
331 name4 : `str`, optional
332 Name of a field that is unmapped in output.
334 Returns
335 -------
336 mapper : `lsst.afw.table.SchemaMapper`
337 The created mapper.
338 """
339 schema = lsst.afw.table.Schema()
340 schema.addField(name1, type=float)
341 schema.addField(name2, type=float)
342 schema.addField(name3, type=float)
343 schema.addField("asdf", type="Flag")
344 mapper = lsst.afw.table.SchemaMapper(schema)
346 # add a default mapping for the first field
347 mapper.addMapping(schema.find(name1).key)
349 # add a mapping to a new field for the second field
350 field = lsst.afw.table.Field[float](name3, "doc for thingy")
351 mapper.addMapping(schema.find(name2).key, field)
353 # add a totally separate field to the output
354 mapper.addOutputField(lsst.afw.table.Field[float](name4, 'docstring'))
355 return mapper
357 def testOperatorEquals(self):
358 mapper1 = self._makeMapper()
359 mapper2 = self._makeMapper()
360 self.assertEqual(mapper1, mapper2)
362 def testNotEqualInput(self):
363 """Check that differing input schema compare not equal."""
364 mapper1 = self._makeMapper(name2="somethingelse")
365 mapper2 = self._makeMapper()
366 # output schema should still be equal
367 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema()))
368 self.assertNotEqual(mapper1, mapper2)
370 def testNotEqualOutput(self):
371 """Check that differing output schema compare not equal."""
372 mapper1 = self._makeMapper(name4="another")
373 mapper2 = self._makeMapper()
374 # input schema should still be equal
375 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema()))
376 self.assertNotEqual(mapper1, mapper2)
378 def testNotEqualMappings(self):
379 """Check that differing mappings but same schema compare not equal."""
380 schema = lsst.afw.table.Schema()
381 schema.addField('a', type=np.int32, doc="int")
382 schema.addField('b', type=np.int32, doc="int")
383 mapper1 = lsst.afw.table.SchemaMapper(schema)
384 mapper2 = lsst.afw.table.SchemaMapper(schema)
385 mapper1.addMapping(schema['a'].asKey(), 'c')
386 mapper1.addMapping(schema['b'].asKey(), 'd')
387 mapper2.addMapping(schema['b'].asKey(), 'c')
388 mapper2.addMapping(schema['a'].asKey(), 'd')
390 # input and output schemas should still be equal
391 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema()))
392 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema()))
393 self.assertNotEqual(mapper1, mapper2)
395 def testNotEqualMappingsSomeFieldsUnmapped(self):
396 """Check that differing mappings, with some unmapped fields, but the
397 same input and output schema compare not equal.
398 """
399 schema = lsst.afw.table.Schema()
400 schema.addField('a', type=np.int32, doc="int")
401 schema.addField('b', type=np.int32, doc="int")
402 mapper1 = lsst.afw.table.SchemaMapper(schema)
403 mapper2 = lsst.afw.table.SchemaMapper(schema)
404 mapper1.addMapping(schema['a'].asKey(), 'c')
405 mapper1.addMapping(schema['b'].asKey(), 'd')
406 mapper2.addMapping(schema['b'].asKey(), 'c')
407 # add an unmapped field to output of 2 to match 1
408 mapper2.addOutputField(lsst.afw.table.Field[np.int32]('d', doc="int"))
410 # input and output schemas should still be equal
411 self.assertTrue(_checkSchemaIdentical(mapper1.getInputSchema(), mapper2.getInputSchema()))
412 self.assertTrue(_checkSchemaIdentical(mapper1.getOutputSchema(), mapper2.getOutputSchema()))
413 self.assertNotEqual(mapper1, mapper2)
415 def testPickle(self):
416 schema = lsst.afw.table.Schema()
417 addTestFields(schema)
418 mapper = lsst.afw.table.SchemaMapper(schema)
419 mapper.addMinimalSchema(schema)
420 inKey = schema.addField("bb", type=float)
421 outField = lsst.afw.table.Field[float]("cc", "doc for bb->cc")
422 mapper.addMapping(inKey, outField, True)
424 pickled = pickle.dumps(mapper, protocol=pickle.HIGHEST_PROTOCOL)
425 unpickled = pickle.loads(pickled)
426 self.assertEqual(mapper, unpickled)
428 def testPickleMissingInput(self):
429 """Test pickling with some fields not being mapped."""
430 mapper = self._makeMapper()
432 pickled = pickle.dumps(mapper, protocol=pickle.HIGHEST_PROTOCOL)
433 unpickled = pickle.loads(pickled)
435 self.assertEqual(mapper, unpickled)
438class MemoryTester(lsst.utils.tests.MemoryTestCase):
439 pass
442def setup_module(module):
443 lsst.utils.tests.init()
446if __name__ == "__main__": 446 ↛ 447line 446 didn't jump to line 447, because the condition on line 446 was never true
447 lsst.utils.tests.init()
448 unittest.main()