Coverage for metadetect / tests / test_fitting.py: 5%

456 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-25 08:38 +0000

1import numpy as np 

2import galsim 

3import ngmix 

4 

5import pytest 

6 

7from ngmix.gaussmom import GaussMom 

8from ngmix.moments import fwhm_to_T 

9 

10from .sim import make_mbobs_sim 

11from metadetect.fitting import ( 

12 fit_mbobs_wavg, 

13 _combine_fit_results_wavg, 

14 symmetrize_obs_weights, 

15 fit_all_psfs, 

16 _sum_bands_wavg, 

17 MOMNAME, 

18 _make_mom_res, 

19 combine_fit_res, 

20) 

21from metadetect import procflags 

22 

23 

24def _print_res(res): 

25 print("", flush=True) 

26 for name in res.dtype.names: 

27 if "flag" in name: 

28 if len(np.shape(res[name])) > 0 and np.shape(res[name])[0] > 1: 

29 for i, f in enumerate(res[name]): 

30 print( 

31 " %s[%d]: %d (%s)" % ( 

32 name, 

33 i, 

34 res[name][i], 

35 procflags.get_procflags_str(res[name][i]), 

36 ), 

37 flush=True, 

38 ) 

39 else: 

40 print( 

41 " %s: %d (%s)" % ( 

42 name, 

43 res[name], 

44 procflags.get_procflags_str(res[name]), 

45 ), 

46 flush=True, 

47 ) 

48 else: 

49 print(" %s:" % name, res[name], flush=True) 

50 

51 

52def test_fit_all_psfs_same(): 

53 mbobs1 = make_mbobs_sim(45, 4) 

54 fit_all_psfs(mbobs1, np.random.RandomState(seed=10)) 

55 

56 mbobs2 = make_mbobs_sim(45, 4) 

57 fit_all_psfs(mbobs2, np.random.RandomState(seed=10)) 

58 

59 for i in range(4): 

60 for key in mbobs1[i][0].psf.meta["result"]: 

61 assert np.all( 

62 mbobs1[i][0].psf.meta["result"][key] 

63 == mbobs2[i][0].psf.meta["result"][key] 

64 ) 

65 

66 

67def test_fitting_fit_mbobs_wavg_flagging_nodata(): 

68 mbobs = make_mbobs_sim(45, 4) 

69 mbobs[1] = ngmix.ObsList() 

70 res = fit_mbobs_wavg( 

71 mbobs=mbobs, 

72 fitter=GaussMom(1.2), 

73 bmask_flags=0, 

74 ) 

75 _print_res(res[0]) 

76 assert np.all((res["wmom_flags"] & procflags.MISSING_BAND) != 0) 

77 assert not np.any(np.isfinite(res["wmom_g_cov"])) 

78 assert np.all((res["wmom_band_flux_flags"][:, 1] & procflags.MISSING_BAND) != 0) 

79 for i in [0, 2, 3]: 

80 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

81 for tail in ["", "_err"]: 

82 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 1])) 

83 for i in [0, 2, 3]: 

84 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

85 assert np.all(res["shear_bands"] == "0123") 

86 

87 mbobs = make_mbobs_sim(45, 7) 

88 mbobs[1] = ngmix.ObsList() 

89 shear_bands = [0, 1, 2, 3] 

90 res = fit_mbobs_wavg( 

91 mbobs=mbobs, 

92 fitter=GaussMom(1.2), 

93 bmask_flags=0, 

94 shear_bands=shear_bands, 

95 ) 

96 _print_res(res[0]) 

97 assert np.all((res["wmom_flags"] & procflags.MISSING_BAND) != 0) 

98 assert np.all((res["wmom_band_flux_flags"][:, 1] & procflags.MISSING_BAND) != 0) 

99 for i in [0, 2, 3, 4, 5, 6]: 

100 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

101 assert not np.any(np.isfinite(res["wmom_g_cov"])) 

102 for tail in ["", "_err"]: 

103 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 1])) 

104 for i in [0, 2, 3, 4, 5, 6]: 

105 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

106 assert np.all(res["shear_bands"] == "0123") 

107 

108 mbobs = make_mbobs_sim(45, 7) 

109 mbobs[5] = ngmix.ObsList() 

110 shear_bands = [0, 1, 2, 3] 

111 res = fit_mbobs_wavg( 

112 mbobs=mbobs, 

113 fitter=GaussMom(1.2), 

114 bmask_flags=0, 

115 shear_bands=shear_bands, 

116 ) 

117 _print_res(res[0]) 

118 assert np.all((res["wmom_flags"] & procflags.MISSING_BAND) != 0) 

119 assert np.all((res["wmom_band_flux_flags"][:, 5] & procflags.MISSING_BAND) != 0) 

120 for i in [0, 1, 2, 3, 4, 6]: 

121 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

122 assert np.all(np.isfinite(res["wmom_g_cov"])) 

123 for tail in ["", "_err"]: 

124 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 5])) 

125 for i in [0, 1, 2, 3, 4, 6]: 

126 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

127 assert np.all(res["shear_bands"] == "0123") 

128 

129 

130def test_fitting_fit_mbobs_wavg_flagging_edge(): 

131 bmask_flags = 2**8 

132 other_flags = 2**3 

133 mbobs = make_mbobs_sim(45, 4) 

134 with mbobs[1][0].writeable(): 

135 mbobs[1][0].bmask[2, 3] = bmask_flags 

136 mbobs[1][0].bmask[3, 1] = other_flags 

137 res = fit_mbobs_wavg( 

138 mbobs=mbobs, 

139 fitter=GaussMom(1.2), 

140 bmask_flags=bmask_flags, 

141 ) 

142 _print_res(res[0]) 

143 assert np.all((res["wmom_flags"] & procflags.EDGE_HIT) != 0) 

144 for f in [procflags.EDGE_HIT, procflags.MISSING_BAND]: 

145 assert np.all((res["wmom_band_flux_flags"][:, 1] & f) != 0) 

146 for i in [0, 2, 3]: 

147 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

148 assert not np.any(np.isfinite(res["wmom_g_cov"])) 

149 for tail in ["", "_err"]: 

150 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 1])) 

151 for i in [0, 2, 3]: 

152 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

153 assert np.all(res["shear_bands"] == "0123") 

154 

155 mbobs = make_mbobs_sim(45, 7) 

156 with mbobs[1][0].writeable(): 

157 mbobs[1][0].bmask[2, 3] = bmask_flags 

158 mbobs[1][0].bmask[3, 1] = other_flags 

159 res = fit_mbobs_wavg( 

160 mbobs=mbobs, 

161 fitter=GaussMom(1.2), 

162 bmask_flags=bmask_flags, 

163 shear_bands=list(range(4)), 

164 ) 

165 _print_res(res[0]) 

166 assert np.all((res["wmom_flags"] & procflags.EDGE_HIT) != 0) 

167 for f in [procflags.EDGE_HIT, procflags.MISSING_BAND]: 

168 assert np.all((res["wmom_band_flux_flags"][:, 1] & f) != 0) 

169 for i in [0, 2, 3, 4, 5, 6]: 

170 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

171 assert not np.any(np.isfinite(res["wmom_g_cov"])) 

172 for tail in ["", "_err"]: 

173 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 1])) 

174 for i in [0, 2, 3, 4, 5, 6]: 

175 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

