Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import with_statement 

2from builtins import range 

3import unittest 

4import numpy as np 

5import os 

6import shutil 

7import tempfile 

8 

9import lsst.utils.tests 

10from lsst.sims.utils.CodeUtilities import sims_clean_up 

11from lsst.sims.catalogs.definitions import InstanceCatalog, CompoundInstanceCatalog 

12from lsst.sims.catalogs.db import fileDBObject, CatalogDBObject 

13from lsst.sims.catalogs.decorators import cached, compound 

14 

15ROOT = os.path.abspath(os.path.dirname(__file__)) 

16 

17 

18def setup_module(module): 

19 lsst.utils.tests.init() 

20 

21 

22class InstanceCatalogTestCase(unittest.TestCase): 

23 """ 

24 This class will contain tests that will help us verify 

25 that using cannot_be_null to filter the contents of an 

26 InstanceCatalog works as it should. 

27 """ 

28 

29 @classmethod 

30 def setUpClass(cls): 

31 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="scratchSpace-") 

32 

33 cls.db_src_name = os.path.join(cls.scratch_dir, 'inst_cat_filter_db.txt') 

34 if os.path.exists(cls.db_src_name): 

35 os.unlink(cls.db_src_name) 

36 

37 with open(cls.db_src_name, 'w') as output_file: 

38 output_file.write('#a header\n') 

39 for ii in range(10): 

40 output_file.write('%d %d %d %d\n' % (ii, ii+1, ii+2, ii+3)) 

41 

42 dtype = np.dtype([('id', int), ('ip1', int), ('ip2', int), ('ip3', int)]) 

43 cls.db = fileDBObject(cls.db_src_name, runtable='test', dtype=dtype, 

44 idColKey='id') 

45 

46 @classmethod 

47 def tearDownClass(cls): 

48 

49 sims_clean_up() 

50 

51 del cls.db 

52 

53 if os.path.exists(cls.db_src_name): 

54 os.unlink(cls.db_src_name) 

55 if os.path.exists(cls.scratch_dir): 

56 shutil.rmtree(cls.scratch_dir) 

57 

58 def test_single_filter(self): 

59 """ 

60 Test filtering on a single column 

61 """ 

62 

63 class FilteredCat(InstanceCatalog): 

64 column_outputs = ['id', 'ip1', 'ip2', 'ip3t'] 

65 cannot_be_null = ['ip3t'] 

66 

67 @cached 

68 def get_ip3t(self): 

69 base = self.column_by_name('ip3') 

70 ii = self.column_by_name('id') 

71 return np.where(ii < 5, base, None) 

72 

73 cat_name = os.path.join(self.scratch_dir, 'inst_single_filter_cat.txt') 

74 if os.path.exists(cat_name): 

75 os.unlink(cat_name) 

76 

77 cat = FilteredCat(self.db) 

78 cat.write_catalog(cat_name) 

79 with open(cat_name, 'r') as input_file: 

80 input_lines = input_file.readlines() 

81 

82 # verify that the catalog contains the expected data 

83 self.assertEqual(len(input_lines), 6) # 5 data lines and a header 

84 for i_line, line in enumerate(input_lines): 

85 if i_line is 0: 

86 continue 

87 else: 

88 ii = i_line - 1 

89 self.assertLess(ii, 5) 

90 self.assertEqual(line, 

91 '%d, %d, %d, %d\n' % (ii, ii+1, ii+2, ii+3)) 

92 

93 # test that iter_catalog returns the same result 

94 cat = FilteredCat(self.db) 

95 line_ct = 0 

96 for line in cat.iter_catalog(): 

97 str_line = '%d, %d, %d, %d\n' % (line[0], line[1], line[2], line[3]) 

98 line_ct += 1 

99 self.assertIn(str_line, input_lines) 

100 self.assertEqual(line_ct, len(input_lines)-1) 

101 

102 # test that iter_catalog_chunks returns the same result 

103 cat = FilteredCat(self.db) 

104 line_ct = 0 

105 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

106 for ix in range(len(chunk[0])): 

107 str_line = '%d, %d, %d, %d\n' % \ 

108 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix]) 

109 line_ct += 1 

