Coverage for metadetect / tests / test_metadetect.py: 9%

474 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-23 08:44 +0000

1""" 

2test with super simple sim. The purpose here is not 

3to make sure it gets the right answer or anything, just 

4to test all the moving parts 

5""" 

6import time 

7import copy 

8import itertools 

9 

10import pytest 

11 

12from packaging import version 

13import ngmix 

14import numpy as np 

15 

16from metadetect import metadetect 

17from metadetect import fitting 

18from metadetect import procflags 

19from .sim import Sim 

20 

21 

22TEST_METADETECT_CONFIG = { 

23 "model": "wmom", 

24 

25 "weight": { 

26 "fwhm": 1.2, # arcsec 

27 }, 

28 

29 "metacal": { 

30 "psf": "fitgauss", 

31 "types": ["noshear", "1p", "1m", "2p", "2m"], 

32 }, 

33 

34 "sx": { 

35 # in sky sigma 

36 # DETECT_THRESH 

37 "detect_thresh": 0.8, 

38 

39 # Minimum contrast parameter for deblending 

40 # DEBLEND_MINCONT 

41 "deblend_cont": 0.00001, 

42 

43 # minimum number of pixels above threshold 

44 # DETECT_MINAREA: 6 

45 "minarea": 4, 

46 

47 "filter_type": "conv", 

48 

49 # 7x7 convolution mask of a gaussian PSF with FWHM = 3.0 pixels. 

50 "filter_kernel": [ 

51 [0.004963, 0.021388, 0.051328, 0.068707, 0.051328, 0.021388, 0.004963], # noqa 

52 [0.021388, 0.092163, 0.221178, 0.296069, 0.221178, 0.092163, 0.021388], # noqa 

53 [0.051328, 0.221178, 0.530797, 0.710525, 0.530797, 0.221178, 0.051328], # noqa 

54 [0.068707, 0.296069, 0.710525, 0.951108, 0.710525, 0.296069, 0.068707], # noqa 

55 [0.051328, 0.221178, 0.530797, 0.710525, 0.530797, 0.221178, 0.051328], # noqa 

56 [0.021388, 0.092163, 0.221178, 0.296069, 0.221178, 0.092163, 0.021388], # noqa 

57 [0.004963, 0.021388, 0.051328, 0.068707, 0.051328, 0.021388, 0.004963], # noqa 

58 ] 

59 }, 

60 

61 "meds": { 

62 "min_box_size": 32, 

63 "max_box_size": 256, 

64 

65 "box_type": "iso_radius", 

66 

67 "rad_min": 4, 

68 "rad_fac": 2, 

69 "box_padding": 2, 

70 }, 

71 

72 # check for an edge hit 

73 "bmask_flags": 2**30, 

74 

75 "nodet_flags": 2**0, 

76} 

77 

78 

79def _show_mbobs(mer): 

80 import images 

81 

82 mbobs = mer.mbobs 

83 

84 rgb = images.get_color_image( 

85 mbobs[2][0].image.transpose(), 

86 mbobs[1][0].image.transpose(), 

87 mbobs[0][0].image.transpose(), 

88 nonlinear=0.1, 

89 ) 

90 rgb *= 1.0/rgb.max() 

91 

92 images.view_mosaic( 

93 [rgb, 

94 mer.seg, 

95 mer.detim], 

96 titles=["image", "seg", "detim"], 

97 ) 

98 

99 

100def test_detect(ntrial=1, show=False): 

101 """ 

102 just test the detection 

103 """ 

104 pytest.importorskip("meds") 

105 pytest.importorskip("sxdes") 

106 from .. import detect 

107 

108 rng = np.random.RandomState(seed=45) 

109 

110 tm0 = time.time() 

111 nobj_meas = 0 

112 

113 sim = Sim(rng) 

114 

115 config = {} 

116 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

117 

118 for trial in range(ntrial): 

119 print("trial: %d/%d" % (trial+1, ntrial)) 

120 

121 mbobs = sim.get_mbobs() 

122 mer = detect.MEDSifier( 

123 mbobs=mbobs, 

124 sx_config=config["sx"], 

125 meds_config=config["meds"], 

126 ) 

127 

128 mbm = mer.get_multiband_meds() 

