Coverage for tests/testInstanceCatalog.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import with_statement
2from builtins import zip
3from builtins import str
4import sys
5str_type = str
6if sys.version_info.major == 2: 6 ↛ 7line 6 didn't jump to line 7, because the condition on line 6 was never true
7 from past.builtins import str as past_str
8 str_type = past_str
9from builtins import range
10import os
11import numpy as np
12import sqlite3
13import unittest
14import tempfile
15import shutil
16import lsst.utils.tests
17from lsst.sims.utils.CodeUtilities import sims_clean_up
18from lsst.sims.utils import ObservationMetaData
19from lsst.sims.catalogs.db import CatalogDBObject
20from lsst.sims.catalogs.utils import myTestStars, makeStarTestDB
21from lsst.sims.catalogs.definitions import InstanceCatalog
22from lsst.sims.utils import Site
24ROOT = os.path.abspath(os.path.dirname(__file__))
27def setup_module(module):
28 lsst.utils.tests.init()
31def createCannotBeNullTestDB(filename=None, add_nans=True, dir=None):
32 """
33 Create a database to test the 'cannot_be_null' functionality in InstanceCatalog
35 This method will return the contents of the database as a recarray for baseline comparison
36 in the unit tests.
37 If the filename is not specified, it will be written in to directory "dir" if not
38 none, else it will be written to the current directory
39 """
41 if filename is None:
42 dbName = 'cannotBeNullTest.db'
43 if dir is not None:
44 dbName = os.path.join(dir, dbName)
45 else:
46 dbName = filename
48 rng = np.random.RandomState(32)
49 dtype = np.dtype([('id', int), ('n1', np.float64), ('n2', np.float64), ('n3', np.float64),
50 ('n4', (str_type, 40)), ('n5', (str_type, 40))])
51 output = None
53 if os.path.exists(dbName):
54 os.unlink(dbName)
56 conn = sqlite3.connect(dbName)
57 c = conn.cursor()
58 try:
59 c.execute('''CREATE TABLE testTable (id int, n1 float, n2 float, n3 float, n4 text, n5 text)''')
60 conn.commit()
61 except:
62 raise RuntimeError("Error creating database.")
64 for ii in range(100):
66 values = rng.random_sample(3)
67 for i in range(len(values)):
68 draw = rng.random_sample(1)
69 if draw[0] < 0.5 and add_nans:
70 values[i] = None
72 draw = rng.random_sample(1)
73 if draw[0] < 0.5:
74 w1 = 'None'
75 else:
76 w1 = 'word'
78 draw = rng.random_sample(1)
79 if draw[0] < 0.5:
80 w2 = str('None')
81 else:
82 w2 = str('word')
84 if output is None:
85 output = np.array([(ii, values[0], values[1], values[2], w1, w2)], dtype=dtype)
86 else:
87 size = output.size
88 output = np.resize(output, size+1)
89 output[size] = (ii, values[0], values[1], values[2], w1, w2)
91 if np.isnan(values[0]) and add_nans:
92 v0 = 'NULL'
93 else:
94 v0 = str(values[0])
96 if np.isnan(values[1]) and add_nans:
97 v1 = 'NULL'
98 else:
99 v1 = str(values[1])
101 if np.isnan(values[2]) and add_nans:
102 v2 = 'NULL'
103 else:
104 v2 = str(values[2])
106 cmd = '''INSERT INTO testTable VALUES (%s, %s, %s, %s, '%s', '%s')''' % (ii, v0, v1, v2, w1, w2)
107 c.execute(cmd)
109 conn.commit()
110 conn.close()
111 return output
114class myCannotBeNullDBObject(CatalogDBObject):
115 driver = 'sqlite'
116 database = 'cannotBeNullTest.db'
117 tableid = 'testTable'
118 objid = 'cannotBeNull'
119 idColKey = 'id'
120 columns = [('n5', 'n5', str, 40)]
123class floatCannotBeNullCatalog(InstanceCatalog):
124 """
125 This catalog class will not write rows with a null value in the n2 column
126 """
127 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5']
128 cannot_be_null = ['n2']
131class strCannotBeNullCatalog(InstanceCatalog):
132 """
133 This catalog class will not write rows with a null value in the n2 column
134 """
135 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5']
136 cannot_be_null = ['n4']
139class unicodeCannotBeNullCatalog(InstanceCatalog):
140 """
141 This catalog class will not write rows with a null value in the n2 column
142 """
143 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5']
144 cannot_be_null = ['n5']
147class severalCannotBeNullCatalog(InstanceCatalog):
148 """
149 This catalog class will not write rows with null values in the n2 or n4 columns
150 """
151 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5']
152 cannot_be_null = ['n2', 'n4']
155class CanBeNullCatalog(InstanceCatalog):
156 """
157 This catalog class will write all rows to the catalog
158 """
159 column_outputs = ['id', 'n1', 'n2', 'n3', 'n4', 'n5']
160 catalog_type = 'canBeNull'
163class testStellarCatalogClass(InstanceCatalog):
164 column_outputs = ['raJ2000', 'decJ2000']
165 default_formats = {'f': '%le'}
168class cartoonValueCatalog(InstanceCatalog):
169 column_outputs = ['n1', 'n2']
170 default_formats = {'f': '%le'}
172 def get_difference(self):
173 x = self.column_by_name('n1')
174 y = self.column_by_name('n3')
175 return x-y
178class InstanceCatalogMetaDataTest(unittest.TestCase):
179 """
180 This class will test how Instance catalog handles the metadata
181 class variables (pointingRA, pointingDec, etc.)
182 """
184 @classmethod
185 def setUpClass(cls):
186 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="scratchSpace-")
187 cls.database = os.path.join(cls.scratch_dir, 'testInstanceCatalogDatabase.db')
188 makeStarTestDB(filename=cls.database)
190 @classmethod
191 def tearDownClass(cls):
192 sims_clean_up()
193 if os.path.exists(cls.scratch_dir):
194 shutil.rmtree(cls.scratch_dir)
196 def setUp(self):
197 self.myDB = myTestStars(driver='sqlite', database=self.database)
199 def tearDown(self):
200 del self.myDB
202 def testObsMetaDataAssignment(self):
203 """
204 Test that you get an error when you pass something that is not
205 ObservationMetaData as obs_metadata
206 """
208 xx = 5.0
209 self.assertRaises(ValueError, testStellarCatalogClass, self.myDB, obs_metadata=xx)
211 def testColumnArg(self):
212 """
213 A unit test to make sure that the code allowing you to add
214 new column_outputs to an InstanceCatalog using its constructor
215 works properly.
216 """
217 mjd = 5120.0
218 RA = 1.5
219 Dec = -1.1
220 rotSkyPos = -10.0
222 testSite = Site(longitude=2.0, latitude=-1.0, height=4.0,
223 temperature=100.0, pressure=500.0, humidity=0.1,
224 lapseRate=0.1)
226 testObsMD = ObservationMetaData(site=testSite,
227 mjd=mjd, pointingRA=RA,
228 pointingDec=Dec,
229 rotSkyPos=rotSkyPos,
230 bandpassName = 'z')
232 # make sure the correct column names are returned
233 # according to class definition
234 testCat = testStellarCatalogClass(self.myDB, obs_metadata=testObsMD)
235 columnsShouldBe = ['raJ2000', 'decJ2000']
236 for col in testCat.iter_column_names():
237 if col in columnsShouldBe:
238 columnsShouldBe.remove(col)
239 else:
240 raise RuntimeError('column %s returned; should not be there' % col)
242 self.assertEqual(len(columnsShouldBe), 0)
244 # make sure that new column names can be added
245 newColumns = ['properMotionRa', 'properMotionDec']
246 testCat = testStellarCatalogClass(self.myDB, obs_metadata=testObsMD, column_outputs=newColumns)
247 columnsShouldBe = ['raJ2000', 'decJ2000', 'properMotionRa', 'properMotionDec']
248 for col in testCat.iter_column_names():
249 if col in columnsShouldBe:
250 columnsShouldBe.remove(col)
251 else:
252 raise RuntimeError('column %s returned; should not be there' % col)
254 self.assertEqual(len(columnsShouldBe), 0)
256 # make sure that, if we include a duplicate column in newColumns,
257 # the column is not duplicated
258 newColumns = ['properMotionRa', 'properMotionDec', 'raJ2000']
259 testCat = testStellarCatalogClass(self.myDB, obs_metadata=testObsMD, column_outputs=newColumns)
260 columnsShouldBe = ['raJ2000', 'decJ2000', 'properMotionRa', 'properMotionDec']
262 for col in columnsShouldBe:
263 self.assertIn(col, testCat._actually_calculated_columns)
265 generatedColumns = []
266 for col in testCat.iter_column_names():
267 generatedColumns.append(col)
268 if col in columnsShouldBe:
269 columnsShouldBe.remove(col)
270 else:
271 raise RuntimeError('column %s returned; should not be there' % col)
273 self.assertEqual(len(columnsShouldBe), 0)
274 self.assertEqual(len(generatedColumns), 4)
276 cat_name = os.path.join(self.scratch_dir, 'testArgCatalog.txt')
277 testCat.write_catalog(cat_name)
278 with open(cat_name, 'r') as inCat:
279 lines = inCat.readlines()
280 header = lines[0]
281 header = header.strip('#')
282 header = header.strip('\n')
283 header = header.split(', ')
284 self.assertIn('raJ2000', header)
285 self.assertIn('decJ2000', header)
286 self.assertIn('properMotionRa', header)
287 self.assertIn('properMotionDec', header)
288 if os.path.exists(cat_name):
289 os.unlink(cat_name)
291 def testArgValues(self):
292 """
293 Test that columns added using the contructor ags return the correct value
294 """
295 with lsst.utils.tests.getTempFilePath(".db") as dbName:
296 baselineData = createCannotBeNullTestDB(filename=dbName, add_nans=False)
297 db = myCannotBeNullDBObject(driver='sqlite', database=dbName)
298 dtype = np.dtype([('n1', float), ('n2', float), ('n3', float), ('difference', float)])
299 cat = cartoonValueCatalog(db, column_outputs = ['n3', 'difference'])
301 columns = ['n1', 'n2', 'n3', 'difference']
302 for col in columns:
303 self.assertIn(col, cat._actually_calculated_columns)
305 cat_name = os.path.join(self.scratch_dir, 'cartoonValCat.txt')
306 cat.write_catalog(cat_name)
307 testData = np.genfromtxt(cat_name, dtype=dtype, delimiter=',')
308 for testLine, controlLine in zip(testData, baselineData):
309 self.assertAlmostEqual(testLine[0], controlLine['n1'], 6)
310 self.assertAlmostEqual(testLine[1], controlLine['n2'], 6)
311 self.assertAlmostEqual(testLine[2], controlLine['n3'], 6)
312 self.assertAlmostEqual(testLine[3], controlLine['n1']-controlLine['n3'], 6)
314 if os.path.exists(cat_name):
315 os.unlink(cat_name)
317 def testAllCalculatedColumns(self):
318 """
319 Unit test to make sure that _actually_calculated_columns contains all of the dependent columns
320 """
321 class otherCartoonValueCatalog(InstanceCatalog):
322 column_outputs = ['n1', 'n2', 'difference']
324 def get_difference(self):
325 n1 = self.column_by_name('n1')
326 n3 = self.column_by_name('n3')
327 return n1-n3
329 with lsst.utils.tests.getTempFilePath(".db") as dbName:
330 createCannotBeNullTestDB(filename=dbName, add_nans=False)
331 db = myCannotBeNullDBObject(driver='sqlite', database=dbName)
332 cat = otherCartoonValueCatalog(db)
333 columns = ['n1', 'n2', 'n3', 'difference']
334 for col in columns:
335 self.assertIn(col, cat._actually_calculated_columns)
338class InstanceCatalogCannotBeNullTest(unittest.TestCase):
340 def setUp(self):
341 self.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-')
342 # Force the class to understand where the DB is meant to be
343 myCannotBeNullDBObject.database = os.path.join(self.scratch_dir,
344 'cannotBeNullTest.db')
345 self.baselineOutput = createCannotBeNullTestDB(dir=self.scratch_dir)
347 def tearDown(self):
348 sims_clean_up()
349 del self.baselineOutput
350 if os.path.exists('cannotBeNullTest.db'):
351 os.unlink('cannotBeNullTest.db')
352 if os.path.exists(self.scratch_dir):
353 shutil.rmtree(self.scratch_dir)
355 def testCannotBeNull(self):
356 """
357 Test to make sure that the code for filtering out rows with null values
358 in key rows works.
359 """
361 # each of these classes flags a different column with a different datatype as cannot_be_null
362 availableCatalogs = [floatCannotBeNullCatalog, strCannotBeNullCatalog, unicodeCannotBeNullCatalog,
363 severalCannotBeNullCatalog]
364 dbobj = CatalogDBObject.from_objid('cannotBeNull')
366 ct_n2 = 0 # number of rows in floatCannotBeNullCatalog
367 ct_n4 = 0 # number of rows in strCannotBeNullCatalog
368 ct_n2_n4 = 0 # number of rows in severalCannotBeNullCatalog
370 for catClass in availableCatalogs:
371 cat = catClass(dbobj)
372 fileName = os.path.join(self.scratch_dir, 'cannotBeNullTestFile.txt')
373 cat.write_catalog(fileName)
374 dtype = np.dtype([('id', int), ('n1', np.float64), ('n2', np.float64), ('n3', np.float64),
375 ('n4', (str_type, 40)), ('n5', (str_type, 40))])
376 testData = np.genfromtxt(fileName, dtype=dtype, delimiter=',')
378 ct_good = 0 # a counter to keep track of the rows read in from the catalog
379 ct_total = len(self.baselineOutput)
381 for i in range(len(self.baselineOutput)):
383 # self.baselineOutput contains all of the rows from the dbobj
384 # first, we must assess whether or not the row we are currently
385 # testing would, in fact, pass the cannot_be_null test
386 validLine = True
387 for col_name in cat.cannot_be_null:
388 if (isinstance(self.baselineOutput[col_name][i], str) or
389 isinstance(self.baselineOutput[col_name][i], str_type)):
391 if self.baselineOutput[col_name][i].strip().lower() == 'none':
392 validLine = False
393 else:
394 if np.isnan(self.baselineOutput[col_name][i]):
395 validLine = False
397 if validLine:
398 if catClass is floatCannotBeNullCatalog:
399 ct_n2 += 1
400 elif catClass is strCannotBeNullCatalog:
401 ct_n4 += 1
402 elif catClass is severalCannotBeNullCatalog:
403 ct_n2_n4 += 1
405 # if the row in self.baslineOutput should be in the catalog, we now check
406 # that baseline and testData agree on column values (there are some gymnastics
407 # here because you cannot do an == on NaN's
408 for (k, xx) in enumerate(self.baselineOutput[i]):
409 if k < 4:
410 if not np.isnan(xx):
411 msg = ('k: %d -- %s %s -- %s' %
412 (k, str(xx), str(testData[ct_good][k]), cat.cannot_be_null))
413 self.assertAlmostEqual(xx, testData[ct_good][k], 3, msg=msg)
414 else:
415 np.testing.assert_equal(testData[ct_good][k], np.NaN)
416 else:
417 msg = ('%s (%s) is not %s (%s)' %
418 (xx, type(xx), testData[ct_good][k], type(testData[ct_good][k])))
419 self.assertEqual(xx.strip(), testData[ct_good][k].strip(), msg=msg)
420 ct_good += 1
422 self.assertEqual(ct_good, len(testData)) # make sure that we tested all of the testData rows
423 msg = '%d >= %d' % (ct_good, ct_total)
424 self.assertLess(ct_good, ct_total, msg=msg) # make sure that some rows did not make
425 # it into the catalog
427 # make sure that severalCannotBeNullCatalog weeded out rows that were individually in
428 # floatCannotBeNullCatalog or strCannotBeNullCatalog
429 self.assertGreater(ct_n2, ct_n2_n4)
430 self.assertGreater(ct_n4, ct_n2_n4)
432 if os.path.exists(fileName):
433 os.unlink(fileName)
435 def testCannotBeNull_pre_screen(self):
436 """
437 Check that writing a catalog with self._pre_screen = True produces
438 the same results as writing one with self._pre_screen = False.
439 """
441 # each of these classes flags a different column with a different datatype as cannot_be_null
442 availableCatalogs = [floatCannotBeNullCatalog, strCannotBeNullCatalog, unicodeCannotBeNullCatalog,
443 severalCannotBeNullCatalog]
444 dbobj = CatalogDBObject.from_objid('cannotBeNull')
446 for catClass in availableCatalogs:
447 cat = catClass(dbobj)
448 cat._pre_screen = True
449 control_cat = catClass(dbobj)
450 fileName = os.path.join(self.scratch_dir, 'cannotBeNullTestFile_prescreen.txt')
451 control_fileName = os.path.join(self.scratch_dir, 'cannotBeNullTestFile_prescreen_control.txt')
452 cat.write_catalog(fileName)
453 control_cat.write_catalog(control_fileName)
455 with open(fileName, 'r') as test_file:
456 test_lines = test_file.readlines()
457 with open(control_fileName, 'r') as control_file:
458 control_lines = control_file.readlines()
459 for line in control_lines:
460 self.assertIn(line, test_lines)
461 for line in test_lines:
462 self.assertIn(line, control_lines)
464 if os.path.exists(fileName):
465 os.unlink(fileName)
466 if os.path.exists(control_fileName):
467 os.unlink(control_fileName)
469 def testCanBeNull(self):
470 """
471 Test to make sure that we can still write all rows to catalogs,
472 even those with null values in key columns
473 """
474 dbobj = CatalogDBObject.from_objid('cannotBeNull')
475 cat = dbobj.getCatalog('canBeNull')
476 fileName = os.path.join(self.scratch_dir, 'canBeNullTestFile.txt')
477 cat.write_catalog(fileName)
478 dtype = np.dtype([('id', int), ('n1', np.float64), ('n2', np.float64), ('n3', np.float64),
479 ('n4', (str_type, 40)), ('n5', (str_type, 40))])
480 testData = np.genfromtxt(fileName, dtype=dtype, delimiter=',')
482 for i in range(len(self.baselineOutput)):
483 # make sure that all of the rows in self.baselineOutput are represented in
484 # testData
485 for (k, xx) in enumerate(self.baselineOutput[i]):
486 if k < 4:
487 if not np.isnan(xx):
488 self.assertAlmostEqual(xx, testData[i][k], 3)
489 else:
490 np.testing.assert_equal(testData[i][k], np.NaN)
491 else:
492 msg = '%s is not %s' % (xx, testData[i][k])
493 self.assertEqual(xx.strip(), testData[i][k].strip(), msg=msg)
495 self.assertEqual(i, 99)
496 self.assertEqual(len(testData), len(self.baselineOutput))
498 if os.path.exists(fileName):
499 os.unlink(fileName)
502class MemoryTestClass(lsst.utils.tests.MemoryTestCase):
503 pass
505if __name__ == "__main__": 505 ↛ 506line 505 didn't jump to line 506, because the condition on line 505 was never true
506 lsst.utils.tests.init()
507 unittest.main()