110 self.assertIn(str_line, input_lines) 

111 self.assertEqual(line_ct, len(input_lines)-1) 

112 

113 if os.path.exists(cat_name): 

114 os.unlink(cat_name) 

115 

116 def test_two_filters(self): 

117 """ 

118 Test a case where we filter on two columns. 

119 """ 

120 class FilteredCat2(InstanceCatalog): 

121 column_outputs = ['id', 'ip1', 'ip2t', 'ip3t'] 

122 cannot_be_null = ['ip2t', 'ip3t'] 

123 

124 @cached 

125 def get_ip2t(self): 

126 base = self.column_by_name('ip2') 

127 return np.where(base % 2 == 0, base, None) 

128 

129 @cached 

130 def get_ip3t(self): 

131 base = self.column_by_name('ip3') 

132 return np.where(base % 3 == 0, base, None) 

133 

134 cat_name = os.path.join(self.scratch_dir, "inst_two_filter_cat.txt") 

135 if os.path.exists(cat_name): 

136 os.unlink(cat_name) 

137 

138 cat = FilteredCat2(self.db) 

139 cat.write_catalog(cat_name) 

140 

141 with open(cat_name, 'r') as input_file: 

142 input_lines = input_file.readlines() 

143 

144 self.assertEqual(len(input_lines), 3) # two data lines and a header 

145 for i_line, line in enumerate(input_lines): 

146 if i_line is 0: 

147 continue 

148 else: 

149 ii = (i_line - 1)*6 

150 ip1 = ii + 1 

151 ip2 = ii + 2 

152 ip3 = ii + 3 

153 self.assertEqual((ii+2) % 2, 0) 

154 self.assertEqual((ii+3) % 3, 0) 

155 self.assertEqual(line, 

156 '%d, %d, %d, %d\n' % (ii, ip1, ip2, ip3)) 

157 

158 # test that iter_catalog returns the same result 

159 cat = FilteredCat2(self.db) 

160 line_ct = 0 

161 for line in cat.iter_catalog(): 

162 str_line = '%d, %d, %d, %d\n' % (line[0], line[1], line[2], line[3]) 

163 line_ct += 1 

164 self.assertIn(str_line, input_lines) 

165 self.assertEqual(line_ct, len(input_lines)-1) 

166 

167 # test that iter_catalog_chunks returns the same result 

168 cat = FilteredCat2(self.db) 

169 line_ct = 0 

170 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

171 for ix in range(len(chunk[0])): 

172 str_line = '%d, %d, %d, %d\n' % \ 

173 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix]) 

174 line_ct += 1 

175 self.assertIn(str_line, input_lines) 

176 self.assertEqual(line_ct, len(input_lines)-1) 

177 

178 if os.path.exists(cat_name): 

179 os.unlink(cat_name) 

180 

181 def test_post_facto_filters(self): 

182 """ 

183 Test a case where filters are declared at instantiation 

184 """ 

185 class FilteredCat3(InstanceCatalog): 

186 column_outputs = ['id', 'ip1', 'ip2t', 'ip3t'] 

187 

188 @cached 

189 def get_ip2t(self): 

190 base = self.column_by_name('ip2') 

191 return np.where(base % 2 == 0, base, None) 

192 

193 @cached 

194 def get_ip3t(self): 

195 base = self.column_by_name('ip3') 

196 return np.where(base % 3 == 0, base, None) 

197 

198 cat_name = os.path.join(self.scratch_dir, "inst_post_facto_filter_cat.txt") 

199 if os.path.exists(cat_name): 

200 os.unlink(cat_name) 

201 

202 cat = FilteredCat3(self.db, cannot_be_null=['ip2t', 'ip3t']) 

203 cat.write_catalog(cat_name) 

204 

205 with open(cat_name, 'r') as input_file: 

206 input_lines = input_file.readlines() 

207 

208 self.assertEqual(len(input_lines), 3) # two data lines and a header 

209 for i_line, line in enumerate(input_lines): 

210 if i_line is 0: 

211 continue 

212 else: 

213 ii = (i_line - 1)*6 

214 ip1 = ii + 1 

215 ip2 = ii + 2 

216 ip3 = ii + 3 

