Coverage for tests/testFilteringCatalogs.py : 5%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import with_statement
2from builtins import range
3import unittest
4import numpy as np
5import os
6import shutil
7import tempfile
9import lsst.utils.tests
10from lsst.sims.utils.CodeUtilities import sims_clean_up
11from lsst.sims.catalogs.definitions import InstanceCatalog, CompoundInstanceCatalog
12from lsst.sims.catalogs.db import fileDBObject, CatalogDBObject
13from lsst.sims.catalogs.decorators import cached, compound
15ROOT = os.path.abspath(os.path.dirname(__file__))
18def setup_module(module):
19 lsst.utils.tests.init()
22class InstanceCatalogTestCase(unittest.TestCase):
23 """
24 This class will contain tests that will help us verify
25 that using cannot_be_null to filter the contents of an
26 InstanceCatalog works as it should.
27 """
29 @classmethod
30 def setUpClass(cls):
31 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="scratchSpace-")
33 cls.db_src_name = os.path.join(cls.scratch_dir, 'inst_cat_filter_db.txt')
34 if os.path.exists(cls.db_src_name):
35 os.unlink(cls.db_src_name)
37 with open(cls.db_src_name, 'w') as output_file:
38 output_file.write('#a header\n')
39 for ii in range(10):
40 output_file.write('%d %d %d %d\n' % (ii, ii+1, ii+2, ii+3))
42 dtype = np.dtype([('id', int), ('ip1', int), ('ip2', int), ('ip3', int)])
43 cls.db = fileDBObject(cls.db_src_name, runtable='test', dtype=dtype,
44 idColKey='id')
46 @classmethod
47 def tearDownClass(cls):
49 sims_clean_up()
51 del cls.db
53 if os.path.exists(cls.db_src_name):
54 os.unlink(cls.db_src_name)
55 if os.path.exists(cls.scratch_dir):
56 shutil.rmtree(cls.scratch_dir)
58 def test_single_filter(self):
59 """
60 Test filtering on a single column
61 """
63 class FilteredCat(InstanceCatalog):
64 column_outputs = ['id', 'ip1', 'ip2', 'ip3t']
65 cannot_be_null = ['ip3t']
67 @cached
68 def get_ip3t(self):
69 base = self.column_by_name('ip3')
70 ii = self.column_by_name('id')
71 return np.where(ii < 5, base, None)
73 cat_name = os.path.join(self.scratch_dir, 'inst_single_filter_cat.txt')
74 if os.path.exists(cat_name):
75 os.unlink(cat_name)
77 cat = FilteredCat(self.db)
78 cat.write_catalog(cat_name)
79 with open(cat_name, 'r') as input_file:
80 input_lines = input_file.readlines()
82 # verify that the catalog contains the expected data
83 self.assertEqual(len(input_lines), 6) # 5 data lines and a header
84 for i_line, line in enumerate(input_lines):
85 if i_line is 0:
86 continue
87 else:
88 ii = i_line - 1
89 self.assertLess(ii, 5)
90 self.assertEqual(line,
91 '%d, %d, %d, %d\n' % (ii, ii+1, ii+2, ii+3))
93 # test that iter_catalog returns the same result
94 cat = FilteredCat(self.db)
95 line_ct = 0
96 for line in cat.iter_catalog():
97 str_line = '%d, %d, %d, %d\n' % (line[0], line[1], line[2], line[3])
98 line_ct += 1
99 self.assertIn(str_line, input_lines)
100 self.assertEqual(line_ct, len(input_lines)-1)
102 # test that iter_catalog_chunks returns the same result
103 cat = FilteredCat(self.db)
104 line_ct = 0
105 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
106 for ix in range(len(chunk[0])):
107 str_line = '%d, %d, %d, %d\n' % \
108 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix])
109 line_ct += 1
110 self.assertIn(str_line, input_lines)
111 self.assertEqual(line_ct, len(input_lines)-1)
113 if os.path.exists(cat_name):
114 os.unlink(cat_name)
116 def test_two_filters(self):
117 """
118 Test a case where we filter on two columns.
119 """
120 class FilteredCat2(InstanceCatalog):
121 column_outputs = ['id', 'ip1', 'ip2t', 'ip3t']
122 cannot_be_null = ['ip2t', 'ip3t']
124 @cached
125 def get_ip2t(self):
126 base = self.column_by_name('ip2')
127 return np.where(base % 2 == 0, base, None)
129 @cached
130 def get_ip3t(self):
131 base = self.column_by_name('ip3')
132 return np.where(base % 3 == 0, base, None)
134 cat_name = os.path.join(self.scratch_dir, "inst_two_filter_cat.txt")
135 if os.path.exists(cat_name):
136 os.unlink(cat_name)
138 cat = FilteredCat2(self.db)
139 cat.write_catalog(cat_name)
141 with open(cat_name, 'r') as input_file:
142 input_lines = input_file.readlines()
144 self.assertEqual(len(input_lines), 3) # two data lines and a header
145 for i_line, line in enumerate(input_lines):
146 if i_line is 0:
147 continue
148 else:
149 ii = (i_line - 1)*6
150 ip1 = ii + 1
151 ip2 = ii + 2
152 ip3 = ii + 3
153 self.assertEqual((ii+2) % 2, 0)
154 self.assertEqual((ii+3) % 3, 0)
155 self.assertEqual(line,
156 '%d, %d, %d, %d\n' % (ii, ip1, ip2, ip3))
158 # test that iter_catalog returns the same result
159 cat = FilteredCat2(self.db)
160 line_ct = 0
161 for line in cat.iter_catalog():
162 str_line = '%d, %d, %d, %d\n' % (line[0], line[1], line[2], line[3])
163 line_ct += 1
164 self.assertIn(str_line, input_lines)
165 self.assertEqual(line_ct, len(input_lines)-1)
167 # test that iter_catalog_chunks returns the same result
168 cat = FilteredCat2(self.db)
169 line_ct = 0
170 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
171 for ix in range(len(chunk[0])):
172 str_line = '%d, %d, %d, %d\n' % \
173 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix])
174 line_ct += 1
175 self.assertIn(str_line, input_lines)
176 self.assertEqual(line_ct, len(input_lines)-1)
178 if os.path.exists(cat_name):
179 os.unlink(cat_name)
181 def test_post_facto_filters(self):
182 """
183 Test a case where filters are declared at instantiation
184 """
185 class FilteredCat3(InstanceCatalog):
186 column_outputs = ['id', 'ip1', 'ip2t', 'ip3t']
188 @cached
189 def get_ip2t(self):
190 base = self.column_by_name('ip2')
191 return np.where(base % 2 == 0, base, None)
193 @cached
194 def get_ip3t(self):
195 base = self.column_by_name('ip3')
196 return np.where(base % 3 == 0, base, None)
198 cat_name = os.path.join(self.scratch_dir, "inst_post_facto_filter_cat.txt")
199 if os.path.exists(cat_name):
200 os.unlink(cat_name)
202 cat = FilteredCat3(self.db, cannot_be_null=['ip2t', 'ip3t'])
203 cat.write_catalog(cat_name)
205 with open(cat_name, 'r') as input_file:
206 input_lines = input_file.readlines()
208 self.assertEqual(len(input_lines), 3) # two data lines and a header
209 for i_line, line in enumerate(input_lines):
210 if i_line is 0:
211 continue
212 else:
213 ii = (i_line - 1)*6
214 ip1 = ii + 1
215 ip2 = ii + 2
216 ip3 = ii + 3
217 self.assertEqual((ii+2) % 2, 0)
218 self.assertEqual((ii+3) % 3, 0)
219 self.assertEqual(line,
220 '%d, %d, %d, %d\n' % (ii, ip1, ip2, ip3))
222 # test that iter_catalog returns the same result
223 cat = FilteredCat3(self.db, cannot_be_null=['ip2t', 'ip3t'])
224 line_ct = 0
225 for line in cat.iter_catalog():
226 str_line = '%d, %d, %d, %d\n' % (line[0], line[1], line[2], line[3])
227 line_ct += 1
228 self.assertIn(str_line, input_lines)
229 self.assertEqual(line_ct, len(input_lines)-1)
231 # test that iter_catalog_chunks returns the same result
232 cat = FilteredCat3(self.db, cannot_be_null=['ip2t', 'ip3t'])
233 line_ct = 0
234 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
235 for ix in range(len(chunk[0])):
236 str_line = '%d, %d, %d, %d\n' % \
237 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix])
238 line_ct += 1
239 self.assertIn(str_line, input_lines)
240 self.assertEqual(line_ct, len(input_lines)-1)
242 if os.path.exists(cat_name):
243 os.unlink(cat_name)
245 def test_compound_column(self):
246 """
247 Test filtering on a catalog with a compound column that is calculated
248 after the filter column (example code has shown this to be a difficult case)
249 """
251 class FilteredCat4(InstanceCatalog):
252 column_outputs = ['id', 'a', 'b', 'c', 'filter_col']
253 cannot_be_null = ['filter_col']
255 @compound('a', 'b', 'c')
256 def get_alphabet(self):
257 ii = self.column_by_name('ip3')
258 return np.array([ii*ii, ii*ii*ii, ii*ii*ii*ii])
260 @cached
261 def get_filter_col(self):
262 base = self.column_by_name('a')
263 return np.where(base % 3 == 0, base/2.0, None)
265 cat_name = os.path.join(self.scratch_dir, "inst_compound_column_filter_cat.txt")
266 if os.path.exists(cat_name):
267 os.unlink(cat_name)
269 cat = FilteredCat4(self.db)
270 cat.write_catalog(cat_name)
272 with open(cat_name, 'r') as input_file:
273 input_lines = input_file.readlines()
275 # verify that the catalog contains expected data
276 self.assertEqual(len(input_lines), 5) # 4 data lines and a header
277 for i_line, line in enumerate(input_lines):
278 if i_line is 0:
279 continue
280 else:
281 ii = (i_line - 1)*3
282 ip3 = ii + 3
283 self.assertEqual((ip3*ip3) % 3, 0)
284 self.assertEqual(line, '%d, %d, %d, %d, %.1f\n'
285 % (ii, ip3*ip3, ip3*ip3*ip3, ip3*ip3*ip3*ip3, 0.5*(ip3*ip3)))
288 # test that iter_catalog returns the same result
289 cat = FilteredCat4(self.db)
290 line_ct = 0
291 for line in cat.iter_catalog():
292 str_line = '%d, %d, %d, %d, %.1f\n' % (line[0], line[1], line[2],
293 line[3], line[4])
294 line_ct += 1
295 self.assertIn(str_line, input_lines)
296 self.assertEqual(line_ct, len(input_lines)-1)
298 # test that iter_catalog_chunks returns the same result
299 cat = FilteredCat4(self.db)
300 line_ct = 0
301 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
302 for ix in range(len(chunk[0])):
303 str_line = '%d, %d, %d, %d, %.1f\n' % \
304 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix],
305 chunk[4][ix])
306 line_ct += 1
307 self.assertIn(str_line, input_lines)
308 self.assertEqual(line_ct, len(input_lines)-1)
310 if os.path.exists(cat_name):
311 os.unlink(cat_name)
313 def test_filter_on_compound_column(self):
314 """
315 Test filtering on a catalog that filters on a compound column
316 (example code has shown this to be a difficult case)
317 """
319 class FilteredCat5(InstanceCatalog):
320 column_outputs = ['id', 'a', 'b', 'c']
321 cannot_be_null = ['c']
323 @compound('a', 'b', 'c')
324 def get_alphabet(self):
325 ii = self.column_by_name('ip3')
326 c = ii*ii*ii*ii
327 return np.array([ii*ii, ii*ii*ii,
328 np.where(c % 3 == 0, c, None)])
330 cat_name = os.path.join(self.scratch_dir, "inst_actual_compound_column_filter_cat.txt")
331 if os.path.exists(cat_name):
332 os.unlink(cat_name)
334 cat = FilteredCat5(self.db)
335 cat.write_catalog(cat_name)
337 with open(cat_name, 'r') as input_file:
338 input_lines = input_file.readlines()
340 # verify that the catalog contains expected data
341 self.assertEqual(len(input_lines), 5) # 4 data lines and a header
342 for i_line, line in enumerate(input_lines):
343 if i_line is 0:
344 continue
345 else:
346 ii = (i_line - 1)*3
347 ip3 = ii + 3
348 self.assertEqual((ip3**4) % 3, 0)
349 self.assertEqual(line, '%d, %d, %d, %d\n'
350 % (ii, ip3*ip3, ip3*ip3*ip3, ip3*ip3*ip3*ip3))
352 # test that iter_catalog returns the same result
353 cat = FilteredCat5(self.db)
354 line_ct = 0
355 for line in cat.iter_catalog():
356 str_line = '%d, %d, %d, %d\n' % (line[0], line[1], line[2], line[3])
357 line_ct += 1
358 self.assertIn(str_line, input_lines)
359 self.assertEqual(line_ct, len(input_lines)-1)
361 # test that iter_catalog_chunks returns the same result
362 cat = FilteredCat5(self.db)
363 line_ct = 0
364 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
365 for ix in range(len(chunk[0])):
366 str_line = '%d, %d, %d, %d\n' % \
367 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix])
368 line_ct += 1
369 self.assertIn(str_line, input_lines)
370 self.assertEqual(line_ct, len(input_lines)-1)
373 if os.path.exists(cat_name):
374 os.unlink(cat_name)
376 def test_filter_on_unused_compound_column(self):
377 """
378 Test a catalog in which cannot_be_null is a compound column, but not one
379 that is written to the catalog.
380 """
382 class FilteredCat5b(InstanceCatalog):
383 column_outputs = ['id', 'a', 'b']
384 cannot_be_null = ['c']
386 @compound('a', 'b', 'c')
387 def get_alphabet(self):
388 ii = self.column_by_name('ip3')
389 c = ii*ii*ii*ii
390 return np.array([ii*ii, ii*ii*ii,
391 np.where(c % 3 == 0, c, None)])
393 cat_name = os.path.join(self.scratch_dir, "inst_actual_compound_column_filter_b_cat.txt")
394 if os.path.exists(cat_name):
395 os.unlink(cat_name)
397 cat = FilteredCat5b(self.db)
398 cat.write_catalog(cat_name)
400 with open(cat_name, 'r') as input_file:
401 input_lines = input_file.readlines()
403 # verify that the catalog contains expected data
404 self.assertEqual(len(input_lines), 5) # 4 data lines and a header
405 for i_line, line in enumerate(input_lines):
406 if i_line is 0:
407 continue
408 else:
409 ii = (i_line - 1)*3
410 ip3 = ii + 3
411 self.assertEqual((ip3**4) % 3, 0)
412 self.assertEqual(line, '%d, %d, %d\n'
413 % (ii, ip3*ip3, ip3*ip3*ip3))
415 # test that iter_catalog returns the same result
416 cat = FilteredCat5b(self.db)
417 line_ct = 0
418 for line in cat.iter_catalog():
419 str_line = '%d, %d, %d\n' % (line[0], line[1], line[2])
420 line_ct += 1
421 self.assertIn(str_line, input_lines)
422 self.assertEqual(line_ct, len(input_lines)-1)
424 # test that iter_catalog_chunks returns the same result
425 cat = FilteredCat5b(self.db)
426 line_ct = 0
427 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
428 for ix in range(len(chunk[0])):
429 str_line = '%d, %d, %d\n' % \
430 (chunk[0][ix], chunk[1][ix], chunk[2][ix])
431 line_ct += 1
432 self.assertIn(str_line, input_lines)
433 self.assertEqual(line_ct, len(input_lines)-1)
436 if os.path.exists(cat_name):
437 os.unlink(cat_name)
439 def test_empty_chunk(self):
440 """
441 Test that catalog filtering behaves correctly, even when the first
442 chunk is empty
443 """
445 class FilteredCat6(InstanceCatalog):
446 column_outputs = ['id', 'filter']
447 cannot_be_null = ['filter']
449 @cached
450 def get_filter(self):
451 ii = self.column_by_name('ip1')
452 return np.where(ii>5, ii, None)
454 cat_name = os.path.join(self.scratch_dir, "inst_empty_chunk_cat.txt")
455 if os.path.exists(cat_name):
456 os.unlink(cat_name)
458 cat = FilteredCat6(self.db)
459 cat.write_catalog(cat_name, chunk_size=2)
461 # check that the catalog contains the correct information
462 with open(cat_name, 'r') as input_file:
463 input_lines = input_file.readlines()
465 self.assertEqual(len(input_lines), 6) # 5 data lines and a header
466 for i_line, line in enumerate(input_lines):
467 if i_line is 0:
468 continue
469 else:
470 ii = 4 + i_line
471 self.assertGreater(ii+1, 5)
472 self.assertEqual(line, '%d, %d\n' % (ii, ii+1))
474 # test that iter_catalog returns the same result
475 cat = FilteredCat6(self.db)
476 line_ct = 0
477 for line in cat.iter_catalog():
478 str_line = '%d, %d\n' % (line[0], line[1])
479 line_ct += 1
480 self.assertIn(str_line, input_lines)
481 self.assertEqual(line_ct, len(input_lines)-1)
483 # test that iter_catalog_chunks returns the same result
484 cat = FilteredCat6(self.db)
485 line_ct = 0
486 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
487 for ix in range(len(chunk[0])):
488 str_line = '%d, %d\n' % \
489 (chunk[0][ix], chunk[1][ix])
490 line_ct += 1
491 self.assertIn(str_line, input_lines)
492 self.assertEqual(line_ct, len(input_lines)-1)
495 if os.path.exists(cat_name):
496 os.unlink(cat_name)
498 def test_hidden_filter(self):
499 """
500 Test filtering on a column that is not written to the final catalog.
501 """
502 class FilteredCat7(InstanceCatalog):
503 column_outputs = ['id', 'ip1']
504 cannot_be_null = ['filter']
506 def get_filter(self):
507 ii = self.column_by_name('ip3')
508 return np.where(ii<7, ii, None)
510 cat_name = os.path.join(self.scratch_dir, "inst_hidden_filter_cat.txt")
511 if os.path.exists(cat_name):
512 os.unlink(cat_name)
514 cat = FilteredCat7(self.db)
515 cat.write_catalog(cat_name)
517 with open(cat_name, 'r') as input_file:
518 input_lines = input_file.readlines()
520 self.assertEqual(len(input_lines), 5)
521 for i_line, line in enumerate(input_lines):
522 if i_line is 0:
523 continue
524 else:
525 ii = i_line - 1
526 self.assertLess(ii+3, 7)
527 self.assertEqual(line, '%d, %d\n' % (ii, ii+1))
529 # test that iter_catalog returns the same result
530 cat = FilteredCat7(self.db)
531 line_ct = 0
532 for line in cat.iter_catalog():
533 str_line = '%d, %d\n' % (line[0], line[1])
534 line_ct += 1
535 self.assertIn(str_line, input_lines)
536 self.assertEqual(line_ct, len(input_lines)-1)
538 # test that iter_catalog_chunks returns the same result
539 cat = FilteredCat7(self.db)
540 line_ct = 0
541 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
542 for ix in range(len(chunk[0])):
543 str_line = '%d, %d\n' % \
544 (chunk[0][ix], chunk[1][ix])
545 line_ct += 1
546 self.assertIn(str_line, input_lines)
547 self.assertEqual(line_ct, len(input_lines)-1)
550 if os.path.exists(cat_name):
551 os.unlink(cat_name)
553 def test_adding_filter(self):
554 """
555 Test that, when we use the kwarg in the constructor to add to
556 cannot_be_null, the filter is appended to existing filters.
557 """
558 class FilteredCat8(InstanceCatalog):
559 column_outputs = ['id', 'ip1']
560 cannot_be_null = ['filter1']
562 def get_filter1(self):
563 ii = self.column_by_name('ip2')
564 return np.where(ii%2 == 0, ii, None)
566 def get_filter2(self):
567 ii = self.column_by_name('ip3')
568 return np.where(ii > 8, ii, None)
570 cat_name = os.path.join(self.scratch_dir, "inst_adding_filter_cat.txt")
571 if os.path.exists(cat_name):
572 os.unlink(cat_name)
574 cat = FilteredCat8(self.db, cannot_be_null = ['filter2'])
575 self.assertIn('filter1', cat._cannot_be_null)
576 self.assertIn('filter2', cat._cannot_be_null)
578 cat.write_catalog(cat_name)
579 with open(cat_name, 'r') as input_file:
580 input_lines = input_file.readlines()
581 self.assertEqual(len(input_lines), 3)
582 for i_line, line in enumerate(input_lines):
583 if i_line is 0:
584 continue
585 else:
586 ii = (i_line - 1)*2 + 6
587 self.assertEqual((ii+2) % 2, 0)
588 self.assertGreater(ii+3, 8)
589 self.assertEqual(line, '%d, %d\n' % (ii, ii+1))
591 # test that iter_catalog returns the same result
592 cat = FilteredCat8(self.db, cannot_be_null=['filter2'])
593 line_ct = 0
594 for line in cat.iter_catalog():
595 str_line = '%d, %d\n' % (line[0], line[1])
596 line_ct += 1
597 self.assertIn(str_line, input_lines)
598 self.assertEqual(line_ct, len(input_lines)-1)
600 # test that iter_catalog_chunks returns the same result
601 cat = FilteredCat8(self.db, cannot_be_null=['filter2'])
602 line_ct = 0
603 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2):
604 for ix in range(len(chunk[0])):
605 str_line = '%d, %d\n' % \
606 (chunk[0][ix], chunk[1][ix])
607 line_ct += 1
608 self.assertIn(str_line, input_lines)
609 self.assertEqual(line_ct, len(input_lines)-1)
612 if os.path.exists(cat_name):
613 os.unlink(cat_name)
616class CompoundInstanceCatalogTestCase(unittest.TestCase):
617 """
618 This class will contain tests that will help us verify that using
619 cannot_be_null to filter the contents of a CompoundInstanceCatalog
620 works as it should.
621 """
623 @classmethod
624 def setUpClass(cls):
625 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="scratchSpace-")
627 cls.db_src_name = os.path.join(cls.scratch_dir, 'compound_cat_filter_db.txt')
628 if os.path.exists(cls.db_src_name):
629 os.unlink(cls.db_src_name)
631 cls.db_name = os.path.join(cls.scratch_dir, 'compound_cat_filter_db.db')
632 if os.path.exists(cls.db_name):
633 os.unlink(cls.db_name)
635 with open(cls.db_src_name, 'w') as output_file:
636 output_file.write('#a header\n')
637 for ii in range(10):
638 output_file.write('%d %d %d %d\n' % (ii, ii+1, ii+2, ii+3))
640 dtype = np.dtype([('id', int), ('ip1', int), ('ip2', int), ('ip3', int)])
641 fileDBObject(cls.db_src_name, runtable='test', dtype=dtype,
642 idColKey='id', database=cls.db_name)
644 @classmethod
645 def tearDownClass(cls):
647 sims_clean_up()
649 if os.path.exists(cls.db_src_name):
650 os.unlink(cls.db_src_name)
652 if os.path.exists(cls.db_name):
653 os.unlink(cls.db_name)
655 if os.path.exists(cls.scratch_dir):
656 shutil.rmtree(cls.scratch_dir)
658 def test_compound_cat(self):
659 """
660 Test that a CompoundInstanceCatalog made up of InstanceCatalog classes that
661 each filter on a different condition gives the correct outputs.
662 """
664 class FilteringCatClass1(InstanceCatalog):
665 column_outputs = ['id', 'ip1t']
666 cannot_be_null = ['ip1t']
668 @cached
669 def get_ip1t(self):
670 base = self.column_by_name('ip1')
671 output = []
672 for bb in base:
673 if bb%2 == 0:
674 output.append(bb)
675 else:
676 output.append(None)
677 return np.array(output)
679 class FilteringCatClass2(InstanceCatalog):
680 column_outputs = ['id', 'ip2t']
681 cannot_be_null = ['ip2t']
683 @cached
684 def get_ip2t(self):
685 base = self.column_by_name('ip2')
686 ii = self.column_by_name('id')
687 return np.where(ii < 4, base, None)
689 class FilteringCatClass3(InstanceCatalog):
690 column_outputs = ['id', 'ip3t']
691 cannot_be_null = ['ip3t']
693 @cached
694 def get_ip3t(self):
695 base = self.column_by_name('ip3')
696 ii = self.column_by_name('id')
697 return np.where(ii > 5, base, None)
699 class DbClass(CatalogDBObject):
700 host = None
701 port = None
702 database = self.db_name
703 driver = 'sqlite'
704 tableid = 'test'
705 objid = 'silliness'
706 idColKey = 'id'
708 class DbClass1(DbClass):
709 objid = 'silliness1'
711 class DbClass2(DbClass):
712 objid = 'silliness2'
714 class DbClass3(DbClass):
715 objid = 'silliness3'
717 cat = CompoundInstanceCatalog([FilteringCatClass1,
718 FilteringCatClass2,
719 FilteringCatClass3],
720 [DbClass1, DbClass2, DbClass3])
722 cat_name = os.path.join(self.scratch_dir, "compound_filter_output.txt")
723 if os.path.exists(cat_name):
724 os.unlink(cat_name)
726 cat.write_catalog(cat_name)
728 with open(cat_name, 'r') as input_file:
729 input_lines = input_file.readlines()
731 self.assertEqual(len(input_lines), 14)
733 # given that we know what the contents of each sub-catalog should be
734 # and how they should be ordered, loop through the lines of the output
735 # catalog, verifying that every line is where it ought to be
736 for i_line, line in enumerate(input_lines):
737 if i_line is 0:
738 continue
739 elif i_line < 6:
740 ii = 2*(i_line-1) + 1
741 self.assertEqual((ii+1) % 2, 0)
742 self.assertEqual(line, '%d, %d\n' % (ii, ii+1))
743 elif i_line < 10:
744 ii = i_line - 6
745 self.assertLess(ii, 4)
746 self.assertEqual(line, '%d, %d\n' % (ii, ii+2))
747 else:
748 ii = i_line - 10 + 6
749 self.assertGreater(ii, 5)
750 self.assertEqual(line, '%d, %d\n' % (ii, ii+3))
752 if os.path.exists(cat_name):
753 os.unlink(cat_name)
755 def test_compound_cat_compound_column(self):
756 """
757 Test filtering a CompoundInstanceCatalog on a compound column
758 """
760 class FilteringCatClass4(InstanceCatalog):
761 column_outputs = ['id', 'a', 'b']
762 cannot_be_null = ['a']
764 @compound('a', 'b')
765 def get_alphabet(self):
766 a = self.column_by_name('ip1')
767 a = a*a
768 b = self.column_by_name('ip2')
769 b = b*0.25
770 return np.array([np.where(a % 2 == 0, a, None), b])
772 class FilteringCatClass5(InstanceCatalog):
773 column_outputs = ['id', 'a', 'b', 'filter']
774 cannot_be_null = ['b', 'filter']
776 @compound('a', 'b')
777 def get_alphabet(self):
778 ii = self.column_by_name('ip1')
779 return np.array([self.column_by_name('ip2')**3,
780 np.where(ii % 2 == 1, ii, None)])
782 @cached
783 def get_filter(self):
784 ii = self.column_by_name('ip1')
785 return np.where(ii % 3 != 0, ii, None)
787 class DbClass(CatalogDBObject):
788 host = None
789 port = None
790 database = self.db_name
791 driver = 'sqlite'
792 tableid = 'test'
793 idColKey = 'id'
795 class DbClass4(DbClass):
796 objid = 'silliness4'
798 class DbClass5(DbClass):
799 objid = 'silliness5'
801 cat_name = os.path.join(self.scratch_dir, "compound_cat_compound_filter_cat.txt")
802 if os.path.exists(cat_name):
803 os.unlink(cat_name)
805 cat = CompoundInstanceCatalog([FilteringCatClass4, FilteringCatClass5], [DbClass4, DbClass5])
806 cat.write_catalog(cat_name)
808 # now make sure that the catalog contains the expected data
809 with open(cat_name, 'r') as input_file:
810 input_lines = input_file.readlines()
811 self.assertEqual(len(input_lines), 9) # 8 data lines and a header
813 first_cat_lines = ['1, 4, 0.75\n', '3, 16, 1.25\n',
814 '5, 36, 1.75\n', '7, 64, 2.25\n',
815 '9, 100, 2.75\n']
817 second_cat_lines = ['0, 8, 1, 1\n', '4, 216, 5, 5\n',
818 '6, 512, 7, 7\n']
820 for i_line, line in enumerate(input_lines):
821 if i_line is 0:
822 continue
823 elif i_line < 6:
824 self.assertEqual(line, first_cat_lines[i_line-1])
825 else:
826 self.assertEqual(line, second_cat_lines[i_line-6])
828 if os.path.exists(cat_name):
829 os.unlink(cat_name)
832class MemoryTestClass(lsst.utils.tests.MemoryTestCase):
833 pass
836if __name__ == "__main__": 836 ↛ 837line 836 didn't jump to line 837, because the condition on line 836 was never true
837 lsst.utils.tests.init()
838 unittest.main()