129 

130 nobj = mbm.size 

131 nobj_meas += nobj 

132 

133 if show: 

134 _show_mbobs(mer) 

135 if ntrial > 1 and trial != (ntrial-1): 

136 if "q" == input("hit a key: "): 

137 return 

138 

139 total_time = time.time()-tm0 

140 print("found", nobj_meas, "objects") 

141 print("time per group:", total_time/ntrial) 

142 print("time per object:", total_time/nobj_meas) 

143 

144 

145def test_detect_masking(ntrial=1, show=False): 

146 """ 

147 just test the detection 

148 """ 

149 pytest.importorskip("meds") 

150 pytest.importorskip("sxdes") 

151 from .. import detect 

152 

153 rng = np.random.RandomState(seed=45) 

154 

155 sim = Sim(rng) 

156 

157 config = {} 

158 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

159 

160 for trial in range(ntrial): 

161 print("trial: %d/%d" % (trial+1, ntrial)) 

162 

163 mbobs = sim.get_mbobs() 

164 for obslist in mbobs: 

165 for obs in obslist: 

166 obs.bmask = obs.bmask | config["nodet_flags"] 

167 

168 mer = detect.MEDSifier( 

169 mbobs=mbobs, 

170 sx_config=config["sx"], 

171 meds_config=config["meds"], 

172 nodet_flags=config["nodet_flags"], 

173 ) 

174 assert mer.cat.size == 0 

175 

176 

177def _check_result_array(res, shear, msk, model): 

178 for col in res[shear].dtype.names: 

179 if col == "shear_bands": 

180 assert np.all(res[shear][msk][col] == "012") 

181 elif col == "det_bands": 

182 assert np.all(res[shear][msk][col] == "012") 

183 else: 

184 # admom doesn't make band fluxes 

185 if model in ["admom", "am"] and "band_flux" in col: 

186 if col.endswith("band_flux_flags"): 

187 assert np.array_equal( 

188 res[shear][msk][col], 

189 np.zeros_like(res[shear][msk][col]) 

190 + procflags.NO_ATTEMPT, 

191 ), ( 

192 "result column '%s' is not NO_ATTEMPT: %s" % ( 

193 col, res[shear][msk][col] 

194 ) 

195 ) 

196 elif any( 

197 col.endswith(s) for s in ["band_flux", "band_flux_err"] 

198 ): 

199 assert np.all(np.isnan( 

200 res[shear][msk][col], 

201 )), ( 

202 "result column '%s' is not NaN: %s" % ( 

203 col, res[shear][msk][col] 

204 ) 

205 ) 

206 else: 

207 assert np.all(np.isfinite(res[shear][msk][col])), ( 

208 "result column '%s' has NaNs: %s" % ( 

209 col, res[shear][msk][col] 

210 ) 

211 ) 

212 

213 

214@pytest.mark.xfail(reason="flaky performance") 

215@pytest.mark.parametrize("model", ["gauss", "wmom"]) 

216def test_metadetect_coadd_faster(model): 

217 """ 

218 test coadding is faster 

219 """ 

220 pytest.importorskip("sxdes") 

221 

222 config = {} 

223 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

224 del config["model"] 

225 del config["weight"] 

226 config["fitters"] = [ 

227 {"model": model, "coadd": False, "weight": {"fwhm": 1.2}} 

228 ] 

229 

230 # warm up once 

231 mbobs = Sim( 

232 np.random.RandomState(seed=116), config={"nband": 6} 

233 ).get_mbobs() 

234 metadetect.do_metadetect( 

235 config, mbobs, np.random.RandomState(seed=116) 

236 ) 

237 

238 mbobs = Sim( 

239 np.random.RandomState(seed=116), config={"nband": 6} 

240 ).get_mbobs() 

241 tm0 = time.time() 

242 metadetect.do_metadetect( 

243 config, mbobs, np.random.RandomState(seed=116) 

244 ) 

245 no_coadd_time = time.time() - tm0 

246 

247 config["fitters"][0]["coadd"] = True 

248 mbobs = Sim( 

249 np.random.RandomState(seed=116), config={"nband": 6} 

250 ).get_mbobs() 

251 tm0 = time.time() 