217 self.assertEqual((ii+2) % 2, 0) 

218 self.assertEqual((ii+3) % 3, 0) 

219 self.assertEqual(line, 

220 '%d, %d, %d, %d\n' % (ii, ip1, ip2, ip3)) 

221 

222 # test that iter_catalog returns the same result 

223 cat = FilteredCat3(self.db, cannot_be_null=['ip2t', 'ip3t']) 

224 line_ct = 0 

225 for line in cat.iter_catalog(): 

226 str_line = '%d, %d, %d, %d\n' % (line[0], line[1], line[2], line[3]) 

227 line_ct += 1 

228 self.assertIn(str_line, input_lines) 

229 self.assertEqual(line_ct, len(input_lines)-1) 

230 

231 # test that iter_catalog_chunks returns the same result 

232 cat = FilteredCat3(self.db, cannot_be_null=['ip2t', 'ip3t']) 

233 line_ct = 0 

234 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

235 for ix in range(len(chunk[0])): 

236 str_line = '%d, %d, %d, %d\n' % \ 

237 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix]) 

238 line_ct += 1 

239 self.assertIn(str_line, input_lines) 

240 self.assertEqual(line_ct, len(input_lines)-1) 

241 

242 if os.path.exists(cat_name): 

243 os.unlink(cat_name) 

244 

245 def test_compound_column(self): 

246 """ 

247 Test filtering on a catalog with a compound column that is calculated 

248 after the filter column (example code has shown this to be a difficult case) 

249 """ 

250 

251 class FilteredCat4(InstanceCatalog): 

252 column_outputs = ['id', 'a', 'b', 'c', 'filter_col'] 

253 cannot_be_null = ['filter_col'] 

254 

255 @compound('a', 'b', 'c') 

256 def get_alphabet(self): 

257 ii = self.column_by_name('ip3') 

258 return np.array([ii*ii, ii*ii*ii, ii*ii*ii*ii]) 

259 

260 @cached 

261 def get_filter_col(self): 

262 base = self.column_by_name('a') 

263 return np.where(base % 3 == 0, base/2.0, None) 

264 

265 cat_name = os.path.join(self.scratch_dir, "inst_compound_column_filter_cat.txt") 

266 if os.path.exists(cat_name): 

267 os.unlink(cat_name) 

268 

269 cat = FilteredCat4(self.db) 

270 cat.write_catalog(cat_name) 

271 

272 with open(cat_name, 'r') as input_file: 

273 input_lines = input_file.readlines() 

274 

275 # verify that the catalog contains expected data 

276 self.assertEqual(len(input_lines), 5) # 4 data lines and a header 

277 for i_line, line in enumerate(input_lines): 

278 if i_line is 0: 

279 continue 

280 else: 

281 ii = (i_line - 1)*3 

282 ip3 = ii + 3 

283 self.assertEqual((ip3*ip3) % 3, 0) 

284 self.assertEqual(line, '%d, %d, %d, %d, %.1f\n' 

285 % (ii, ip3*ip3, ip3*ip3*ip3, ip3*ip3*ip3*ip3, 0.5*(ip3*ip3))) 

286 

287 

288 # test that iter_catalog returns the same result 

289 cat = FilteredCat4(self.db) 

290 line_ct = 0 

291 for line in cat.iter_catalog(): 

292 str_line = '%d, %d, %d, %d, %.1f\n' % (line[0], line[1], line[2], 

293 line[3], line[4]) 

294 line_ct += 1 

295 self.assertIn(str_line, input_lines) 

296 self.assertEqual(line_ct, len(input_lines)-1) 

297 

298 # test that iter_catalog_chunks returns the same result 

299 cat = FilteredCat4(self.db) 

300 line_ct = 0 

301 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

302 for ix in range(len(chunk[0])): 

303 str_line = '%d, %d, %d, %d, %.1f\n' % \ 

304 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix], 

305 chunk[4][ix]) 

306 line_ct += 1 

307 self.assertIn(str_line, input_lines) 

308 self.assertEqual(line_ct, len(input_lines)-1) 

309 

310 if os.path.exists(cat_name): 

311 os.unlink(cat_name) 