176 assert np.all(res["shear_bands"] == "0123") 

177 

178 mbobs = make_mbobs_sim(45, 7) 

179 with mbobs[5][0].writeable(): 

180 mbobs[5][0].bmask[2, 3] = bmask_flags 

181 mbobs[5][0].bmask[3, 1] = other_flags 

182 res = fit_mbobs_wavg( 

183 mbobs=mbobs, 

184 fitter=GaussMom(1.2), 

185 bmask_flags=bmask_flags, 

186 shear_bands=list(range(4)), 

187 ) 

188 _print_res(res[0]) 

189 assert np.all((res["wmom_flags"] & procflags.EDGE_HIT) != 0) 

190 for f in [procflags.EDGE_HIT, procflags.MISSING_BAND]: 

191 assert np.all((res["wmom_band_flux_flags"][:, 5] & f) != 0) 

192 for i in [0, 1, 2, 3, 4, 6]: 

193 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

194 assert np.all(np.isfinite(res["wmom_g_cov"])) 

195 for tail in ["", "_err"]: 

196 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 5])) 

197 for i in [0, 1, 2, 3, 4, 6]: 

198 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

199 assert np.all(res["shear_bands"] == "0123") 

200 

201 

202def test_fitting_fit_mbobs_wavg_flagging_zeroweight(): 

203 mbobs = make_mbobs_sim(45, 4) 

204 with mbobs[1][0].writeable(): 

205 mbobs[1][0].ignore_zero_weight = False 

206 mbobs[1][0].weight[:, :] = 0 

207 res = fit_mbobs_wavg( 

208 mbobs=mbobs, 

209 fitter=GaussMom(1.2), 

210 bmask_flags=0, 

211 ) 

212 _print_res(res[0]) 

213 assert np.all((res["wmom_flags"] & procflags.ZERO_WEIGHTS) != 0) 

214 for f in [procflags.ZERO_WEIGHTS, procflags.MISSING_BAND]: 

215 assert np.all((res["wmom_band_flux_flags"][:, 1] & f) != 0) 

216 for i in [0, 2, 3]: 

217 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

218 assert not np.any(np.isfinite(res["wmom_g_cov"])) 

219 for tail in ["", "_err"]: 

220 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 1])) 

221 for i in [0, 2, 3]: 

222 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

223 assert np.all(res["shear_bands"] == "0123") 

224 

225 mbobs = make_mbobs_sim(45, 7) 

226 with mbobs[1][0].writeable(): 

227 mbobs[1][0].ignore_zero_weight = False 

228 mbobs[1][0].weight[:, :] = 0 

229 res = fit_mbobs_wavg( 

230 mbobs=mbobs, 

231 fitter=GaussMom(1.2), 

232 bmask_flags=0, 

233 shear_bands=list(range(4)), 

234 ) 

235 _print_res(res[0]) 

236 assert np.all((res["wmom_flags"] & procflags.ZERO_WEIGHTS) != 0) 

237 for f in [procflags.ZERO_WEIGHTS, procflags.MISSING_BAND]: 

238 assert np.all((res["wmom_band_flux_flags"][:, 1] & f) != 0) 

239 for i in [0, 2, 3, 4, 5, 6]: 

240 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

241 assert not np.any(np.isfinite(res["wmom_g_cov"])) 

242 for tail in ["", "_err"]: 

243 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 1])) 

244 for i in [0, 2, 3, 4, 5, 6]: 

245 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

246 assert np.all(res["shear_bands"] == "0123") 

247 

248 mbobs = make_mbobs_sim(45, 7) 

249 with mbobs[5][0].writeable(): 

250 mbobs[5][0].ignore_zero_weight = False 

251 mbobs[5][0].weight[:, :] = 0 

252 res = fit_mbobs_wavg( 

253 mbobs=mbobs, 

254 fitter=GaussMom(1.2), 

255 bmask_flags=0, 

256 shear_bands=list(range(4)), 

257 ) 

258 _print_res(res[0]) 

259 assert np.all((res["wmom_flags"] & procflags.ZERO_WEIGHTS) != 0) 

260 for f in [procflags.ZERO_WEIGHTS, procflags.MISSING_BAND]: 

261 assert np.all((res["wmom_band_flux_flags"][:, 5] & f) != 0) 

262 for i in [0, 1, 2, 3, 4, 6]: 

263 assert np.all(res["wmom_band_flux_flags"][:, i] == 0) 

264 assert np.all(np.isfinite(res["wmom_g_cov"])) 

265 for tail in ["", "_err"]: 

266 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, 5])) 

267 for i in [0, 1, 2, 3, 4, 6]: 

268 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

269 assert np.all(res["shear_bands"] == "0123") 

270 

271 

272def test_fitting_fit_mbobs_wavg_flagging_combined(): 

273 mbobs = make_mbobs_sim(45, 4) 

274 bmask_flags = 2**8 

275 other_flags = 2**3 

276 mbobs[0] = ngmix.ObsList() 

277 with mbobs[1][0].writeable(): 

278 mbobs[1][0].bmask[2, 3] = bmask_flags 

279 mbobs[1][0].bmask[3, 1] = other_flags 

280 with mbobs[2][0].writeable(): 

281 mbobs[2][0].ignore_zero_weight = False 

282 mbobs[2][0].weight[:, :] = 0 

283 res = fit_mbobs_wavg( 

284 mbobs=mbobs, 

285 fitter=GaussMom(1.2), 

286 bmask_flags=bmask_flags, 

287 ) 

288 _print_res(res[0]) 

289 assert np.all((res["wmom_flags"] & procflags.MISSING_BAND) != 0) 

290 assert np.all((res["wmom_flags"] & procflags.ZERO_WEIGHTS) != 0) 

291 assert np.all((res["wmom_flags"] & procflags.EDGE_HIT) != 0) 

292 assert np.all( 

293 ( 

294 res["wmom_band_flux_flags"][:, 0] 

295 & (procflags.MISSING_BAND) 

296 ) != 0 

297 ) 

298 for f in [procflags.EDGE_HIT, procflags.MISSING_BAND]: 

299 assert np.all((res["wmom_band_flux_flags"][:, 1] & f) != 0) 

300 for f in [procflags.ZERO_WEIGHTS, procflags.MISSING_BAND]: 

301 assert np.all((res["wmom_band_flux_flags"][:, 2] & f) != 0) 

302 assert np.all(res["wmom_band_flux_flags"][:, 3:] == 0) 

303 assert not np.any(np.isfinite(res["wmom_g_cov"])) 

304 for tail in ["", "_err"]: 

305 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, 3])) 

306 for i in [0, 1, 2]: 

307 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

308 assert np.all(res["shear_bands"] == "0123") 

309 

310 mbobs = make_mbobs_sim(45, 7) 

311 bmask_flags = 2**8 

312 other_flags = 2**3 

313 mbobs[0] = ngmix.ObsList() 

314 with mbobs[1][0].writeable(): 

315 mbobs[1][0].bmask[2, 3] = bmask_flags 

316 mbobs[1][0].bmask[3, 1] = other_flags 

317 with mbobs[2][0].writeable(): 

318 mbobs[2][0].ignore_zero_weight = False 

319 mbobs[2][0].weight[:, :] = 0 

320 res = fit_mbobs_wavg( 

321 mbobs=mbobs, 

322 fitter=GaussMom(1.2), 

323 bmask_flags=bmask_flags, 

324 shear_bands=list(range(4)), 

325 ) 