252 metadetect.do_metadetect( 

253 config, mbobs, np.random.RandomState(seed=116) 

254 ) 

255 coadd_time = time.time() - tm0 

256 

257 print("coadd|nocoadd: %f|%f" % (coadd_time, no_coadd_time)) 

258 

259 if model == "gauss": 

260 assert coadd_time < no_coadd_time*0.7, (coadd_time, no_coadd_time) 

261 else: 

262 assert np.allclose(coadd_time, no_coadd_time, atol=0, rtol=0.3) 

263 

264 

265@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma", "am", "gauss"]) 

266def test_metadetect_smoke(model): 

267 """ 

268 test full metadetection 

269 """ 

270 pytest.importorskip("sxdes") 

271 

272 ntrial = 1 

273 rng = np.random.RandomState(seed=116) 

274 

275 tm0 = time.time() 

276 

277 sim = Sim(rng) 

278 config = {} 

279 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

280 config["model"] = model 

281 

282 for trial in range(ntrial): 

283 print("trial: %d/%d" % (trial+1, ntrial)) 

284 

285 mbobs = sim.get_mbobs() 

286 res = metadetect.do_metadetect(config, mbobs, rng) 

287 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

288 assert np.all(res[shear]["mfrac"] == 0) 

289 assert any(c.endswith("band_flux") for c in res[shear].dtype.names) 

290 assert np.any(res[shear]["psfrec_g"] != 0) 

291 assert np.any(res[shear]["psfrec_T"] != 0) 

292 msk = res[shear][model + '_flags'] == 0 

293 _check_result_array(res, shear, msk, model) 

294 

295 total_time = time.time()-tm0 

296 print("time per:", total_time/ntrial) 

297 

298 

299@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma", "am", "gauss"]) 

300def test_metadetect_uberseg(model): 

301 """ 

302 test full metadetection 

303 """ 

304 pytest.importorskip("sxdes") 

305 

306 ntrial = 1 

307 rng = np.random.RandomState(seed=116) 

308 

309 tm0 = time.time() 

310 

311 sim = Sim(rng) 

312 config = {} 

313 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

314 config["model"] = model 

315 config["meds"]["weight_type"] = "uberseg" 

316 

317 mbobs = sim.get_mbobs() 

318 res = metadetect.do_metadetect(config, mbobs, rng) 

319 

320 for trial in range(ntrial): 

321 print("trial: %d/%d" % (trial+1, ntrial)) 

322 

323 res = metadetect.do_metadetect(config, mbobs, rng) 

324 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

325 assert np.all(res[shear]["mfrac"] == 0) 

326 assert any(c.endswith("band_flux") for c in res[shear].dtype.names) 

327 assert np.any(res[shear]["psfrec_g"] != 0) 

328 assert np.any(res[shear]["psfrec_T"] != 0) 

329 msk = res[shear][model + '_flags'] == 0 

330 _check_result_array(res, shear, msk, model) 

331 

332 total_time = time.time()-tm0 

333 print("time per:", total_time/ntrial) 

334 

335 

336@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma", "am", "gauss"]) 

337def test_metadetect_mfrac(model): 

338 """ 

339 test full metadetection w/ mfrac 

340 """ 

341 pytest.importorskip("sxdes") 

342 

343 ntrial = 1 

344 rng = np.random.RandomState(seed=53341) 

345 

346 tm0 = time.time() 

347 

348 sim = Sim(rng) 

349 config = {} 

350 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

351 config["model"] = model 

352 

353 for trial in range(ntrial): 

354 print("trial: %d/%d" % (trial+1, ntrial)) 

355 

356 mbobs = sim.get_mbobs() 

357 for band in range(len(mbobs)): 

358 mbobs[band][0].mfrac = rng.uniform( 

359 size=mbobs[band][0].image.shape, low=0.2, high=0.8 

360 ) 

361 res = metadetect.do_metadetect(config, mbobs, rng) 

362 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

363 assert np.all( 

364 (res[shear]["mfrac"] > 0.45) 

365 & (res[shear]["mfrac"] < 0.55) 

366 ) 

367 assert np.all( 

368 (res[shear]["mfrac_img"] > 0.45) 

369 & (res[shear]["mfrac_img"] < 0.55) 

370 ) 