312 

313 def test_filter_on_compound_column(self): 

314 """ 

315 Test filtering on a catalog that filters on a compound column 

316 (example code has shown this to be a difficult case) 

317 """ 

318 

319 class FilteredCat5(InstanceCatalog): 

320 column_outputs = ['id', 'a', 'b', 'c'] 

321 cannot_be_null = ['c'] 

322 

323 @compound('a', 'b', 'c') 

324 def get_alphabet(self): 

325 ii = self.column_by_name('ip3') 

326 c = ii*ii*ii*ii 

327 return np.array([ii*ii, ii*ii*ii, 

328 np.where(c % 3 == 0, c, None)]) 

329 

330 cat_name = os.path.join(self.scratch_dir, "inst_actual_compound_column_filter_cat.txt") 

331 if os.path.exists(cat_name): 

332 os.unlink(cat_name) 

333 

334 cat = FilteredCat5(self.db) 

335 cat.write_catalog(cat_name) 

336 

337 with open(cat_name, 'r') as input_file: 

338 input_lines = input_file.readlines() 

339 

340 # verify that the catalog contains expected data 

341 self.assertEqual(len(input_lines), 5) # 4 data lines and a header 

342 for i_line, line in enumerate(input_lines): 

343 if i_line is 0: 

344 continue 

345 else: 

346 ii = (i_line - 1)*3 

347 ip3 = ii + 3 

348 self.assertEqual((ip3**4) % 3, 0) 

349 self.assertEqual(line, '%d, %d, %d, %d\n' 

350 % (ii, ip3*ip3, ip3*ip3*ip3, ip3*ip3*ip3*ip3)) 

351 

352 # test that iter_catalog returns the same result 

353 cat = FilteredCat5(self.db) 

354 line_ct = 0 

355 for line in cat.iter_catalog(): 

356 str_line = '%d, %d, %d, %d\n' % (line[0], line[1], line[2], line[3]) 

357 line_ct += 1 

358 self.assertIn(str_line, input_lines) 

359 self.assertEqual(line_ct, len(input_lines)-1) 

360 

361 # test that iter_catalog_chunks returns the same result 

362 cat = FilteredCat5(self.db) 

363 line_ct = 0 

364 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

365 for ix in range(len(chunk[0])): 

366 str_line = '%d, %d, %d, %d\n' % \ 

367 (chunk[0][ix], chunk[1][ix], chunk[2][ix], chunk[3][ix]) 

368 line_ct += 1 

369 self.assertIn(str_line, input_lines) 

370 self.assertEqual(line_ct, len(input_lines)-1) 

371 

372 

373 if os.path.exists(cat_name): 

374 os.unlink(cat_name) 

375 

376 def test_filter_on_unused_compound_column(self): 

377 """ 

378 Test a catalog in which cannot_be_null is a compound column, but not one 

379 that is written to the catalog. 

380 """ 

381 

382 class FilteredCat5b(InstanceCatalog): 

383 column_outputs = ['id', 'a', 'b'] 

384 cannot_be_null = ['c'] 

385 

386 @compound('a', 'b', 'c') 

387 def get_alphabet(self): 

388 ii = self.column_by_name('ip3') 

389 c = ii*ii*ii*ii 

390 return np.array([ii*ii, ii*ii*ii, 

391 np.where(c % 3 == 0, c, None)]) 

392 

393 cat_name = os.path.join(self.scratch_dir, "inst_actual_compound_column_filter_b_cat.txt") 

394 if os.path.exists(cat_name): 

395 os.unlink(cat_name) 

396 

397 cat = FilteredCat5b(self.db) 

398 cat.write_catalog(cat_name) 

399 

400 with open(cat_name, 'r') as input_file: 

401 input_lines = input_file.readlines() 

402 

403 # verify that the catalog contains expected data 

404 self.assertEqual(len(input_lines), 5) # 4 data lines and a header 

405 for i_line, line in enumerate(input_lines): 

406 if i_line is 0: 

407 continue 

408 else: 

409 ii = (i_line - 1)*3 

410 ip3 = ii + 3 

411 self.assertEqual((ip3**4) % 3, 0) 