326 _print_res(res[0]) 

327 assert np.all((res["wmom_flags"] & procflags.MISSING_BAND) != 0) 

328 assert np.all((res["wmom_flags"] & procflags.ZERO_WEIGHTS) != 0) 

329 assert np.all((res["wmom_flags"] & procflags.EDGE_HIT) != 0) 

330 assert np.all( 

331 ( 

332 res["wmom_band_flux_flags"][:, 0] 

333 & (procflags.MISSING_BAND) 

334 ) != 0 

335 ) 

336 for f in [procflags.EDGE_HIT, procflags.MISSING_BAND]: 

337 assert np.all((res["wmom_band_flux_flags"][:, 1] & f) != 0) 

338 for f in [procflags.ZERO_WEIGHTS, procflags.MISSING_BAND]: 

339 assert np.all((res["wmom_band_flux_flags"][:, 2] & f) != 0) 

340 assert np.all(res["wmom_band_flux_flags"][:, 3:] == 0) 

341 assert not np.any(np.isfinite(res["wmom_g_cov"])) 

342 for tail in ["", "_err"]: 

343 for i in [3, 4, 5, 6]: 

344 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

345 for i in [0, 1, 2]: 

346 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

347 assert np.all(res["shear_bands"] == "0123") 

348 

349 mbobs = make_mbobs_sim(45, 7) 

350 bmask_flags = 2**8 

351 other_flags = 2**3 

352 mbobs[4] = ngmix.ObsList() 

353 with mbobs[5][0].writeable(): 

354 mbobs[5][0].ignore_zero_weight = False 

355 mbobs[5][0].weight[:, :] = 0 

356 with mbobs[6][0].writeable(): 

357 mbobs[6][0].bmask[2, 3] = bmask_flags 

358 mbobs[6][0].bmask[3, 1] = other_flags 

359 res = fit_mbobs_wavg( 

360 mbobs=mbobs, 

361 fitter=GaussMom(1.2), 

362 bmask_flags=bmask_flags, 

363 shear_bands=list(range(4)), 

364 ) 

365 _print_res(res[0]) 

366 assert np.all((res["wmom_flags"] & procflags.MISSING_BAND) != 0) 

367 assert np.all((res["wmom_flags"] & procflags.ZERO_WEIGHTS) != 0) 

368 assert np.all((res["wmom_flags"] & procflags.EDGE_HIT) != 0) 

369 assert np.all( 

370 ( 

371 res["wmom_band_flux_flags"][:, 4] 

372 & (procflags.MISSING_BAND) 

373 ) != 0 

374 ) 

375 for f in [procflags.ZERO_WEIGHTS, procflags.MISSING_BAND]: 

376 assert np.all((res["wmom_band_flux_flags"][:, 5] & f) != 0) 

377 for f in [procflags.EDGE_HIT, procflags.MISSING_BAND]: 

378 assert np.all((res["wmom_band_flux_flags"][:, 6] & f) != 0) 

379 assert np.all(res["wmom_band_flux_flags"][:, :4] == 0) 

380 assert np.all(np.isfinite(res["wmom_g_cov"])) 

381 for tail in ["", "_err"]: 

382 for i in [0, 1, 2, 3]: 

383 assert np.all(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

384 for i in [4, 5, 6]: 

385 assert not np.any(np.isfinite(res["wmom_band_flux" + tail][:, i])) 

386 assert np.all(res["shear_bands"] == "0123") 

387 

388 

389@pytest.mark.parametrize("purpose,kwargs,psf_flags,model_flags,flux_flags", [ 

390 ( 

391 "no data at all", 

392 dict( 

393 all_res=[], 

394 all_psf_res=[], 

395 all_is_shear_band=[], 

396 all_wgts=[], 

397 all_flags=[], 

398 ), 

399 ["MISSING_BAND"], 

400 ["MISSING_BAND", "PSF_FAILURE"], 

401 ["MISSING_BAND"], 

402 ), 

403 ( 

404 "everything failed", 

405 dict( 

406 all_res=[None, None, None, None], 

407 all_psf_res=[None, None, None, None], 

408 all_is_shear_band=[True, True, False, False], 

409 all_wgts=[0, 0, 0, 0], 

410 all_flags=[0, 0, 0, 0], 

411 ), 

412 ["MISSING_BAND", "ZERO_WEIGHTS"], 

413 ["MISSING_BAND", "PSF_FAILURE", "ZERO_WEIGHTS"], 

414 [["MISSING_BAND"]] * 4, 

415 ), 

416 

417 ( 

418 "everything failed w/ input flags that should be in the output", 

419 dict( 

420 all_res=[None, None, None, None], 

421 all_psf_res=[None, None, None, None], 

422 all_is_shear_band=[True, True, False, False], 

423 all_wgts=[0, 0, 0, 0], 

424 all_flags=[1, 0, 1, 0], 

425 ), 

426 ["MISSING_BAND", "ZERO_WEIGHTS", "NO_ATTEMPT"], 

427 ["MISSING_BAND", "PSF_FAILURE", "NO_ATTEMPT", "ZERO_WEIGHTS"], 

428 [ 

429 ["MISSING_BAND", "NO_ATTEMPT"], 

430 ["MISSING_BAND"], 

431 ["MISSING_BAND", "NO_ATTEMPT"], 

432 ["MISSING_BAND"] 

433 ], 

434 ), 

435 ( 

436 "we mark weights zero vs not for failures", 

437 dict( 

438 all_res=[None, None, None, None], 

439 all_psf_res=[None, None, None, None], 

440 all_is_shear_band=[True, True, False, False], 

441 all_wgts=[1, 1, 1, 1], 

442 all_flags=[0, 0, 0, 0], 

443 ), 

444 ["MISSING_BAND"], 

445 ["MISSING_BAND", "PSF_FAILURE"], 

446 [["MISSING_BAND"]] * 4, 

447 ), 

448 ( 

449 "everything is fine one band", 

450 dict( 

451 all_res=[{ 

452 "flux_flags": 0, 

453 "flux": 1, 

454 "flux_err": 1, 

455 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

456 MOMNAME+"_cov": np.diag(np.ones(6)) 

457 }] * 1, 

458 all_psf_res=[ 

459 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

460 ] * 1, 

461 all_is_shear_band=[True], 

462 all_wgts=[1], 

463 all_flags=[0], 

464 ), 

465 [], 

466 [], 

467 [], 

468 ), 

469 ( 

470 "everything is fine for more than one band", 

471 dict( 

472 all_res=[{ 

473 "flux_flags": 0, 

474 "flux": 1, 

475 "flux_err": 1, 

476 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

477 MOMNAME+"_cov": np.diag(np.ones(6)) 

478 }] * 4, 

479 all_psf_res=[ 

480 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

481 ] * 4, 

482 all_is_shear_band=[True, True, False, False], 

483 all_wgts=[1, 1, 1, 1], 

484 all_flags=[0, 0, 0, 0], 

485 ), 

486 [], 

487 [], 

488 [[]] * 4, 

489 ), 

490 ( 

491 "extra shear bands", 

492 dict( 

493 all_res=[{ 

494 "flux_flags": 0, 

495 "flux": 1, 

496 "flux_err": 1, 

497 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

498 MOMNAME+"_cov": np.diag(np.ones(6)) 

499 }] * 2, 

500 all_psf_res=[ 

501 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

502 ] * 1, 

503 all_is_shear_band=[True], 

504 all_wgts=[1], 

505 all_flags=[0], 

506 ), 

507 ["INCONSISTENT_BANDS"], 

508 ["INCONSISTENT_BANDS", "PSF_FAILURE"], 

509 [["INCONSISTENT_BANDS"], ["INCONSISTENT_BANDS"]], 

510 ), 

511 ( 

512 "extra PSF bands", 

513 dict( 

514 all_res=[{ 

515 "flux_flags": 0, 

516 "flux": 1, 

517 "flux_err": 1, 

518 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

519 MOMNAME+"_cov": np.diag(np.ones(6)) 

520 }], 

521 all_psf_res=[ 

522 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

523 ] * 2, 

524 all_is_shear_band=[True], 

525 all_wgts=[1], 

526 all_flags=[0], 

527 ), 

528 ["INCONSISTENT_BANDS"], 

529 ["INCONSISTENT_BANDS", "PSF_FAILURE"], 

530 ["INCONSISTENT_BANDS"], 

531 ), 

532 ( 

533 "extra weights", 

534 dict( 

535 all_res=[{ 

536 "flux_flags": 0, 

537 "flux": 1, 

538 "flux_err": 1, 

539 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

540 MOMNAME+"_cov": np.diag(np.ones(6)) 

541 }], 

542 all_psf_res=[{MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))}], 

543 all_is_shear_band=[True], 

544 all_wgts=[1, 1], 

545 all_flags=[0], 

546 ), 

547 ["INCONSISTENT_BANDS"], 

548 ["INCONSISTENT_BANDS", "PSF_FAILURE"], 

549 ["INCONSISTENT_BANDS"], 

550 ), 

551 ( 

552 "extra flags", 

553 dict( 

554 all_res=[{ 

555 "flux_flags": 0, 

556 "flux": 1, 

557 "flux_err": 1, 

558 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

559 MOMNAME+"_cov": np.diag(np.ones(6)) 

560 }], 

561 all_psf_res=[{MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))}], 

562 all_is_shear_band=[True], 

563 all_wgts=[1], 

564 all_flags=[0, 0], 

565 ), 