371 assert np.all( 

372 (res[shear]["mfrac_noshear"] > 0.45) 

373 & (res[shear]["mfrac_noshear"] < 0.55) 

374 ) 

375 msk = res[shear][model + '_flags'] == 0 

376 _check_result_array(res, shear, msk, model) 

377 

378 total_time = time.time()-tm0 

379 print("time per:", total_time/ntrial) 

380 

381 

382@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma", "am", "gauss"]) 

383def test_metadetect_mfrac_all(model): 

384 """ 

385 test full metadetection w/ mfrac all 1 

386 """ 

387 ntrial = 1 

388 rng = np.random.RandomState(seed=53341) 

389 

390 sim = Sim(rng) 

391 config = {} 

392 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

393 config["model"] = model 

394 

395 for trial in range(ntrial): 

396 print("trial: %d/%d" % (trial+1, ntrial)) 

397 

398 mbobs = sim.get_mbobs() 

399 for band in range(len(mbobs)): 

400 mbobs[band][0].mfrac = np.ones_like(mbobs[band][0].image) 

401 

402 res = metadetect.do_metadetect(config, mbobs, rng) 

403 assert res is None 

404 

405 

406@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma", "am", "gauss"]) 

407def test_metadetect_zero_weight_all(model): 

408 """ 

409 test full metadetection w/ all zero weight 

410 """ 

411 

412 ntrial = 1 

413 rng = np.random.RandomState(seed=53341) 

414 

415 sim = Sim(rng) 

416 config = {} 

417 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

418 config["model"] = model 

419 

420 for trial in range(ntrial): 

421 print("trial: %d/%d" % (trial+1, ntrial)) 

422 

423 mbobs = sim.get_mbobs() 

424 for band in range(len(mbobs)): 

425 mbobs[band][0].weight = np.zeros_like(mbobs[band][0].image) 

426 

427 res = metadetect.do_metadetect(config, mbobs, rng) 

428 assert res is None 

429 

430 

431@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma", "am", "gauss"]) 

432def test_metadetect_zero_weight_some(model): 

433 """ 

434 test full metadetection w/ some zero weight 

435 """ 

436 ntrial = 1 

437 rng = np.random.RandomState(seed=53341) 

438 

439 sim = Sim(rng) 

440 config = {} 

441 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

442 config["model"] = model 

443 

444 for trial in range(ntrial): 

445 print("trial: %d/%d" % (trial+1, ntrial)) 

446 

447 mbobs = sim.get_mbobs() 

448 for band in range(len(mbobs)): 

449 if band == 1: 

450 mbobs[band][0].weight = np.zeros_like(mbobs[band][0].image) 

451 

452 res = metadetect.do_metadetect(config, mbobs, rng) 

453 assert res is None 

454 

455 

456@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma", "am", "gauss"]) 

457def test_metadetect_nodet_flags_all(model): 

458 """ 

459 test full metadetection w/ all bmask all nodet_flags 

460 """ 

461 ntrial = 1 

462 rng = np.random.RandomState(seed=53341) 

463 

464 sim = Sim(rng) 

465 config = {} 

466 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

467 config["model"] = model 

468 

469 for trial in range(ntrial): 

470 print("trial: %d/%d" % (trial+1, ntrial)) 

471 

472 mbobs = sim.get_mbobs() 

473 for band in range(len(mbobs)): 

474 mbobs[band][0].bmask = np.ones_like( 

475 mbobs[band][0].image, dtype=np.int32 

476 ) 

477 

478 res = metadetect.do_metadetect(config, mbobs, rng) 

479 assert res is None 

480 

481 

482@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma", "am", "gauss"]) 

483def test_metadetect_nodet_flags_some(model): 

484 """ 

485 test full metadetection w/ some bmask nodet_flags 

486 """ 

487 ntrial = 1 

488 rng = np.random.RandomState(seed=53341) 

489 

490 sim = Sim(rng) 

491 config = {} 

492 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

493 config["model"] = model 

494 

495 for trial in range(ntrial): 

496 print("trial: %d/%d" % (trial+1, ntrial)) 

497 

498 mbobs = sim.get_mbobs() 

499 for band in range(len(mbobs)): 