412 self.assertEqual(line, '%d, %d, %d\n' 

413 % (ii, ip3*ip3, ip3*ip3*ip3)) 

414 

415 # test that iter_catalog returns the same result 

416 cat = FilteredCat5b(self.db) 

417 line_ct = 0 

418 for line in cat.iter_catalog(): 

419 str_line = '%d, %d, %d\n' % (line[0], line[1], line[2]) 

420 line_ct += 1 

421 self.assertIn(str_line, input_lines) 

422 self.assertEqual(line_ct, len(input_lines)-1) 

423 

424 # test that iter_catalog_chunks returns the same result 

425 cat = FilteredCat5b(self.db) 

426 line_ct = 0 

427 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

428 for ix in range(len(chunk[0])): 

429 str_line = '%d, %d, %d\n' % \ 

430 (chunk[0][ix], chunk[1][ix], chunk[2][ix]) 

431 line_ct += 1 

432 self.assertIn(str_line, input_lines) 

433 self.assertEqual(line_ct, len(input_lines)-1) 

434 

435 

436 if os.path.exists(cat_name): 

437 os.unlink(cat_name) 

438 

439 def test_empty_chunk(self): 

440 """ 

441 Test that catalog filtering behaves correctly, even when the first 

442 chunk is empty 

443 """ 

444 

445 class FilteredCat6(InstanceCatalog): 

446 column_outputs = ['id', 'filter'] 

447 cannot_be_null = ['filter'] 

448 

449 @cached 

450 def get_filter(self): 

451 ii = self.column_by_name('ip1') 

452 return np.where(ii>5, ii, None) 

453 

454 cat_name = os.path.join(self.scratch_dir, "inst_empty_chunk_cat.txt") 

455 if os.path.exists(cat_name): 

456 os.unlink(cat_name) 

457 

458 cat = FilteredCat6(self.db) 

459 cat.write_catalog(cat_name, chunk_size=2) 

460 

461 # check that the catalog contains the correct information 

462 with open(cat_name, 'r') as input_file: 

463 input_lines = input_file.readlines() 

464 

465 self.assertEqual(len(input_lines), 6) # 5 data lines and a header 

466 for i_line, line in enumerate(input_lines): 

467 if i_line is 0: 

468 continue 

469 else: 

470 ii = 4 + i_line 

471 self.assertGreater(ii+1, 5) 

472 self.assertEqual(line, '%d, %d\n' % (ii, ii+1)) 

473 

474 # test that iter_catalog returns the same result 

475 cat = FilteredCat6(self.db) 

476 line_ct = 0 

477 for line in cat.iter_catalog(): 

478 str_line = '%d, %d\n' % (line[0], line[1]) 

479 line_ct += 1 

480 self.assertIn(str_line, input_lines) 

481 self.assertEqual(line_ct, len(input_lines)-1) 

482 

483 # test that iter_catalog_chunks returns the same result 

484 cat = FilteredCat6(self.db) 

485 line_ct = 0 

486 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

487 for ix in range(len(chunk[0])): 

488 str_line = '%d, %d\n' % \ 

489 (chunk[0][ix], chunk[1][ix]) 

490 line_ct += 1 

491 self.assertIn(str_line, input_lines) 

492 self.assertEqual(line_ct, len(input_lines)-1) 

493 

494 

495 if os.path.exists(cat_name): 

496 os.unlink(cat_name) 

497 

498 def test_hidden_filter(self): 

499 """ 

500 Test filtering on a column that is not written to the final catalog. 

501 """ 

502 class FilteredCat7(InstanceCatalog): 

503 column_outputs = ['id', 'ip1'] 

504 cannot_be_null = ['filter'] 

505 

506 def get_filter(self): 

507 ii = self.column_by_name('ip3') 

508 return np.where(ii<7, ii, None) 

509 

510 cat_name = os.path.join(self.scratch_dir, "inst_hidden_filter_cat.txt") 

511 if os.path.exists(cat_name): 

512 os.unlink(cat_name) 

513 

514 cat = FilteredCat7(self.db) 

515 cat.write_catalog(cat_name) 

516 

517 with open(cat_name, 'r') as input_file: 

518 input_lines = input_file.readlines() 

