Coverage for tests/testParallelCatalogWriter.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import with_statement
2from builtins import range
3from builtins import super
4import unittest
5import sqlite3
6import os
7import numpy as np
8import tempfile
9import shutil
11import lsst.utils.tests
12from lsst.sims.utils.CodeUtilities import sims_clean_up
13from lsst.sims.catalogs.definitions import parallelCatalogWriter
14from lsst.sims.catalogs.definitions import InstanceCatalog
15from lsst.sims.catalogs.decorators import compound, cached
16from lsst.sims.catalogs.db import CatalogDBObject
19ROOT = os.path.abspath(os.path.dirname(__file__))
22def setup_module(module):
23 lsst.utils.tests.init()
26class DbClass(CatalogDBObject):
27 tableid = 'test'
29 host = None
30 port = None
31 driver = 'sqlite'
32 objid = 'parallel_writer_test_db'
33 idColKey = 'id'
36class ParallelCatClass1(InstanceCatalog):
37 column_outputs = ['id', 'test1', 'ii']
38 cannot_be_null = ['valid1']
40 @compound('test1', 'valid1')
41 def get_values(self):
42 ii = self.column_by_name('ii')
43 return np.array([self.column_by_name('id')**2,
44 np.where(ii%2 == 1, ii, None)])
47class ParallelCatClass2(InstanceCatalog):
48 column_outputs = ['id', 'test2', 'ii']
49 cannot_be_null = ['valid2']
51 @compound('test2', 'valid2')
52 def get_values(self):
53 ii = self.column_by_name('id')
54 return np.array([self.column_by_name('id')**3,
55 np.where(ii%2 == 1, ii, None)])
58class ParallelCatClass3(InstanceCatalog):
59 column_outputs = ['id', 'test3', 'ii']
60 cannot_be_null = ['valid3']
62 @cached
63 def get_test3(self):
64 return self.column_by_name('id')**4
66 @cached
67 def get_valid3(self):
68 ii = self.column_by_name('id')
69 return np.where(ii%5 == 0, ii, None)
72class ControlCatalog(InstanceCatalog):
73 column_outputs = ['id', 'ii']
76class ParallelWriterTestCase(unittest.TestCase):
78 @classmethod
79 def setUpClass(cls):
80 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="ParallelWriterTestCase")
82 cls.db_name = os.path.join(cls.scratch_dir, 'parallel_test_db.db')
83 if os.path.exists(cls.db_name):
84 os.unlink(cls.db_name)
86 rng = np.random.RandomState(88)
87 conn = sqlite3.connect(cls.db_name)
88 c = conn.cursor()
89 c.execute('''CREATE TABLE test (id int, ii int)''')
90 for ii in range(100):
91 c.execute('''INSERT INTO test VALUES(%i, %i)''' % (ii, rng.randint(0, 101)))
93 conn.commit()
94 conn.close()
96 @classmethod
97 def tearDownClass(cls):
98 sims_clean_up()
99 if os.path.exists(cls.db_name):
100 os.unlink(cls.db_name)
101 if os.path.exists(cls.scratch_dir):
102 shutil.rmtree(cls.scratch_dir)
104 def test_parallel_writing(self):
105 """
106 Test that parallelCatalogWriter gets the right columns in it
107 """
108 db_name = os.path.join(self.scratch_dir, 'parallel_test_db.db')
109 db = DbClass(database=db_name)
111 class_dict = {os.path.join(self.scratch_dir, 'par_test1.txt'): ParallelCatClass1(db),
112 os.path.join(self.scratch_dir, 'par_test2.txt'): ParallelCatClass2(db),
113 os.path.join(self.scratch_dir, 'par_test3.txt'): ParallelCatClass3(db)}
115 for file_name in class_dict:
116 if os.path.exists(file_name):
117 os.unlink(file_name)
119 parallelCatalogWriter(class_dict)
121 dtype = np.dtype([('id', int), ('test', int), ('ii', int)])
122 data1 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test1.txt'), dtype=dtype, delimiter=',')
123 data2 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test2.txt'), dtype=dtype, delimiter=',')
124 data3 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test3.txt'), dtype=dtype, delimiter=',')
126 # verify that the contents of the catalogs fit with the constraints in cannot_be_null
127 self.assertEqual(len(np.where(data1['ii']%2 == 0)[0]), 0)
128 self.assertEqual(len(np.where(data2['id']%2 == 0)[0]), 0)
129 self.assertEqual(len(np.where(data3['id']%5 != 0)[0]), 0)
131 # verify that the added value columns came out to the correct value
132 np.testing.assert_array_equal(data1['id']**2, data1['test'])
133 np.testing.assert_array_equal(data2['id']**3, data2['test'])
134 np.testing.assert_array_equal(data3['id']**4, data3['test'])
136 # now verify that all of the rows that were excluded from our catalogs
137 # really should have been excluded
139 control_cat = ControlCatalog(db)
140 iterator = control_cat.iter_catalog()
141 ct = 0
142 ct_in_1 = 0
143 ct_in_2 = 0
144 ct_in_3 = 0
145 for control_data in iterator:
146 ct += 1
148 if control_data[1] % 2 == 0:
149 self.assertNotIn(control_data[0], data1['id'])
150 else:
151 ct_in_1 += 1
152 self.assertIn(control_data[0], data1['id'])
153 dex = np.where(data1['id'] == control_data[0])[0][0]
154 self.assertEqual(control_data[1], data1['ii'][dex])
156 if control_data[0] % 2 == 0:
157 self.assertNotIn(control_data[0], data2['id'])
158 else:
159 ct_in_2 += 1
160 self.assertIn(control_data[0], data2['id'])
161 dex = np.where(data2['id'] == control_data[0])[0][0]
162 self.assertEqual(control_data[1], data2['ii'][dex])
164 if control_data[0] % 5 != 0:
165 self.assertNotIn(control_data[0], data3['id'])
166 else:
167 ct_in_3 += 1
168 self.assertIn(control_data[0], data3['id'])
169 dex = np.where(data3['id'] == control_data[0])[0][0]
170 self.assertEqual(control_data[1], data3['ii'][dex])
172 self.assertEqual(ct_in_1, len(data1['id']))
173 self.assertEqual(ct_in_2, len(data2['id']))
174 self.assertEqual(ct_in_3, len(data3['id']))
175 self.assertEqual(ct, 100)
177 for file_name in class_dict:
178 if os.path.exists(file_name):
179 os.unlink(file_name)
181 def test_parallel_writing_chunk_size(self):
182 """
183 Test that parallelCatalogWriter gets the right columns in it
184 when chunk_size is not None (this is a repeat of test_parallel_writing)
185 """
186 db_name = os.path.join(self.scratch_dir, 'parallel_test_db.db')
187 db = DbClass(database=db_name)
189 class_dict = {os.path.join(self.scratch_dir, 'par_test1.txt'): ParallelCatClass1(db),
190 os.path.join(self.scratch_dir, 'par_test2.txt'): ParallelCatClass2(db),
191 os.path.join(self.scratch_dir, 'par_test3.txt'): ParallelCatClass3(db)}
193 for file_name in class_dict:
194 if os.path.exists(file_name):
195 os.unlink(file_name)
197 parallelCatalogWriter(class_dict, chunk_size=7)
199 dtype = np.dtype([('id', int), ('test', int), ('ii', int)])
200 data1 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test1.txt'), dtype=dtype, delimiter=',')
201 data2 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test2.txt'), dtype=dtype, delimiter=',')
202 data3 = np.genfromtxt(os.path.join(self.scratch_dir, 'par_test3.txt'), dtype=dtype, delimiter=',')
204 # verify that the contents of the catalogs fit with the constraints in cannot_be_null
205 self.assertEqual(len(np.where(data1['ii']%2 == 0)[0]), 0)
206 self.assertEqual(len(np.where(data2['id']%2 == 0)[0]), 0)
207 self.assertEqual(len(np.where(data3['id']%5 != 0)[0]), 0)
209 # verify that the added value columns came out to the correct value
210 np.testing.assert_array_equal(data1['id']**2, data1['test'])
211 np.testing.assert_array_equal(data2['id']**3, data2['test'])
212 np.testing.assert_array_equal(data3['id']**4, data3['test'])
214 # now verify that all of the rows that were excluded from our catalogs
215 # really should have been excluded
217 control_cat = ControlCatalog(db)
218 iterator = control_cat.iter_catalog()
219 ct = 0
220 ct_in_1 = 0
221 ct_in_2 = 0
222 ct_in_3 = 0
223 for control_data in iterator:
224 ct += 1
226 if control_data[1] % 2 == 0:
227 self.assertNotIn(control_data[0], data1['id'])
228 else:
229 ct_in_1 += 1
230 self.assertIn(control_data[0], data1['id'])
231 dex = np.where(data1['id'] == control_data[0])[0][0]
232 self.assertEqual(control_data[1], data1['ii'][dex])
234 if control_data[0] % 2 == 0:
235 self.assertNotIn(control_data[0], data2['id'])
236 else:
237 ct_in_2 += 1
238 self.assertIn(control_data[0], data2['id'])
239 dex = np.where(data2['id'] == control_data[0])[0][0]
240 self.assertEqual(control_data[1], data2['ii'][dex])
242 if control_data[0] % 5 != 0:
243 self.assertNotIn(control_data[0], data3['id'])
244 else:
245 ct_in_3 += 1
246 self.assertIn(control_data[0], data3['id'])
247 dex = np.where(data3['id'] == control_data[0])[0][0]
248 self.assertEqual(control_data[1], data3['ii'][dex])
250 self.assertEqual(ct_in_1, len(data1['id']))
251 self.assertEqual(ct_in_2, len(data2['id']))
252 self.assertEqual(ct_in_3, len(data3['id']))
253 self.assertEqual(ct, 100)
255 for file_name in class_dict:
256 if os.path.exists(file_name):
257 os.unlink(file_name)
260class MemoryTestClass(lsst.utils.tests.MemoryTestCase):
261 pass
264if __name__ == "__main__": 264 ↛ 265line 264 didn't jump to line 265, because the condition on line 264 was never true
265 setup_module(None)
266 unittest.main()