500 if band == 1: 

501 mbobs[band][0].bmask = np.ones_like( 

502 mbobs[band][0].image, dtype=np.int32 

503 ) 

504 

505 res = metadetect.do_metadetect(config, mbobs, rng) 

506 assert res is None 

507 

508 

509@pytest.mark.skipif( 

510 version.parse(ngmix.__version__) < version.parse("2.1.0"), 

511 reason="ngmix version 2.1.0 or greater is needed for smoothing prepsf moments", 

512) 

513@pytest.mark.parametrize("model", ["pgauss", "ksigma"]) 

514def test_metadetect_fitter_fwhm_smooth(model): 

515 pytest.importorskip("sxdes") 

516 

517 nband = 3 

518 rng = np.random.RandomState(seed=116) 

519 

520 sim = Sim(rng, config={"nband": nband}) 

521 mbobs = sim.get_mbobs() 

522 

523 config = {} 

524 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

525 config["model"] = model 

526 config["weight"]["fwhm"] = 1.2 

527 

528 md = metadetect.Metadetect( 

529 config, mbobs, rng, 

530 ) 

531 md.go() 

532 res = md.result 

533 assert md._fitters[0].fwhm_smooth == 0 

534 

535 config["weight"]["fwhm_smooth"] = 0.8 

536 md = metadetect.Metadetect( 

537 config, mbobs, rng, 

538 ) 

539 md.go() 

540 res_smooth = md.result 

541 assert md._fitters[0].fwhm_smooth == 0.8 

542 

543 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

544 msk = res[shear][model + "_flags"] == 0 

545 msk_smooth = res_smooth[shear][model + "_flags"] == 0 

546 assert ( 

547 np.mean(res_smooth[shear][model + "_T"][msk_smooth]) 

548 > np.mean(res[shear][model + "_T"][msk]) 

549 ) 

550 assert ( 

551 np.mean(res_smooth[shear][model + "_g_cov"][msk_smooth, 0, 0]) 

552 < np.mean(res[shear][model + "_g_cov"][msk, 0, 0]) 

553 ) 

554 

555 

556@pytest.mark.parametrize("model", ["pgauss", "ksigma"]) 

557def test_metadetect_fitter_fwhm_reg(model): 

558 pytest.importorskip("sxdes") 

559 

560 nband = 3 

561 rng = np.random.RandomState(seed=116) 

562 

563 sim = Sim(rng, config={"nband": nband}) 

564 mbobs = sim.get_mbobs() 

565 

566 config = {} 

567 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

568 config["model"] = model 

569 config["weight"]["fwhm"] = 1.2 

570 

571 rng = np.random.RandomState(seed=116) 

572 md = metadetect.Metadetect( 

573 config, mbobs, rng, 

574 ) 

575 md.go() 

576 res = md.result 

577 

578 config["weight"]["fwhm_reg"] = 0.8 

579 rng = np.random.RandomState(seed=116) 

580 md = metadetect.Metadetect( 

581 config, mbobs, rng, 

582 ) 

583 md.go() 

584 res_reg = md.result 

585 

586 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

587 msk = res[shear][model + "_flags"] == 0 

588 msk_reg = res_reg[shear][model + "_reg0.80_" + "flags"] == 0 

589 assert np.allclose( 

590 res_reg[shear][model + "_reg0.80" + "_T"][msk_reg], 

591 res[shear][model + "_T"][msk] 

592 ) 

593 assert not np.allclose( 

594 res_reg[shear][model + "_reg0.80" + "_g"][msk_reg], 

595 res[shear][model + "_g"][msk] 

596 ) 

597 

598 

599def test_metadetect_fitter_multi_meas(): 

600 pytest.importorskip("sxdes") 

601 

602 nband = 3 

603 rng = np.random.RandomState(seed=116) 

604 

605 sim = Sim(rng, config={"nband": nband}) 

606 mbobs = sim.get_mbobs() 

607 

608 config = {} 

609 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

610 del config["model"] 

611 del config["weight"] 

612 config["fitters"] = [ 

613 {"model": "wmom", "weight": {"fwhm": 1.2}}, 

614 {"model": "pgauss", "weight": {"fwhm": 2.0}}, 

615 {"model": "pgauss", "weight": {"fwhm": 2.0, "fwhm_reg": 0.8}}, 

616 {"model": "am"}, 

617 {"model": "gauss"}, 

618 ] 