519 

520 self.assertEqual(len(input_lines), 5) 

521 for i_line, line in enumerate(input_lines): 

522 if i_line is 0: 

523 continue 

524 else: 

525 ii = i_line - 1 

526 self.assertLess(ii+3, 7) 

527 self.assertEqual(line, '%d, %d\n' % (ii, ii+1)) 

528 

529 # test that iter_catalog returns the same result 

530 cat = FilteredCat7(self.db) 

531 line_ct = 0 

532 for line in cat.iter_catalog(): 

533 str_line = '%d, %d\n' % (line[0], line[1]) 

534 line_ct += 1 

535 self.assertIn(str_line, input_lines) 

536 self.assertEqual(line_ct, len(input_lines)-1) 

537 

538 # test that iter_catalog_chunks returns the same result 

539 cat = FilteredCat7(self.db) 

540 line_ct = 0 

541 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

542 for ix in range(len(chunk[0])): 

543 str_line = '%d, %d\n' % \ 

544 (chunk[0][ix], chunk[1][ix]) 

545 line_ct += 1 

546 self.assertIn(str_line, input_lines) 

547 self.assertEqual(line_ct, len(input_lines)-1) 

548 

549 

550 if os.path.exists(cat_name): 

551 os.unlink(cat_name) 

552 

553 def test_adding_filter(self): 

554 """ 

555 Test that, when we use the kwarg in the constructor to add to 

556 cannot_be_null, the filter is appended to existing filters. 

557 """ 

558 class FilteredCat8(InstanceCatalog): 

559 column_outputs = ['id', 'ip1'] 

560 cannot_be_null = ['filter1'] 

561 

562 def get_filter1(self): 

563 ii = self.column_by_name('ip2') 

564 return np.where(ii%2 == 0, ii, None) 

565 

566 def get_filter2(self): 

567 ii = self.column_by_name('ip3') 

568 return np.where(ii > 8, ii, None) 

569 

570 cat_name = os.path.join(self.scratch_dir, "inst_adding_filter_cat.txt") 

571 if os.path.exists(cat_name): 

572 os.unlink(cat_name) 

573 

574 cat = FilteredCat8(self.db, cannot_be_null = ['filter2']) 

575 self.assertIn('filter1', cat._cannot_be_null) 

576 self.assertIn('filter2', cat._cannot_be_null) 

577 

578 cat.write_catalog(cat_name) 

579 with open(cat_name, 'r') as input_file: 

580 input_lines = input_file.readlines() 

581 self.assertEqual(len(input_lines), 3) 

582 for i_line, line in enumerate(input_lines): 

583 if i_line is 0: 

584 continue 

585 else: 

586 ii = (i_line - 1)*2 + 6 

587 self.assertEqual((ii+2) % 2, 0) 

588 self.assertGreater(ii+3, 8) 

589 self.assertEqual(line, '%d, %d\n' % (ii, ii+1)) 

590 

591 # test that iter_catalog returns the same result 

592 cat = FilteredCat8(self.db, cannot_be_null=['filter2']) 

593 line_ct = 0 

594 for line in cat.iter_catalog(): 

595 str_line = '%d, %d\n' % (line[0], line[1]) 

596 line_ct += 1 

597 self.assertIn(str_line, input_lines) 

598 self.assertEqual(line_ct, len(input_lines)-1) 

599 

600 # test that iter_catalog_chunks returns the same result 

601 cat = FilteredCat8(self.db, cannot_be_null=['filter2']) 

602 line_ct = 0 

603 for chunk, chunk_map in cat.iter_catalog_chunks(chunk_size=2): 

604 for ix in range(len(chunk[0])): 

605 str_line = '%d, %d\n' % \ 

606 (chunk[0][ix], chunk[1][ix]) 

607 line_ct += 1 

608 self.assertIn(str_line, input_lines) 

609 self.assertEqual(line_ct, len(input_lines)-1) 

610 

611 

612 if os.path.exists(cat_name): 

613 os.unlink(cat_name) 

614 

615 

616class CompoundInstanceCatalogTestCase(unittest.TestCase): 