566 ["INCONSISTENT_BANDS"], 

567 ["INCONSISTENT_BANDS", "PSF_FAILURE"], 

568 ["INCONSISTENT_BANDS"], 

569 ), 

570 ( 

571 "extra shear bands", 

572 dict( 

573 all_res=[{ 

574 "flux_flags": 0, 

575 "flux": 1, 

576 "flux_err": 1, 

577 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

578 MOMNAME+"_cov": np.diag(np.ones(6)) 

579 }], 

580 all_psf_res=[{MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))}], 

581 all_is_shear_band=[True, False], 

582 all_wgts=[1], 

583 all_flags=[0, 0], 

584 ), 

585 ["INCONSISTENT_BANDS"], 

586 ["INCONSISTENT_BANDS", "PSF_FAILURE"], 

587 ["INCONSISTENT_BANDS"], 

588 ), 

589 ( 

590 "flag a single shear", 

591 dict( 

592 all_res=[{ 

593 "flux_flags": 0, 

594 "flux": 1, 

595 "flux_err": 1, 

596 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

597 MOMNAME+"_cov": np.diag(np.ones(6)) 

598 }] * 4, 

599 all_psf_res=[ 

600 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

601 ] * 4, 

602 all_is_shear_band=[True, True, False, False], 

603 all_wgts=[1, 1, 1, 1], 

604 all_flags=[1, 0, 0, 0], 

605 ), 

606 ["NO_ATTEMPT"], 

607 ["NO_ATTEMPT", "PSF_FAILURE"], 

608 [["NO_ATTEMPT"], [], [], []], 

609 ), 

610 ( 

611 "flag a shear res is fine", 

612 dict( 

613 all_res=[{ 

614 "flux_flags": 1, 

615 "flux": 1, 

616 "flux_err": 1, 

617 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

618 MOMNAME+"_cov": np.diag(np.ones(6)) 

619 }] + [{ 

620 "flux_flags": 0, 

621 "flux": 1, 

622 "flux_err": 1, 

623 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

624 MOMNAME+"_cov": np.diag(np.ones(6)) 

625 }] * 3, 

626 all_psf_res=[ 

627 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

628 ] * 4, 

629 all_is_shear_band=[True, True, False, False], 

630 all_wgts=[1, 1, 1, 1], 

631 all_flags=[0, 0, 0, 0], 

632 ), 

633 [], 

634 [], 

635 [["NO_ATTEMPT"], [], [], []], 

636 ), 

637 ( 

638 "zero weight a shear", 

639 dict( 

640 all_res=[{ 

641 "flux_flags": 0, 

642 "flux": 1, 

643 "flux_err": 1, 

644 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

645 MOMNAME+"_cov": np.diag(np.ones(6)) 

646 }] * 4, 

647 all_psf_res=[ 

648 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

649 ] * 4, 

650 all_is_shear_band=[True, True, False, False], 

651 all_wgts=[0, 1, 1, 1], 

652 all_flags=[0, 0, 0, 0], 

653 ), 

654 ["ZERO_WEIGHTS"], 

655 ["ZERO_WEIGHTS", "PSF_FAILURE"], 

656 [[]] * 4, 

657 ), 

658 ( 

659 "missing a shear res", 

660 dict( 

661 all_res=[None] + [{ 

662 "flux_flags": 0, 

663 "flux": 1, 

664 "flux_err": 1, 

665 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

666 MOMNAME+"_cov": np.diag(np.ones(6)) 

667 }] * 3, 

668 all_psf_res=[ 

669 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

670 ] * 4, 

671 all_is_shear_band=[True, True, False, False], 

672 all_wgts=[1, 1, 1, 1], 

673 all_flags=[0, 0, 0, 0], 

674 ), 

675 ["MISSING_BAND"], 

676 ["MISSING_BAND", "PSF_FAILURE"], 

677 [["MISSING_BAND"], [], [], []], 

678 ), 

679 ( 

680 "zero weight a flux", 

681 dict( 

682 all_res=[{ 

683 "flux_flags": 0, 

684 "flux": 1, 

685 "flux_err": 1, 

686 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

687 MOMNAME+"_cov": np.diag(np.ones(6)) 

688 }] * 4, 

689 all_psf_res=[ 

690 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

691 ] * 4, 

692 all_is_shear_band=[True, True, False, False], 

693 all_wgts=[1, 1, 1, 0], 

694 all_flags=[0, 0, 0, 0], 

695 ), 

696 [], 

697 [], 

698 [[]] * 4, 

699 ), 

700 ( 

701 "flag a flux", 

702 dict( 

703 all_res=[{ 

704 "flux_flags": 0, 

705 "flux": 1, 

706 "flux_err": 1, 

707 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

708 MOMNAME+"_cov": np.diag(np.ones(6)) 

709 }] * 4, 

710 all_psf_res=[ 

711 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

712 ] * 4, 

713 all_is_shear_band=[True, True, False, False], 

714 all_wgts=[1, 1, 1, 1], 

715 all_flags=[0, 0, 0, 1], 

716 ), 

717 [], 

718 [], 

719 [[], [], [], ["NO_ATTEMPT"]], 

720 ), 