619 

620 rng = np.random.RandomState(seed=116) 

621 md = metadetect.Metadetect( 

622 config, mbobs, rng, 

623 ) 

624 md.go() 

625 res = md.result 

626 

627 model = "pgauss" 

628 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

629 msk = res[shear][model + "_flags"] == 0 

630 msk_reg = res[shear][model + "_reg0.80_" + "flags"] == 0 

631 msk_wmom = res[shear]["wmom_flags"] == 0 

632 msk_admom = res[shear]["am_flags"] == 0 

633 msk_gauss = res[shear]["gauss_flags"] == 0 

634 assert np.allclose( 

635 res[shear][model + "_reg0.80" + "_T"][msk_reg], 

636 res[shear][model + "_T"][msk] 

637 ) 

638 assert not np.allclose( 

639 res[shear][model + "_reg0.80" + "_g"][msk_reg], 

640 res[shear][model + "_g"][msk] 

641 ) 

642 assert not np.allclose( 

643 res[shear]["wmom_T"][msk_wmom], 

644 res[shear][model + "_T"][msk] 

645 ) 

646 assert not np.allclose( 

647 res[shear]["wmom_g"][msk_wmom], 

648 res[shear][model + "_g"][msk] 

649 ) 

650 

651 # admom can fail so look at intersection 

652 assert not np.allclose( 

653 res[shear]["am_T"][msk_admom & msk], 

654 res[shear][model + "_T"][msk_admom & msk] 

655 ) 

656 assert not np.allclose( 

657 res[shear]["am_g"][msk_admom & msk], 

658 res[shear][model + "_g"][msk_admom & msk] 

659 ) 

660 

661 # gauss can fail so look at intersection 

662 assert not np.allclose( 

663 res[shear]["gauss_T"][msk_gauss & msk], 

664 res[shear][model + "_T"][msk_gauss & msk] 

665 ) 

666 assert not np.allclose( 

667 res[shear]["gauss_g"][msk_gauss & msk], 

668 res[shear][model + "_g"][msk_gauss & msk] 

669 ) 

670 

671 

672@pytest.mark.parametrize("model", ["wmom", "pgauss", "ksigma"]) 

673@pytest.mark.parametrize("nband,nshear", [(3, 2), (1, 1), (4, 2), (3, 1)]) 

674def test_metadetect_flux(model, nband, nshear): 

675 """ 

676 test full metadetection w/ fluxes 

677 """ 

678 pytest.importorskip("sxdes") 

679 

680 ntrial = 1 

681 rng = np.random.RandomState(seed=116) 

682 

683 tm0 = time.time() 

684 

685 sim = Sim(rng, config={"nband": nband}) 

686 config = {} 

687 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

688 config["model"] = model 

689 

690 for trial in range(ntrial): 

691 print("trial: %d/%d" % (trial+1, ntrial)) 

692 

693 mbobs = sim.get_mbobs() 

694 for shear_bands in itertools.combinations(list(range(nband)), nshear): 

695 res = metadetect.do_metadetect( 

696 config, mbobs, rng, shear_band_combs=[shear_bands], 

697 det_band_combs="shear_bands", 

698 ) 

699 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

700 assert np.all(res[shear]["mfrac"] == 0) 

701 assert np.all( 

702 res[shear]["shear_bands"] == "".join("%s" % b for b in shear_bands) 

703 ) 

704 assert np.all( 

705 res[shear]["det_bands"] == "".join("%s" % b for b in shear_bands) 

706 ) 

707 for c in res[shear].dtype.names: 

708 if c.endswith("band_flux"): 

709 if nband > 1: 

710 assert res[shear][c][0].shape == (nband,) 

711 else: 

712 assert res[shear][c][0].shape == tuple() 

713 

714 total_time = time.time()-tm0 

715 print("time per:", total_time/ntrial) 

716 

717 

718@pytest.mark.parametrize("coadd", [True, False]) 

719@pytest.mark.parametrize("det_bands", [None, "shear_bands", "single"]) 

720@pytest.mark.parametrize("model", ["wmom", "pgauss", "am", "gauss"]) 