617 """ 

618 This class will contain tests that will help us verify that using 

619 cannot_be_null to filter the contents of a CompoundInstanceCatalog 

620 works as it should. 

621 """ 

622 

623 @classmethod 

624 def setUpClass(cls): 

625 cls.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="scratchSpace-") 

626 

627 cls.db_src_name = os.path.join(cls.scratch_dir, 'compound_cat_filter_db.txt') 

628 if os.path.exists(cls.db_src_name): 

629 os.unlink(cls.db_src_name) 

630 

631 cls.db_name = os.path.join(cls.scratch_dir, 'compound_cat_filter_db.db') 

632 if os.path.exists(cls.db_name): 

633 os.unlink(cls.db_name) 

634 

635 with open(cls.db_src_name, 'w') as output_file: 

636 output_file.write('#a header\n') 

637 for ii in range(10): 

638 output_file.write('%d %d %d %d\n' % (ii, ii+1, ii+2, ii+3)) 

639 

640 dtype = np.dtype([('id', int), ('ip1', int), ('ip2', int), ('ip3', int)]) 

641 fileDBObject(cls.db_src_name, runtable='test', dtype=dtype, 

642 idColKey='id', database=cls.db_name) 

643 

644 @classmethod 

645 def tearDownClass(cls): 

646 

647 sims_clean_up() 

648 

649 if os.path.exists(cls.db_src_name): 

650 os.unlink(cls.db_src_name) 

651 

652 if os.path.exists(cls.db_name): 

653 os.unlink(cls.db_name) 

654 

655 if os.path.exists(cls.scratch_dir): 

656 shutil.rmtree(cls.scratch_dir) 

657 

658 def test_compound_cat(self): 

659 """ 

660 Test that a CompoundInstanceCatalog made up of InstanceCatalog classes that 

661 each filter on a different condition gives the correct outputs. 

662 """ 

663 

664 class FilteringCatClass1(InstanceCatalog): 

665 column_outputs = ['id', 'ip1t'] 

666 cannot_be_null = ['ip1t'] 

667 

668 @cached 

669 def get_ip1t(self): 

670 base = self.column_by_name('ip1') 

671 output = [] 

672 for bb in base: 

673 if bb%2 == 0: 

674 output.append(bb) 

675 else: 

676 output.append(None) 

677 return np.array(output) 

678 

679 class FilteringCatClass2(InstanceCatalog): 

680 column_outputs = ['id', 'ip2t'] 

681 cannot_be_null = ['ip2t'] 

682 

683 @cached 

684 def get_ip2t(self): 

685 base = self.column_by_name('ip2') 

686 ii = self.column_by_name('id') 

687 return np.where(ii < 4, base, None) 

688 

689 class FilteringCatClass3(InstanceCatalog): 

690 column_outputs = ['id', 'ip3t'] 

691 cannot_be_null = ['ip3t'] 

692 

693 @cached 

694 def get_ip3t(self): 

695 base = self.column_by_name('ip3') 

696 ii = self.column_by_name('id') 

697 return np.where(ii > 5, base, None) 

698 

699 class DbClass(CatalogDBObject): 

700 host = None 

701 port = None 

702 database = self.db_name 

703 driver = 'sqlite' 

704 tableid = 'test' 

705 objid = 'silliness' 

706 idColKey = 'id' 

707 

708 class DbClass1(DbClass): 

709 objid = 'silliness1' 

710 

711 class DbClass2(DbClass): 

712 objid = 'silliness2' 

713 

714 class DbClass3(DbClass): 

715 objid = 'silliness3' 

716 

717 cat = CompoundInstanceCatalog([FilteringCatClass1, 

718 FilteringCatClass2, 

719 FilteringCatClass3], 

720 [DbClass1, DbClass2, DbClass3]) 

721 

722 cat_name = os.path.join(self.scratch_dir, "compound_filter_output.txt") 

723 if os.path.exists(cat_name): 

724 os.unlink(cat_name) 

725 

726 cat.write_catalog(cat_name) 

727 

728 with open(cat_name, 'r') as input_file: 

729 input_lines = input_file.readlines() 

730 

731 self.assertEqual(len(input_lines), 14) 

732 

733 # given that we know what the contents of each sub-catalog should be 