721 ( 

722 "flag a flux in res", 

723 dict( 

724 all_res=[{ 

725 "flux_flags": 0, 

726 "flux": 1, 

727 "flux_err": 1, 

728 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

729 MOMNAME+"_cov": np.diag(np.ones(6)) 

730 }] * 3 + [ 

731 { 

732 "flux_flags": 1, 

733 "flux": 1, 

734 "flux_err": 1, 

735 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

736 MOMNAME+"_cov": np.diag(np.ones(6)) 

737 } 

738 ], 

739 all_psf_res=[ 

740 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

741 ] * 4, 

742 all_is_shear_band=[True, True, False, False], 

743 all_wgts=[1, 1, 1, 1], 

744 all_flags=[0, 0, 0, 0], 

745 ), 

746 [], 

747 [], 

748 [[], [], [], ["NO_ATTEMPT"]], 

749 ), 

750 ( 

751 "missing a flux res", 

752 dict( 

753 all_res=[{ 

754 "flux_flags": 0, 

755 "flux": 1, 

756 "flux_err": 1, 

757 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

758 MOMNAME+"_cov": np.diag(np.ones(6)) 

759 }] * 3 + [None], 

760 all_psf_res=[ 

761 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

762 ] * 4, 

763 all_is_shear_band=[True, True, False, False], 

764 all_wgts=[1, 1, 1, 1], 

765 all_flags=[0, 0, 0, 0], 

766 ), 

767 [], 

768 [], 

769 [[], [], [], ["MISSING_BAND"]], 

770 ), 

771 ( 

772 "missing flux", 

773 dict( 

774 all_res=[{ 

775 "flux_flags": 0, 

776 "flux": 1, 

777 "flux_err": 1, 

778 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

779 MOMNAME+"_cov": np.diag(np.ones(6)) 

780 }] * 3 + [{ 

781 "flux_flags": 0, 

782 "flux_err": 1, 

783 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

784 MOMNAME+"_cov": np.diag(np.ones(6)) 

785 }], 

786 all_psf_res=[ 

787 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

788 ] * 4, 

789 all_is_shear_band=[True, True, False, False], 

790 all_wgts=[1, 1, 1, 1], 

791 all_flags=[0, 0, 0, 0], 

792 ), 

793 [], 

794 [], 

795 [[], [], [], ["NOMOMENTS_FAILURE"]], 

796 ), 

797 ( 

798 "missing flux_err", 

799 dict( 

800 all_res=[{ 

801 "flux_flags": 0, 

802 "flux": 1, 

803 "flux_err": 1, 

804 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

805 MOMNAME+"_cov": np.diag(np.ones(6)) 

806 }] * 3 + [{ 

807 "flux_flags": 0, 

808 "flux": 1, 

809 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

810 MOMNAME+"_cov": np.diag(np.ones(6)) 

811 }], 

812 all_psf_res=[ 

813 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

814 ] * 4, 

815 all_is_shear_band=[True, True, False, False], 

816 all_wgts=[1, 1, 1, 1], 

817 all_flags=[0, 0, 0, 0], 

818 ), 

819 [], 

820 [], 

821 [[], [], [], ["NOMOMENTS_FAILURE"]], 

822 ), 

823 ( 

824 "missing flux_flags", 

825 dict( 

826 all_res=[{ 

827 "flux_flags": 0, 

828 "flux": 1, 

829 "flux_err": 1, 

830 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

831 MOMNAME+"_cov": np.diag(np.ones(6)) 

832 }] * 3 + [{ 

833 "flux": 1, 

834 "flux_err": 1, 

835 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

836 MOMNAME+"_cov": np.diag(np.ones(6)) 

837 }], 

838 all_psf_res=[ 

839 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

840 ] * 4, 

841 all_is_shear_band=[True, True, False, False], 

842 all_wgts=[1, 1, 1, 1], 

843 all_flags=[0, 0, 0, 0], 

844 ), 

845 [], 

846 [], 

847 [[], [], [], ["NOMOMENTS_FAILURE"]], 

848 ), 

849 ( 

850 "missing mom", 

851 dict( 

852 all_res=[{ 

853 "flux_flags": 0, 

854 "flux": 1, 

855 "flux_err": 1, 

856 MOMNAME+"_cov": np.diag(np.ones(6)) 

857 }] + [{ 

858 "flux_flags": 0, 

859 "flux": 1, 

860 "flux_err": 1, 

861 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

862 MOMNAME+"_cov": np.diag(np.ones(6)) 

863 }] * 3, 

864 all_psf_res=[ 

865 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

866 ] * 4, 

867 all_is_shear_band=[True, True, False, False], 

868 all_wgts=[1, 1, 1, 1], 

869 all_flags=[0, 0, 0, 0], 

870 ), 

871 ["NOMOMENTS_FAILURE"], 

872 ["NOMOMENTS_FAILURE", "PSF_FAILURE"], 

873 [[]] * 4, 

874 ), 

875 ( 

876 "missing mom_cov", 

877 dict( 

878 all_res=[{ 

879 "flux_flags": 0, 

880 "flux": 1, 

881 "flux_err": 1, 

882 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]) 

883 }] + [{ 

884 "flux_flags": 0, 

885 "flux": 1, 

886 "flux_err": 1, 

887 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

888 MOMNAME+"_cov": np.diag(np.ones(6)) 

889 }] * 3, 

890 all_psf_res=[ 

891 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

892 ] * 4, 

893 all_is_shear_band=[True, True, False, False], 

894 all_wgts=[1, 1, 1, 1], 

895 all_flags=[0, 0, 0, 0], 

896 ), 

897 ["NOMOMENTS_FAILURE"], 

898 ["NOMOMENTS_FAILURE", "PSF_FAILURE"], 

899 [[]] * 4, 

900 ), 

901 ( 

902 "missing psf mom", 

903 dict( 

904 all_res=[{ 

905 "flux_flags": 0, 

906 "flux": 1, 

907 "flux_err": 1, 

908 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

909 MOMNAME+"_cov": np.diag(np.ones(6)) 

910 }] * 4, 

911 all_psf_res=[ 

912 {MOMNAME+"_cov": np.diag(np.ones(6))} 

913 ] + [ 

914 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

915 ] * 3, 

916 all_is_shear_band=[True, True, False, False], 

917 all_wgts=[1, 1, 1, 1], 

918 all_flags=[0, 0, 0, 0], 

919 ), 

920 ["NOMOMENTS_FAILURE"], 

921 ["PSF_FAILURE"], 

922 [[]] * 4, 

923 ), 

924 ( 

925 "missing psf mom cov", 

926 dict( 

927 all_res=[{ 

928 "flux_flags": 0, 

929 "flux": 1, 

930 "flux_err": 1, 

931 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

932 MOMNAME+"_cov": np.diag(np.ones(6)) 

933 }] * 4, 

934 all_psf_res=[ 

935 {MOMNAME: np.ones(6)} 

936 ] + [ 

937 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

938 ] * 3, 

939 all_is_shear_band=[True, True, False, False], 

940 all_wgts=[1, 1, 1, 1], 

941 all_flags=[0, 0, 0, 0], 

942 ), 

943 ["NOMOMENTS_FAILURE"], 

944 ["PSF_FAILURE"], 

945 [[]] * 4, 

946 ), 