721def test_metadetect_multiband(model, det_bands, coadd): 

722 """ 

723 test full metadetection w/ multiple bands 

724 """ 

725 pytest.importorskip("sxdes") 

726 

727 nband = 3 

728 ntrial = 1 

729 rng = np.random.RandomState(seed=116) 

730 

731 tm0 = time.time() 

732 

733 sim = Sim(rng, config={"nband": nband}) 

734 config = {} 

735 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

736 config["model"] = model 

737 config["coadd"] = coadd 

738 

739 for trial in range(ntrial): 

740 print("trial: %d/%d" % (trial+1, ntrial)) 

741 

742 mbobs = sim.get_mbobs() 

743 shear_band_combs = [list(range(nband))] 

744 shear_band_combs += [ 

745 list(shear_bands) 

746 for shear_bands in itertools.combinations(list(range(nband)), 2) 

747 ] 

748 shear_band_combs += [ 

749 list(shear_bands) 

750 for shear_bands in itertools.combinations(list(range(nband)), 1) 

751 ] 

752 det_band_combs = ( 

753 det_bands 

754 if det_bands != "single" 

755 else [[0]] * len(shear_band_combs) 

756 ) 

757 res = metadetect.do_metadetect( 

758 config, mbobs, rng, shear_band_combs=shear_band_combs, 

759 det_band_combs=det_band_combs, 

760 ) 

761 if det_band_combs is None: 

762 det_band_combs = [list(range(nband))] * len(shear_band_combs) 

763 elif det_band_combs == "shear_bands": 

764 det_band_combs = shear_band_combs 

765 

766 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

767 assert np.all(res[shear]["mfrac"] == 0) 

768 for det_bands, shear_bands in zip(det_band_combs, shear_band_combs): 

769 assert np.any( 

770 ( 

771 res[shear]["shear_bands"] 

772 == "".join("%s" % b for b in shear_bands) 

773 ) 

774 & ( 

775 res[shear]["det_bands"] 

776 == "".join("%s" % b for b in det_bands) 

777 ) 

778 ) 

779 for c in res[shear].dtype.names: 

780 if c.endswith("band_flux"): 

781 if nband > 1: 

782 assert res[shear][c][0].shape == (nband,) 

783 else: 

784 assert res[shear][c][0].shape == tuple() 

785 

786 total_time = time.time()-tm0 

787 print("time per:", total_time/ntrial) 

788 

789 

790def test_metadetect_with_color_is_same(): 

791 pytest.importorskip("sxdes") 

792 

793 model = "wmom" 

794 nband = 3 

795 ntrial = 1 

796 

797 tm0 = time.time() 

798 

799 config = {} 

800 config.update(copy.deepcopy(TEST_METADETECT_CONFIG)) 

801 config["model"] = model 

802 

803 for trial in range(ntrial): 

804 print("trial: %d/%d" % (trial+1, ntrial)) 

805 

806 shear_band_combs = [list(range(nband))] 

807 shear_band_combs += [ 

808 list(shear_bands) 

809 for shear_bands in itertools.combinations(list(range(nband)), 2) 

810 ] 

811 shear_band_combs += [ 

812 list(shear_bands) 

813 for shear_bands in itertools.combinations(list(range(nband)), 1) 

814 ] 

815 

816 rng = np.random.RandomState(seed=116) 

817 sim = Sim(rng, config={"nband": nband}) 

818 mbobs = sim.get_mbobs() 

819 rng = np.random.RandomState(seed=11) 

820 res = metadetect.do_metadetect( 

821 config, mbobs, rng, shear_band_combs=shear_band_combs, 

822 ) 

823 

824 rng = np.random.RandomState(seed=116) 

825 sim = Sim(rng, config={"nband": nband}) 

826 mbobs = sim.get_mbobs() 

827 rng = np.random.RandomState(seed=11) 

828 res_color = metadetect.do_metadetect( 

829 config, mbobs, rng, shear_band_combs=shear_band_combs, 

830 color_key_func=lambda x: "blah", color_dep_mbobs={"blah": mbobs}, 

831 ) 

832 for shear in ["noshear", "1p", "1m", "2p", "2m"]: 