734 # and how they should be ordered, loop through the lines of the output 

735 # catalog, verifying that every line is where it ought to be 

736 for i_line, line in enumerate(input_lines): 

737 if i_line is 0: 

738 continue 

739 elif i_line < 6: 

740 ii = 2*(i_line-1) + 1 

741 self.assertEqual((ii+1) % 2, 0) 

742 self.assertEqual(line, '%d, %d\n' % (ii, ii+1)) 

743 elif i_line < 10: 

744 ii = i_line - 6 

745 self.assertLess(ii, 4) 

746 self.assertEqual(line, '%d, %d\n' % (ii, ii+2)) 

747 else: 

748 ii = i_line - 10 + 6 

749 self.assertGreater(ii, 5) 

750 self.assertEqual(line, '%d, %d\n' % (ii, ii+3)) 

751 

752 if os.path.exists(cat_name): 

753 os.unlink(cat_name) 

754 

755 def test_compound_cat_compound_column(self): 

756 """ 

757 Test filtering a CompoundInstanceCatalog on a compound column 

758 """ 

759 

760 class FilteringCatClass4(InstanceCatalog): 

761 column_outputs = ['id', 'a', 'b'] 

762 cannot_be_null = ['a'] 

763 

764 @compound('a', 'b') 

765 def get_alphabet(self): 

766 a = self.column_by_name('ip1') 

767 a = a*a 

768 b = self.column_by_name('ip2') 

769 b = b*0.25 

770 return np.array([np.where(a % 2 == 0, a, None), b]) 

771 

772 class FilteringCatClass5(InstanceCatalog): 

773 column_outputs = ['id', 'a', 'b', 'filter'] 

774 cannot_be_null = ['b', 'filter'] 

775 

776 @compound('a', 'b') 

777 def get_alphabet(self): 

778 ii = self.column_by_name('ip1') 

779 return np.array([self.column_by_name('ip2')**3, 

780 np.where(ii % 2 == 1, ii, None)]) 

781 

782 @cached 

783 def get_filter(self): 

784 ii = self.column_by_name('ip1') 

785 return np.where(ii % 3 != 0, ii, None) 

786 

787 class DbClass(CatalogDBObject): 

788 host = None 

789 port = None 

790 database = self.db_name 

791 driver = 'sqlite' 

792 tableid = 'test' 

793 idColKey = 'id' 

794 

795 class DbClass4(DbClass): 

796 objid = 'silliness4' 

797 

798 class DbClass5(DbClass): 

799 objid = 'silliness5' 

800 

801 cat_name = os.path.join(self.scratch_dir, "compound_cat_compound_filter_cat.txt") 

802 if os.path.exists(cat_name): 

803 os.unlink(cat_name) 

804 

805 cat = CompoundInstanceCatalog([FilteringCatClass4, FilteringCatClass5], [DbClass4, DbClass5]) 

806 cat.write_catalog(cat_name) 

807 

808 # now make sure that the catalog contains the expected data 

809 with open(cat_name, 'r') as input_file: 

810 input_lines = input_file.readlines() 

811 self.assertEqual(len(input_lines), 9) # 8 data lines and a header 

812 

813 first_cat_lines = ['1, 4, 0.75\n', '3, 16, 1.25\n', 

814 '5, 36, 1.75\n', '7, 64, 2.25\n', 

815 '9, 100, 2.75\n'] 

816 

817 second_cat_lines = ['0, 8, 1, 1\n', '4, 216, 5, 5\n', 

818 '6, 512, 7, 7\n'] 

819 

820 for i_line, line in enumerate(input_lines): 

821 if i_line is 0: 

822 continue 

823 elif i_line < 6: 

824 self.assertEqual(line, first_cat_lines[i_line-1]) 

825 else: 

826 self.assertEqual(line, second_cat_lines[i_line-6]) 

827 

828 if os.path.exists(cat_name): 

829 os.unlink(cat_name) 

830 

831 

832class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

833 pass 

834 

835 

836if __name__ == "__main__": 836 ↛ 837line 836 didn't jump to line 837, because the condition on line 836 was never true

837 lsst.utils.tests.init() 

838 unittest.main()