947 ( 

948 "missing psf mom for flux", 

949 dict( 

950 all_res=[{ 

951 "flux_flags": 0, 

952 "flux": 1, 

953 "flux_err": 1, 

954 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

955 MOMNAME+"_cov": np.diag(np.ones(6)) 

956 }] * 4, 

957 all_psf_res=[ 

958 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

959 ] * 3 + [ 

960 {MOMNAME+"_cov": np.diag(np.ones(6))} 

961 ], 

962 all_is_shear_band=[True, True, False, False], 

963 all_wgts=[1, 1, 1, 1], 

964 all_flags=[0, 0, 0, 0], 

965 ), 

966 [], 

967 [], 

968 [[]] * 4, 

969 ), 

970 ( 

971 "missing psf mom for flux", 

972 dict( 

973 all_res=[{ 

974 "flux_flags": 0, 

975 "flux": 1, 

976 "flux_err": 1, 

977 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

978 MOMNAME+"_cov": np.diag(np.ones(6)) 

979 }] * 4, 

980 all_psf_res=[ 

981 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

982 ] * 3 + [ 

983 {MOMNAME: np.diag(np.ones(6))} 

984 ], 

985 all_is_shear_band=[True, True, False, False], 

986 all_wgts=[1, 1, 1, 1], 

987 all_flags=[0, 0, 0, 0], 

988 ), 

989 [], 

990 [], 

991 [[]] * 4, 

992 ), 

993 ( 

994 "negative/cancelling weights somehow", 

995 dict( 

996 all_res=[{ 

997 "flux_flags": 0, 

998 "flux": 1, 

999 "flux_err": 1, 

1000 MOMNAME: np.array([0, 0, 0.5, 0.5, 1, 1]), 

1001 MOMNAME+"_cov": np.diag(np.ones(6)) 

1002 }] * 5, 

1003 all_psf_res=[ 

1004 {MOMNAME: np.ones(6), MOMNAME+"_cov": np.diag(np.ones(6))} 

1005 ] * 5, 

1006 all_is_shear_band=[True, True, True, False, False], 

1007 all_wgts=[-1, 1, 0, 1, 1], 

1008 all_flags=[0, 0, 0, 0, 0], 

1009 ), 

1010 ["ZERO_WEIGHTS"], 

1011 ["ZERO_WEIGHTS", "PSF_FAILURE"], 

1012 [[]] * 5, 

1013 ), 

1014]) 

1015def test_fitting_combine_fit_results_wavg_flagging( 

1016 purpose, kwargs, psf_flags, model_flags, flux_flags 

1017): 

1018 def _check_flags(val, flags): 

1019 if np.shape(val) == tuple(): 

1020 fval = 0 

1021 for flag in flags: 

1022 fval |= getattr(procflags, flag) 

1023 if val != fval: 

1024 for flag in flags: 

1025 assert (val & getattr(procflags, flag)) != 0, ( 

1026 "%s: flag val %s failed!" % (purpose, flag) 

1027 ) 

1028 assert val == fval, purpose 

1029 else: 

1030 for i, _val in enumerate(val): 

1031 fval = 0 

1032 for flag in flags[i]: 

1033 fval |= getattr(procflags, flag) 

1034 

1035 if _val != fval: 

1036 for flag in flags: 

1037 assert (_val & getattr(procflags, flag)) != 0, ( 

1038 "%s: flag val %s failed!" % (purpose, flag) 

1039 ) 

1040 assert _val == fval, purpose 

1041 

1042 model = "wwmom" 

1043 shear_bands = [i for i, b in enumerate(kwargs["all_is_shear_band"]) if b] 

1044 

1045 data = _combine_fit_results_wavg( 

1046 model=model, shear_bands=shear_bands, fwhm_reg=0, **kwargs 

1047 ) 

1048 print() 

1049 _print_res(data[0]) 

1050 _check_flags(data[model + "_psf_flags"][0], psf_flags) 

1051 _check_flags(data[model + "_obj_flags"][0], model_flags) 

1052 _check_flags(data[model + "_band_flux_flags"][0], flux_flags) 

1053 if len(flux_flags) > 1 and isinstance(flux_flags[0], list): 

1054 dff = 0 

1055 for f in data[model + "_band_flux_flags"][0]: 

1056 dff |= f 

1057 else: 

1058 dff = data[model + "_band_flux_flags"][0] 

1059 assert ( 

1060 data[model + "_flags"][0] == 

1061 (data[model + "_obj_flags"][0] | dff) 

1062 ), purpose 

1063 if data[model + "_psf_flags"][0] != 0: 

1064 assert (data[model + "_obj_flags"][0] & procflags.PSF_FAILURE) != 0, purpose 

1065 

1066 

1067@pytest.mark.parametrize("mom_norm", [None, [0.3, 0.9, 0.8, 0.6]]) 

1068def test_fitting_sum_bands_wavg_weighting(mom_norm): 

1069 all_is_shear_band = [True, True] 

1070 all_res = [ 

1071 { 

1072 MOMNAME: np.ones(6) * 3, 

1073 MOMNAME+"_cov": np.diag(np.ones(6)) * 3.1, 

1074 }, 

1075 { 

1076 MOMNAME: np.ones(6) * 7, 

1077 MOMNAME+"_cov": np.diag(np.ones(6)) * 7.1, 

1078 }, 

1079 ] 

1080 all_wgts = [0.2, 0.5] 

1081 all_flags = [0, 0] 

1082 all_wgt_res = [ 

1083 { 

1084 MOMNAME: np.ones(6) * 6, 

1085 MOMNAME + "_cov": np.diag(np.ones(6)) * 6.1, 

1086 }, 

1087 { 

1088 MOMNAME: np.ones(6) * 2, 

1089 MOMNAME+"_cov": np.diag(np.ones(6)) * 2.1, 

1090 }, 

1091 ] 

1092 

1093 if mom_norm is not None: 

1094 all_res[0][MOMNAME+"_norm"] = mom_norm[0] 

1095 all_res[1][MOMNAME+"_norm"] = mom_norm[1] 

1096 all_wgt_res[0][MOMNAME+"_norm"] = mom_norm[2] 

1097 all_wgt_res[1][MOMNAME+"_norm"] = mom_norm[3] 

1098 

1099 # return value is 

1100 # raw_mom, raw_mom_cov, wgt_sum, final_flags, used_shear_bands, 

1101 # flux, flux_var, flux_wgt_sum 

1102 sums_wgt = _sum_bands_wavg( 

1103 all_res=all_res, 

1104 all_is_shear_band=all_is_shear_band, 

1105 all_wgts=all_wgts, 

1106 all_flags=all_flags, 

1107 all_wgt_res=all_wgt_res, 

1108 ) 

1109 if mom_norm is None: 

1110 fac0 = 6/3 

1111 fac1 = 2/7 

1112 else: 

1113 fac0 = (6 / mom_norm[2]) / (3 / mom_norm[0]) 

1114 fac1 = (2 / mom_norm[3]) / (7 / mom_norm[1]) 

1115 

1116 assert sums_wgt["wgt_sum"] == 0.7 

1117 if mom_norm is None: 

1118 np.testing.assert_allclose( 

1119 sums_wgt["raw_mom"], 

1120 [0.2 * 3 * fac0 + 0.5 * 7 * fac1] * 6 

1121 ) 

1122 else: 

