Coverage for tests/testColumnOrigins.py : 39%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from builtins import str
2from builtins import range
3from builtins import object
4import os
5import tempfile
6import shutil
7import numpy as np
8import unittest
9import sqlite3
10import lsst.utils.tests
11from lsst.sims.catalogs.definitions import InstanceCatalog
12from lsst.sims.catalogs.decorators import cached, compound
13from lsst.sims.catalogs.db import CatalogDBObject
16ROOT = os.path.abspath(os.path.dirname(__file__))
19def setup_module(module):
20 lsst.utils.tests.init()
23def makeTestDB(name, size=10, **kwargs):
24 """
25 Make a test database
27 @param [in] name is a string indicating the name of the database file
28 to be created
30 @param [in] size is an int indicating the number of objects to include
31 in the database (default=10)
32 """
33 conn = sqlite3.connect(name)
34 c = conn.cursor()
35# try:
36 c.execute('''CREATE TABLE testTable
37 (id int, aa float, bb float, ra float, decl float)''')
38 conn.commit()
39# except Exception:
40# raise RuntimeError("Error creating database.")
42 for i in range(size):
44 ra = np.random.sample()*360.0
45 dec = (np.random.sample()-0.5)*180.0
47 # insert the row into the data base
48 qstr = '''INSERT INTO testTable VALUES (%i, '%f', '%f', '%f', '%f')''' % (i, 2.0*i, 3.0*i, ra, dec)
49 c.execute(qstr)
51 conn.commit()
52 conn.close()
55class testDBObject(CatalogDBObject):
56 objid = 'testDBObject'
57 tableid = 'testTable'
58 idColKey = 'id'
59 # Make this implausibly large?
60 appendint = 1023
61 database = 'colOriginsTestDatabase.db'
62 driver = 'sqlite'
63 raColName = 'ra'
64 decColName = 'decl'
65 columns = [('objid', 'id', int),
66 ('raJ2000', 'ra*%f'%(np.pi/180.)),
67 ('decJ2000', 'decl*%f'%(np.pi/180.)),
68 ('aa', None),
69 ('bb', None)]
72# Below we define mixins which calculate the variables 'cc' and 'dd in different
73# ways. The idea is to see if InstanceCatalog correctly identifies where
74# the columns come from in those cases
75class mixin1(object):
76 @cached
77 def get_cc(self):
78 aa = self.column_by_name('aa')
79 bb = self.column_by_name('bb')
81 return np.array(aa-bb)
83 @cached
84 def get_dd(self):
85 aa = self.column_by_name('aa')
86 bb = self.column_by_name('bb')
88 return np.array(aa+bb)
91class mixin2(object):
92 @compound('cc', 'dd')
93 def get_both(self):
94 aa = self.column_by_name('aa')
95 bb = self.column_by_name('bb')
97 return np.array([aa-bb, aa+bb])
100class mixin3(object):
101 @cached
102 def get_cc(self):
103 aa = self.column_by_name('aa')
104 bb = self.column_by_name('bb')
106 return np.array(aa-bb)
109# Below we define catalog classes that use different combinations
110# of the mixins above to calculate the columns 'cc' and 'dd'
111class testCatalogDefaults(InstanceCatalog):
112 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000']
113 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)]
116class testCatalogMixin1(InstanceCatalog, mixin1):
117 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000']
118 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)]
121class testCatalogMixin2(InstanceCatalog, mixin2):
122 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000']
123 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)]
126class testCatalogMixin3(InstanceCatalog, mixin3):
127 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000']
128 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)]
131class testCatalogMixin3Mixin1(InstanceCatalog, mixin3, mixin1):
132 column_outputs = ['objid', 'aa', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000']
133 default_columns = [('cc', 0.0, float), ('dd', 1.0, float)]
136class testCatalogAunspecified(InstanceCatalog, mixin3, mixin1):
137 column_outputs = ['objid', 'bb', 'cc', 'dd', 'raJ2000', 'decJ2000']
138 default_columns = [('aa', -1.0, float), ('cc', 0.0, float), ('dd', 1.0, float)]
141class testColumnOrigins(unittest.TestCase):
143 @classmethod
144 def setUpClass(cls):
145 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-')
146 cls.dbName = os.path.join(cls.scratch_dir, 'colOriginsTestDatabase.db')
147 makeTestDB(cls.dbName)
149 @classmethod
150 def tearDownClass(cls):
151 if os.path.exists(cls.scratch_dir):
152 shutil.rmtree(cls.scratch_dir)
154 def setUp(self):
155 self.myDBobject = testDBObject(database=self.dbName)
157 def tearDown(self):
158 del self.myDBobject
160 def testDefaults(self):
161 """
162 Test case where the columns cc and dd come from defaults
163 """
164 myCatalog = testCatalogDefaults(self.myDBobject)
166 self.assertEqual(myCatalog._column_origins['objid'], 'the database')
167 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database')
168 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database')
169 self.assertEqual(myCatalog._column_origins['aa'], 'the database')
170 self.assertEqual(myCatalog._column_origins['bb'], 'the database')
171 self.assertEqual(myCatalog._column_origins['cc'], 'default column')
172 self.assertEqual(myCatalog._column_origins['dd'], 'default column')
174 def testMixin1(self):
175 """
176 Test case where the columns cc and dd come from non-compound getters
177 """
178 myCatalog = testCatalogMixin1(self.myDBobject)
180 self.assertEqual(myCatalog._column_origins['objid'], 'the database')
181 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database')
182 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database')
183 self.assertEqual(myCatalog._column_origins['aa'], 'the database')
184 self.assertEqual(myCatalog._column_origins['bb'], 'the database')
186 # test that the last string in the column origin name refers to the correct mixin
187 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin1')
188 self.assertEqual(str(myCatalog._column_origins['dd']).replace("'>", '').split('.')[-1], 'mixin1')
190 def testMixin2(self):
191 """
192 Test case where the columns cc and dd come from a compound getter
193 """
194 myCatalog = testCatalogMixin2(self.myDBobject)
196 self.assertEqual(myCatalog._column_origins['objid'], 'the database')
197 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database')
198 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database')
199 self.assertEqual(myCatalog._column_origins['aa'], 'the database')
200 self.assertEqual(myCatalog._column_origins['bb'], 'the database')
202 # test that the final string in the column origins name refers to the mixin
203 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin2')
204 self.assertEqual(str(myCatalog._column_origins['dd']).replace("'>", '').split('.')[-1], 'mixin2')
206 def testMixin3(self):
207 """
208 Test case where cc comes from a mixin and dd comes from the default
209 """
210 myCatalog = testCatalogMixin3(self.myDBobject)
212 self.assertEqual(myCatalog._column_origins['objid'], 'the database')
213 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database')
214 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database')
215 self.assertEqual(myCatalog._column_origins['aa'], 'the database')
216 self.assertEqual(myCatalog._column_origins['bb'], 'the database')
218 # test that the final string in the column origins name refers to the correct origin
219 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin3')
220 self.assertEqual(str(myCatalog._column_origins['dd']), 'default column')
222 def testMixin3Mixin1(self):
223 """
224 Test case where one mixin overwrites another for calculating cc
225 """
226 myCatalog = testCatalogMixin3Mixin1(self.myDBobject)
228 self.assertEqual(myCatalog._column_origins['objid'], 'the database')
229 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database')
230 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database')
231 self.assertEqual(myCatalog._column_origins['aa'], 'the database')
232 self.assertEqual(myCatalog._column_origins['bb'], 'the database')
233 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin3')
234 self.assertEqual(str(myCatalog._column_origins['dd']).replace("'>", '').split('.')[-1], 'mixin1')
236 def testAunspecified(self):
237 """
238 Test case where aa is not specified in the catalog (and has a default)
239 """
240 myCatalog = testCatalogAunspecified(self.myDBobject)
242 self.assertEqual(myCatalog._column_origins['objid'], 'the database')
243 self.assertEqual(myCatalog._column_origins['raJ2000'], 'the database')
244 self.assertEqual(myCatalog._column_origins['decJ2000'], 'the database')
245 self.assertEqual(myCatalog._column_origins['aa'], 'the database')
246 self.assertEqual(myCatalog._column_origins['bb'], 'the database')
248 # test that the last string in the column origin name points to the correct mixin
249 self.assertEqual(str(myCatalog._column_origins['cc']).replace("'>", '').split('.')[-1], 'mixin3')
250 self.assertEqual(str(myCatalog._column_origins['dd']).replace("'>", '').split('.')[-1], 'mixin1')
253class myDummyCatalogClass(InstanceCatalog):
255 default_columns = [('sillyDefault', 2.0, float)]
257 def get_cc(self):
258 return self.column_by_name('aa')+1.0
260 @compound('dd', 'ee', 'ff')
261 def get_compound(self):
263 return np.array([self.column_by_name('aa')+2.0,
264 self.column_by_name('aa')+3.0,
265 self.column_by_name('aa')+4.0])
268class myDependentColumnsClass_shouldPass(InstanceCatalog):
270 def get_dd(self):
272 if 'ee' in self._all_available_columns:
273 delta = self.column_by_name('ee')
274 else:
275 delta = self.column_by_name('bb')
277 return self.column_by_name('aa') + delta
280class myDependentColumnsClass_shouldFail(InstanceCatalog):
282 def get_cc(self):
283 return self.column_by_name('aa')+1.0
285 def get_dd(self):
287 if 'ee' in self._all_available_columns:
288 delta = self.column_by_name('ee')
289 else:
290 delta = self.column_by_name('bb')
292 return self.column_by_name('aa') + delta
294 def get_ee(self):
295 return self.column_by_name('aa')+self.column_by_name('doesNotExist')
298class AllAvailableColumns(unittest.TestCase):
299 """
300 This will contain a unit test to verify that the InstanceCatalog class
301 self._all_available_columns contains all of the information it should
302 """
304 @classmethod
305 def setUpClass(cls):
306 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-')
307 cls.dbName = os.path.join(cls.scratch_dir, 'allGettersTestDatabase.db')
308 makeTestDB(cls.dbName)
310 @classmethod
311 def tearDownClass(cls):
312 if os.path.exists(cls.scratch_dir):
313 shutil.rmtree(cls.scratch_dir)
315 def setUp(self):
316 self.db = testDBObject(database=self.dbName)
318 def testAllGetters(self):
319 """
320 test that the self._all_available_columns list contains all of the columns
321 definedin an InstanceCatalog and its CatalogDBObject
322 """
323 cat = myDummyCatalogClass(self.db, column_outputs=['aa'])
324 self.assertIn('cc', cat._all_available_columns)
325 self.assertIn('dd', cat._all_available_columns)
326 self.assertIn('ee', cat._all_available_columns)
327 self.assertIn('ff', cat._all_available_columns)
328 self.assertIn('compound', cat._all_available_columns)
329 self.assertIn('id', cat._all_available_columns)
330 self.assertIn('aa', cat._all_available_columns)
331 self.assertIn('bb', cat._all_available_columns)
332 self.assertIn('ra', cat._all_available_columns)
333 self.assertIn('decl', cat._all_available_columns)
334 self.assertIn('decJ2000', cat._all_available_columns)
335 self.assertIn('raJ2000', cat._all_available_columns)
336 self.assertIn('objid', cat._all_available_columns)
337 self.assertIn('sillyDefault', cat._all_available_columns)
339 def testDependentColumns(self):
340 """
341 We want to be able to use self._all_available_columns to change the calculation
342 of columns on the fly (i.e. if a column exists, then use it to calculate
343 another column; if it does not, ignore it). This method tests whether
344 or not that scheme will work.
346 I have written two classes of catalogs. The getter for the column 'dd'
347 depends on the column 'doesNotExist', but only if the column 'ee' is defined.
348 The class myDependentColumnsClass_shouldPass does not define a getter for
349 'ee', so it does not require 'doesNotExist', so the constructor should pass.
350 The class myDependentColumnsClass_shouldFail does have a getter for 'ee',
351 so any catalog that requests the column 'dd' should fail to construct.
352 """
354 myDependentColumnsClass_shouldPass(self.db, column_outputs=['dd'])
356 # as long as we do not request the column 'dd', this should work
357 myDependentColumnsClass_shouldFail(self.db, column_outputs=['cc'])
359 # because we are requesting the column 'dd', which depends on the fictitious column
360 # 'doesNotExist', this should raise an exception
361 self.assertRaises(ValueError, myDependentColumnsClass_shouldFail, self.db, column_outputs=['dd'])
364class MemoryTestClass(lsst.utils.tests.MemoryTestCase):
365 pass
368if __name__ == "__main__": 368 ↛ 369line 368 didn't jump to line 369, because the condition on line 368 was never true
369 lsst.utils.tests.init()
370 unittest.main()