833 for col in res[shear].dtype.names: 

834 assert col in res_color[shear].dtype.names 

835 if col == "shear_bands" or col == "det_bands": 

836 assert np.array_equal( 

837 res[shear][col], 

838 res_color[shear][col], 

839 ) 

840 else: 

841 np.testing.assert_allclose( 

842 res[shear][col], 

843 res_color[shear][col], 

844 atol=0, 

845 rtol=0, 

846 equal_nan=True, 

847 ) 

848 

849 for shear_bands in shear_band_combs: 

850 assert np.any( 

851 res[shear]["shear_bands"] == "".join("%s" % b for b in shear_bands) 

852 ) 

853 assert np.any( 

854 res_color[shear]["shear_bands"] 

855 == "".join("%s" % b for b in shear_bands) 

856 ) 

857 

858 total_time = time.time()-tm0 

859 print("time per:", total_time/ntrial) 

860 

861 

862@pytest.mark.parametrize("mask_region", [1, 7]) 

863def test_fill_in_mask_col(mask_region): 

864 rng = np.random.RandomState(seed=10) 

865 

866 rows = np.array([31.2]) 

867 cols = np.array([51.7]) 

868 mask = rng.randint(low=0, high=64, size=(100, 100)) 

869 

870 vals = metadetect._fill_in_mask_col( 

871 mask_region=mask_region, 

872 rows=rows, 

873 cols=cols, 

874 mask=mask) 

875 

876 row = 31 

877 col = 52 

878 if mask_region == 1: 

879 assert vals[0] == mask[row, col] 

880 else: 

881 assert vals[0] == np.bitwise_or.reduce( 

882 mask[ 

883 row-mask_region:row+mask_region+1, 

884 col-mask_region:col+mask_region+1 

885 ] 

886 )[0] 

887 

888 

889def test_get_psf_stats(): 

890 rng = np.random.RandomState(seed=10) 

891 sim = Sim(rng) 

892 mbobs = sim.get_mbobs() 

893 fitting.fit_all_psfs(mbobs, rng) 

894 

895 psf_stats = metadetect._get_psf_stats(mbobs, 0) 

896 assert psf_stats["flags"] == 0 

897 assert np.isfinite(psf_stats["g1"]) 

898 assert np.isfinite(psf_stats["g2"]) 

899 assert np.isfinite(psf_stats["T"]) 

900 

901 psf_stats = metadetect._get_psf_stats(mbobs, 2) 

902 assert psf_stats["flags"] == (procflags.PSF_FAILURE | 2) 

903 assert not np.isfinite(psf_stats["g1"]) 

904 assert not np.isfinite(psf_stats["g2"]) 

905 assert not np.isfinite(psf_stats["T"]) 

906 

907 for obslist in mbobs: 

908 for obs in obslist: 

909 obs.weight = -1.0*obs.weight 

910 psf_stats = metadetect._get_psf_stats(mbobs, 0) 

911 assert psf_stats["flags"] == procflags.PSF_FAILURE 

912 assert not np.isfinite(psf_stats["g1"]) 

913 assert not np.isfinite(psf_stats["g2"]) 

914 assert not np.isfinite(psf_stats["T"]) 

915 

916 e1s = np.arange(len(mbobs)) + 0.1 

917 e2s = 2*np.arange(len(mbobs)) + 0.1 

918 Ts = 3*np.arange(len(mbobs)) + 0.1 

919 wgts = np.arange(len(mbobs)) + 1 

920 for i, obslist in enumerate(mbobs): 

921 for obs in obslist: 

922 obs.weight = 0*obs.weight + wgts[i] 

923 obs.psf.meta["result"]["e"] = (e1s[i], e2s[i]) 

924 obs.psf.meta["result"]["T"] = Ts[i] 

925 

926 psf_stats = metadetect._get_psf_stats(mbobs, 0) 

927 assert psf_stats["flags"] == 0 

928 assert np.allclose(psf_stats["g1"], np.sum(wgts * e1s)/np.sum(wgts)) 

929 assert np.allclose(psf_stats["g2"], np.sum(wgts * e2s)/np.sum(wgts)) 

930 assert np.allclose(psf_stats["T"], np.sum(wgts * Ts)/np.sum(wgts))