1123 np.testing.assert_allclose( 

1124 sums_wgt["raw_mom"], 

1125 [0.2 * 3 * fac0 / mom_norm[0] + 0.5 * 7 * fac1 / mom_norm[1]] * 6 

1126 ) 

1127 if mom_norm is None: 

1128 np.testing.assert_allclose( 

1129 np.diag(sums_wgt["raw_mom_cov"]), 

1130 [0.2**2 * 3.1 * fac0**2 + 0.5**2 * 7.1 * fac1**2] * 6 

1131 ) 

1132 else: 

1133 np.testing.assert_allclose( 

1134 np.diag(sums_wgt["raw_mom_cov"]), 

1135 [ 

1136 0.2**2 * 3.1 * fac0**2 / mom_norm[0]**2 

1137 + 0.5**2 * 7.1 * fac1**2 / mom_norm[1]**2 

1138 ] * 6 

1139 ) 

1140 

1141 sums = _sum_bands_wavg( 

1142 all_res=all_res, 

1143 all_is_shear_band=all_is_shear_band, 

1144 all_wgts=all_wgts, 

1145 all_flags=all_flags, 

1146 all_wgt_res=None, 

1147 ) 

1148 assert sums["wgt_sum"] == 0.7 

1149 if mom_norm is None: 

1150 np.testing.assert_allclose( 

1151 sums["raw_mom"], 

1152 [0.2 * 3 + 0.5 * 7] * 6 

1153 ) 

1154 else: 

1155 np.testing.assert_allclose( 

1156 sums["raw_mom"], 

1157 [0.2 * 3 / mom_norm[0] + 0.5 * 7 / mom_norm[1]] * 6 

1158 ) 

1159 if mom_norm is None: 

1160 np.testing.assert_allclose( 

1161 np.diag(sums["raw_mom_cov"]), 

1162 [0.2**2 * 3.1 + 0.5**2 * 7.1] * 6 

1163 ) 

1164 else: 

1165 np.testing.assert_allclose( 

1166 np.diag(sums["raw_mom_cov"]), 

1167 [0.2**2 * 3.1 / mom_norm[0]**2 + 0.5**2 * 7.1 / mom_norm[1]**2] * 6 

1168 ) 

1169 

1170 # everything but the moments should be the same 

1171 for key in [ 

1172 "wgt_sum", "final_flags", "used_shear_bands", "flux", "flux_var", 

1173 ]: 

1174 assert sums[key] == sums_wgt[key], (key, sums[key]) 

1175 

1176 

1177def test_fitting_symmetrize_obs_weights_all_zero(): 

1178 obs = ngmix.Observation( 

1179 image=np.zeros((13, 13)), 

1180 weight=np.zeros((13, 13)), 

1181 ignore_zero_weight=False, 

1182 ) 

1183 sym_obs = symmetrize_obs_weights(obs) 

1184 assert sym_obs is not obs 

1185 assert sym_obs.ignore_zero_weight is False 

1186 assert np.all(sym_obs.weight == 0) 

1187 

1188 wgt = np.zeros((13, 13)) 

1189 wgt[:, :2] = 1 

1190 obs = ngmix.Observation( 

1191 image=np.zeros((13, 13)), 

1192 weight=wgt, 

1193 ) 

1194 sym_obs = symmetrize_obs_weights(obs) 

1195 assert sym_obs is not obs 

1196 assert sym_obs.ignore_zero_weight is False 

1197 assert np.all(sym_obs.weight == 0) 

1198 assert not np.array_equal(sym_obs.weight, obs.weight) 

1199 

1200 

1201def test_fitting_symmetrize_obs_weights_none(): 

1202 obs = ngmix.Observation( 

1203 image=np.zeros((13, 13)), 

1204 weight=np.ones((13, 13)), 

1205 ) 

1206 sym_obs = symmetrize_obs_weights(obs) 

1207 assert sym_obs is not obs 

1208 assert sym_obs.ignore_zero_weight is True 

1209 assert np.all(sym_obs.weight == 1) 

1210 assert np.array_equal(sym_obs.weight, obs.weight) 

1211 

1212 

1213def test_fitting_symmetrize_obs_weights(): 

1214 wgt = np.array([ 

1215 [1, 1, 0], 

1216 [1, 1, 1], 

1217 [1, 1, 1], 

1218 ]) 

1219 obs = ngmix.Observation( 

1220 image=np.zeros((3, 3)), 

1221 weight=wgt, 

1222 ) 

1223 sym_obs = symmetrize_obs_weights(obs) 

1224 assert sym_obs is not obs 

1225 assert sym_obs.ignore_zero_weight is True 

1226 assert not np.array_equal(sym_obs.weight, obs.weight) 

1227 sym_wgt = np.array([ 

1228 [0, 1, 0], 

1229 [1, 1, 1], 

1230 [0, 1, 0], 

1231 ]) 

1232 assert np.array_equal(sym_obs.weight, sym_wgt) 

1233 

1234 

1235def test_fitting_fit_mbobs_wavg_wmom_tratio(): 

1236 fitter = GaussMom(1.2) 

1237 seed = 10 

1238 nband = 3 

1239 bmask_flags = 0 

1240 

1241 mbobs = make_mbobs_sim( 

1242 seed, 

1243 nband, 

1244 simulate_star=True, 

1245 noise_scale=1e-4, 

1246 band_flux_factors=[0.1, 2.0, 5.0], 

1247 band_image_sizes=[39, 45, 67], 

1248 ) 

1249 res = fit_mbobs_wavg(mbobs=mbobs, fitter=fitter, bmask_flags=bmask_flags) 

1250 _print_res(res[0]) 

1251 assert np.allclose(res["wmom_T_ratio"], 1.0) 

1252 

1253 mbobs = make_mbobs_sim( 

1254 seed, 

1255 nband, 

1256 simulate_star=False, 

1257 noise_scale=1e-4, 

1258 band_flux_factors=[0.1, 2.0, 5.0], 

1259 band_image_sizes=[39, 45, 67], 

1260 ) 

1261 res = fit_mbobs_wavg(mbobs=mbobs, fitter=fitter, bmask_flags=bmask_flags) 

1262 _print_res(res[0]) 

1263 assert not np.allclose(res["wmom_T_ratio"], 1.0) 

1264 assert res["wmom_T_ratio"][0] > 1.5 

1265 

1266 

1267@pytest.mark.parametrize("fwhm_reg", [0, 0.8]) 

1268@pytest.mark.parametrize("has_nan", [True, False]) 

1269@pytest.mark.parametrize("zero_flux", [True, False]) 

1270@pytest.mark.parametrize("neg_flux_var", [True, False]) 

1271def test_make_mom_res(fwhm_reg, has_nan, zero_flux, neg_flux_var): 

1272 fwhm = 0.9 

1273 image_size = 107 

1274 cen = (image_size - 1)/2 

1275 gs_wcs = galsim.ShearWCS( 

1276 0.125, galsim.Shear(g1=0, g2=0)).jacobian() 

1277 

1278 obj = galsim.Gaussian( 

1279 fwhm=fwhm 

1280 ).shear( 

1281 g1=-0.1, g2=0.3 

1282 ).withFlux( 

1283 400) 

1284 im = obj.drawImage( 

1285 nx=image_size, 

1286 ny=image_size, 

1287 wcs=gs_wcs, 

1288 method='no_pixel').array 

1289 noise = np.sqrt(np.sum(im**2)) / 1e2 

1290 wgt = np.ones_like(im) / noise**2 

1291 

1292 fitter = GaussMom(fwhm=1.2) 

1293 

1294 # get true flux 

1295 jac = ngmix.Jacobian( 

1296 y=cen, x=cen, 

1297 dudx=gs_wcs.dudx, dudy=gs_wcs.dudy, 

1298 dvdx=gs_wcs.dvdx, dvdy=gs_wcs.dvdy) 

1299 obs = ngmix.Observation( 

1300 image=im, 

1301 jacobian=jac, 

1302 weight=wgt, 

1303 ) 

1304 res = fitter.go(obs=obs) 

1305 

1306 if has_nan: 

1307 res["sums"][0] = np.nan 

1308 res["sums"][1] = np.nan 

1309 

1310 raw_mom = res["sums"].copy() 

1311 raw_mom_cov = res["sums_cov"].copy() 

1312 raw_flux = res["flux"] / 1.15 

1313 raw_flux_var = res["sums_cov"][5, 5] / 1.15**2 

1314 

1315 if zero_flux: 

1316 raw_flux = -1 

1317 if neg_flux_var: 

1318 raw_flux_var = -1 

1319 

1320 res_reg = _make_mom_res( 

1321 raw_mom=raw_mom, 

1322 raw_mom_cov=raw_mom_cov, 

1323 raw_flux=raw_flux, 

1324 raw_flux_var=raw_flux_var, 

1325 fwhm_reg=fwhm_reg, 

1326 ) 

1327 

1328 if has_nan: 

1329 assert np.isnan(res_reg["sums"][0]) 

1330 assert np.isnan(res_reg["sums"][1]) 

1331 assert np.all(np.isfinite(res_reg["sums"][2:])) 

1332 

1333 T_reg = fwhm_to_T(fwhm_reg) 

1334 

1335 if not has_nan: 

1336 assert np.allclose(res["sums"][[0, 1]], res_reg["sums"][[0, 1]]) 

1337 assert np.allclose(res["sums"][4] + T_reg * res["sums"][5], res_reg["sums"][4]) 

1338 if fwhm_reg > 0: 

1339 assert not np.allclose(res["sums"][4], res_reg["sums"][4]) 

1340 assert np.allclose(res["sums"][[2, 3, 5]], res_reg["sums"][[2, 3, 5]]) 

1341 for col in ["T", "T_err", "T_flags"]: 

1342 assert np.allclose(res[col], res_reg[col]) 

1343 for col in ["e1", "e2", "e", "e_err", "e_cov"]: 

1344 if fwhm_reg > 0: 

1345 assert not np.allclose(res[col], res_reg[col]) 

1346 else: 

1347 assert np.allclose(res[col], res_reg[col]) 

1348 

1349 if zero_flux: 

1350 assert (res_reg["flags"] & ngmix.flags.NONPOS_FLUX) != 0 

1351 else: 

1352 assert (res_reg["flags"] & ngmix.flags.NONPOS_FLUX) == 0 

1353 assert not np.allclose(res["flux"], res_reg["flux"]) 

1354 

1355 if neg_flux_var: 

1356 assert (res_reg["flags"] & ngmix.flags.NONPOS_VAR) != 0 

1357 assert (res_reg["flux_flags"] & ngmix.flags.NONPOS_VAR) != 0 

1358 assert np.isnan(res_reg["flux_err"]) 

1359 assert np.isnan(res_reg["s2n"]) 

1360 else: 

1361 assert (res_reg["flags"] & ngmix.flags.NONPOS_VAR) == 0 

1362 assert (res_reg["flux_flags"] & ngmix.flags.NONPOS_VAR) == 0 

1363 assert not np.allclose(res["flux_err"], res_reg["flux_err"]) 

1364 

1365 if not zero_flux and not neg_flux_var: 

1366 assert np.allclose(res["s2n"], res_reg["s2n"]) 

1367 else: 

1368 assert not np.allclose(res["s2n"], res_reg["s2n"]) 

1369 

1370 

1371@pytest.mark.parametrize("all_res,expected,raises", [ 

1372 [[None], None, ""], 

1373 [[None, None], None, ""], 

1374 [ 

1375 [ 

1376 None, 

1377 np.zeros(0, dtype=[("a_flags", "i8"), ("a", "f4")]), 

1378 np.zeros(0, dtype=[("b_flags", "i8"), ("b", "f8")]), 

1379 ], 

1380 np.zeros(0, dtype=[ 

1381 ("a_flags", "i8"), 

1382 ("a", "f4"), 

1383 ("b_flags", "i8"), 

1384 ("b", "f8") 

1385 ]), 

1386 "", 

1387 ], 

1388 [ 

1389 [ 

1390 None, 

1391 np.zeros(1, dtype=[("a_flags", "i8"), ("a", "f4")]), 

1392 np.zeros(0, dtype=[("b_flags", "i8"), ("b", "f4")]), 

1393 ], 

1394 None, 

1395 "All fit results must be the same length!", 

1396 ], 

1397 [ 

1398 [ 

1399 np.zeros(1, dtype=[("flags", "i8"), ("a", "f4")]), 

1400 None, 

1401 ], 

1402 None, 

1403 "All fit results must zero length if one is None!", 

1404 ], 

1405 [ 

1406 [ 

1407 None, 

1408 np.zeros(1, dtype=[("flags", "i8"), ("a", "f4")]), 

1409 ], 

1410 None, 

1411 "All fit results must be the same length!", 

1412 ], 

1413 [ 

1414 [ 

1415 np.zeros(1, dtype=[("a_flags", "i8"), ("a", "f4"), ("shear_bands", "i4")]), 

1416 np.ones(1, dtype=[("b_flags", "i8"), ("b", "f8"), ("shear_bands", "i4")]), 

1417 ], 

1418 None, 

1419 "Inconsistent column values for shear_bands when combining results!", 

1420 ], 

1421]) 

1422def test_combine_fit_res_nodata(all_res, expected, raises): 

1423 if raises == "": 

1424 res = combine_fit_res(all_res) 

1425 if expected is not None: 

1426 np.testing.assert_array_equal(expected, res) 

1427 else: 

1428 assert res is None 

1429 else: 

1430 with pytest.raises(RuntimeError) as e: 

1431 combine_fit_res(all_res) 

1432 

1433 assert raises in str(e.value) 

1434 

1435 

1436def test_combine_fit_res(): 

1437 all_res = [ 

1438 np.zeros(2, dtype=[("a", "f4"), ("shear_bands", "i4")]), 

1439 np.zeros(2, dtype=[("b", "f8"), ("shear_bands", "i4")]), 

1440 ] 

1441 all_res[0]["a"] = [0.3, 2.3] 

1442 all_res[1]["b"] = [1.4, 4.6] 

1443 

1444 all_res[0]["shear_bands"] = [0, 2] 

1445 all_res[1]["shear_bands"] = [0, 2] 

1446 

1447 res = combine_fit_res(all_res) 

1448 

1449 np.testing.assert_array_equal(res["a"], np.array([0.3, 2.3], dtype="f4")) 

1450 np.testing.assert_array_equal(res["b"], np.array([1.4, 4.6], dtype="f8")) 

1451 np.testing.assert_array_equal(res["shear_bands"], np.array([0, 2], dtype="i4"))