Coverage for python / lsst / drp / tasks / metadetection_shear.py: 20%

238 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 08:35 +0000

1# This file is part of drp_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "MetadetectionProcessingError", 

26 "MetadetectionShearConfig", 

27 "MetadetectionShearTask", 

28) 

29 

30from collections.abc import Collection, Mapping, Sequence 

31from itertools import product 

32from typing import Any, ClassVar 

33 

34import esutil as eu 

35import numpy as np 

36import pyarrow as pa 

37from metadetect.lsst.masking import apply_apodized_bright_masks_mbexp, apply_apodized_edge_masks_mbexp 

38from metadetect.lsst.metacal_exposures import STEP as SHEAR_STEP 

39from metadetect.lsst.metadetect import MetadetectTask 

40from metadetect.lsst.util import extract_multiband_coadd_data 

41 

42import lsst.pipe.base.connectionTypes as cT 

43from lsst.afw.image import ExposureF 

44from lsst.afw.table import SimpleCatalog 

45from lsst.cell_coadds import MultipleCellCoadd, StitchedCoadd 

46from lsst.daf.butler import DataCoordinate, DatasetRef 

47from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader 

48from lsst.meas.base import FullIdGenerator, SkyMapIdGeneratorConfig 

49from lsst.pex.config import ConfigField, ConfigurableField, Field, FieldValidationError, ListField 

50from lsst.pipe.base import ( 

51 AlgorithmError, 

52 AnnotatedPartialOutputsError, 

53 InputQuantizedConnection, 

54 NoWorkFound, 

55 OutputQuantizedConnection, 

56 PipelineTask, 

57 PipelineTaskConfig, 

58 PipelineTaskConnections, 

59 QuantumContext, 

60 Struct, 

61) 

62from lsst.pipe.base.connectionTypes import BaseInput, Output 

63from lsst.skymap import Index2D 

64 

65 

66class MetadetectionProcessingError(AlgorithmError): 

67 """Exception raised when metadetection processing fails.""" 

68 

69 @property 

70 def metadata(self) -> dict: 

71 return {} 

72 

73 

74class MetadetectionShearConnections(PipelineTaskConnections, dimensions={"patch"}): 

75 """Definitions of inputs and outputs for MetadetectionShearTask.""" 

76 

77 input_coadds = cT.Input( 

78 "deep_coadd_cell_predetection", 

79 storageClass="MultipleCellCoadd", 

80 doc="Per-band deep coadds.", 

81 multiple=True, 

82 dimensions={"patch", "band"}, 

83 ) 

84 

85 ref_cat = cT.PrerequisiteInput( 

86 doc="Reference catalog used to mask bright objects.", 

87 name="the_monster_20250219", 

88 storageClass="SimpleCatalog", 

89 dimensions=("skypix",), 

90 deferLoad=True, 

91 multiple=True, 

92 ) 

93 

94 metadetect_catalog = cT.Output( 

95 "object_shear_patch", 

96 storageClass="ArrowTable", 

97 doc="Output catalog with all quantities measured inside the metacalibration loop.", 

98 multiple=False, 

99 dimensions={"patch"}, 

100 ) 

101 

102 metadetect_schema = cT.InitOutput( 

103 "object_shear_schema", 

104 storageClass="ArrowSchema", 

105 doc="Schema of the output catalog.", 

106 ) 

107 

108 config: MetadetectionShearConfig 

109 

110 def __init__(self, *, config=None): 

111 super().__init__(config=config) 

112 

113 if not config: 

114 return None 

115 

116 if not config.do_mask_bright_objects: 

117 del self.ref_cat 

118 

119 def adjustQuantum( 

120 self, 

121 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

122 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

123 label: str, 

124 data_id: DataCoordinate, 

125 ) -> tuple[ 

126 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

127 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

128 ]: 

129 # Docstring inherited. 

130 # This is a hook for customizing what is input and output to each 

131 # invocation of the task as early as possible, which we override here 

132 # to make sure we have exactly the required bands, no more, no less. 

133 connection, original_input_coadds = inputs["input_coadds"] 

134 bands_missing = set(self.config.photometry_bands) 

135 adjusted_input_coadds = [] 

136 for ref in original_input_coadds: 

137 if ref.dataId["band"] in self.config.photometry_bands: 

138 adjusted_input_coadds.append(ref) 

139 bands_missing.remove(ref.dataId["band"]) 

140 if missing_shear_bands := bands_missing.intersection(self.config.metadetect.shear_bands): 

141 raise NoWorkFound(f"Required bands {missing_shear_bands} not present for {label}@{data_id}).") 

142 adjusted_inputs = {"input_coadds": (connection, adjusted_input_coadds)} 

143 inputs.update(adjusted_inputs) 

144 super().adjustQuantum(inputs, outputs, label, data_id) 

145 return adjusted_inputs, {} 

146 

147 

148class MetadetectionShearConfig(PipelineTaskConfig, pipelineConnections=MetadetectionShearConnections): 

149 """Configuration definition for MetadetectionShearTask.""" 

150 

151 metadetect = ConfigurableField( 

152 target=MetadetectTask, 

153 doc="Configuration for metadetection.", 

154 ) 

155 

156 photometry_bands = ListField[str]( 

157 "Bands expected to be present. Cells with one or more of these bands " 

158 "missing will be skipped. Bands other than those listed here will " 

159 "not be processed.", 

160 default=["g", "r", "i", "z"], 

161 ) 

162 

163 do_mask_bright_objects = Field[bool]( 

164 doc="Mask bright objects in coadds?", 

165 default=False, 

166 ) 

167 

168 ref_loader = ConfigField( 

169 dtype=LoadReferenceObjectsConfig, 

170 doc="Reference object loader used for bright-object masking.", 

171 ) 

172 

173 ref_loader_filter_name = Field[str]( 

174 "Filter name from ref_loader used for bright-object masking.", 

175 default="monster_DES_r", 

176 ) 

177 

178 border = Field[int]( 

179 "Border to apply to single cell images", 

180 default=50, 

181 ) 

182 

183 id_generator = SkyMapIdGeneratorConfig.make_field() 

184 

185 def setDefaults(self): 

186 super().setDefaults() 

187 self.metadetect.shear_bands = ["r", "i", "z"] 

188 self.metadetect.metacal.types = ["noshear", "1p", "1m", "2p", "2m"] 

189 

190 def validate(self): 

191 super().validate() 

192 if (shear_bands := self.metadetect.shear_bands) is not None and not set(shear_bands).issubset( 

193 self.photometry_bands 

194 ): 

195 raise FieldValidationError( 

196 self.__class__.metadetect, 

197 self, 

198 "photometry_bands must be a list of bands that is a superset of metadetect.shear_bands", 

199 ) 

200 

201 

202class MetadetectionShearTask(PipelineTask): 

203 """A PipelineTask that measures shear using metadetection.""" 

204 

205 _DefaultName: ClassVar[str] = "metadetectionShear" 

206 ConfigClass: ClassVar[type[MetadetectionShearConfig]] = MetadetectionShearConfig 

207 

208 config: MetadetectionShearConfig 

209 

210 def __init__(self, *, initInputs: dict[str, Any] | None = None, **kwargs: Any): 

211 super().__init__(initInputs=initInputs, **kwargs) 

212 self.metadetect_schema = self.make_metadetect_schema(self.config) 

213 self.makeSubtask("metadetect") 

214 

215 @classmethod 

216 def make_metadetect_schema(cls, config: MetadetectionShearConfig) -> pa.Schema: 

217 """Construct a PyArrow Schema for this task's main output catalog. 

218 

219 Parameters 

220 ---------- 

221 config : `MetadetectionShearConfig` 

222 Configuration that may be used to control details of the schema. 

223 

224 Returns 

225 ------- 

226 object_schema : `pyarrow.Schema` 

227 Schema for the object catalog produced by this task. Each field's 

228 metadata should include both a 'doc' entry and a 'unit' entry. 

229 """ 

230 pa_schema = pa.schema( 

231 [ 

232 # Fields from pipeline bookkeeping. 

233 pa.field( 

234 "shearObjectId", 

235 pa.int64(), 

236 nullable=False, 

237 metadata={ 

238 "doc": ( 

239 "Unique identifier for a ShearObject, specific " 

240 "to a single metacalibration counterfactual image." 

241 ), 

242 "unit": "", 

243 }, 

244 ), 

245 pa.field( 

246 "tract", 

247 pa.int64(), 

248 nullable=False, 

249 metadata={ 

250 "doc": "ID of the tract on which this measurement was made.", 

251 "unit": "", 

252 }, 

253 ), 

254 pa.field( 

255 "patch", 

256 pa.int64(), 

257 nullable=False, 

258 metadata={ 

259 "doc": "ID of the patch within the tract on which this measurement was made.", 

260 "unit": "", 

261 }, 

262 ), 

263 pa.field( 

264 "cell_x", 

265 pa.int32(), 

266 nullable=False, 

267 metadata={ 

268 "doc": "Column of the cell within the patch on which this measurement was made.", 

269 "unit": "", 

270 }, 

271 ), 

272 pa.field( 

273 "cell_y", 

274 pa.int32(), 

275 nullable=False, 

276 metadata={ 

277 "doc": "Row of the cell within the patch on which this measurement was made.", 

278 "unit": "", 

279 }, 

280 ), 

281 # Fields from metadetection (generic). 

282 pa.field( 

283 "metaStep", 

284 pa.string(), 

285 nullable=False, 

286 metadata={ 

287 "doc": ( 

288 "Type of artificial shear applied to image. " 

289 "One of: 'ns', '1p', '1m', '2p', '2m'." 

290 ), 

291 "unit": "", 

292 }, 

293 ), 

294 pa.field( 

295 "image_flags", 

296 pa.int32(), 

297 nullable=False, 

298 metadata={ 

299 "doc": "Flags for the image on which this measurement was made.", 

300 "unit": "", 

301 }, 

302 ), 

303 pa.field( 

304 "x", 

305 pa.float32(), 

306 nullable=False, 

307 metadata={ 

308 "doc": "Centroid (tract, x-axis) of the detected ShearObject.", 

309 "unit": "", 

310 }, 

311 ), 

312 pa.field( 

313 "y", 

314 pa.float32(), 

315 nullable=False, 

316 metadata={ 

317 "doc": "Centroid (tract, y-axis) of the detected ShearObject.", 

318 "unit": "", 

319 }, 

320 ), 

321 pa.field( 

322 "ra", 

323 pa.float64(), 

324 nullable=False, 

325 metadata={ 

326 "doc": "Detected Right Ascension of the ShearObject.", 

327 "unit": "degrees", 

328 }, 

329 ), 

330 pa.field( 

331 "dec", 

332 pa.float64(), 

333 nullable=False, 

334 metadata={ 

335 "doc": "Detected Declination of the ShearObject.", 

336 "unit": "degrees", 

337 }, 

338 ), 

339 # Original PSF measurements 

340 pa.field( 

341 "psfOriginal_flags", 

342 pa.int32(), 

343 nullable=False, 

344 metadata={ 

345 "doc": "Flags for the original PSF measurement.", 

346 "unit": "", 

347 }, 

348 ), 

349 pa.field( 

350 "psfOriginal_e1", 

351 pa.float32(), 

352 nullable=False, 

353 metadata={ 

354 "doc": "Distortion-style e1 of the original PSF from adaptive moments.", 

355 "unit": "", 

356 }, 

357 ), 

358 pa.field( 

359 "psfOriginal_e2", 

360 pa.float32(), 

361 nullable=False, 

362 metadata={ 

363 "doc": "Distortion-style e2 of the original PSF from adaptive moments.", 

364 "unit": "", 

365 }, 

366 ), 

367 pa.field( 

368 "psfOriginal_T", 

369 pa.float32(), 

370 nullable=False, 

371 metadata={ 

372 "doc": "Trace (<x^2> + <y^2>) measurement of the original PSF from adaptive moments.", 

373 "unit": "arcseconds squared", 

374 }, 

375 ), 

376 pa.field( 

377 "bmask_flags", 

378 pa.int32(), 

379 nullable=False, 

380 metadata={ 

381 "doc": "`bmask` flags for the ShearObject", 

382 "unit": "", 

383 }, 

384 ), 

385 pa.field( 

386 "ormask_flags", 

387 pa.int32(), 

388 nullable=False, 

389 metadata={ 

390 "doc": "`ored` mask flags for the ShearObject", 

391 "unit": "", 

392 }, 

393 ), 

394 pa.field( 

395 "mfrac", 

396 pa.float32(), 

397 nullable=False, 

398 metadata={ 

399 "doc": "Gaussian-weighted masked fraction for the ShearObject.", 

400 "unit": "", 

401 }, 

402 ), 

403 # Fields that come only from gauss algorithm. 

404 # Reconvolved PSF measurements (gauss) 

405 pa.field( 

406 "gauss_psfReconvolved_flags", 

407 pa.int32(), 

408 nullable=False, 

409 metadata={ 

410 "doc": "Flags for reconvolved PSF (measured with gauss algorithm).", 

411 "unit": "", 

412 }, 

413 ), 

414 pa.field( 

415 "gauss_psfReconvolved_g1", 

416 pa.float32(), 

417 nullable=False, 

418 metadata={ 

419 "doc": "Reduced-shear g1 of the reconvolved PSF (measured with gauss algorithm).", 

420 "unit": "", 

421 }, 

422 ), 

423 pa.field( 

424 "gauss_psfReconvolved_g2", 

425 pa.float32(), 

426 nullable=False, 

427 metadata={ 

428 "doc": "Reduced-shear g2 of the reconvolved PSF (measured with gauss algorithm).", 

429 "unit": "", 

430 }, 

431 ), 

432 pa.field( 

433 "gauss_psfReconvolved_T", 

434 pa.float32(), 

435 nullable=False, 

436 metadata={ 

437 "doc": ( 

438 "Trace (<x^2> + <y^2>) of the reconvolved PSF (measured with gauss algorithm)." 

439 ), 

440 "unit": "arcseconds squared", 

441 }, 

442 ), 

443 # Object measurements (gauss algorithm). 

444 pa.field( 

445 "gauss_g1", 

446 pa.float32(), 

447 nullable=False, 

448 metadata={ 

449 "doc": ( 

450 "Reduced-shear g1 measurement of the ShearObject " 

451 "(measured with gauss algorithm)." 

452 ), 

453 "unit": "", 

454 }, 

455 ), 

456 pa.field( 

457 "gauss_g2", 

458 pa.float32(), 

459 nullable=False, 

460 metadata={ 

461 "doc": ( 

462 "Reduced-shear g2 measurement of the ShearObject " 

463 "(measured with gauss algorithm)." 

464 ), 

465 "unit": "", 

466 }, 

467 ), 

468 pa.field( 

469 "gauss_g1_g1_Cov", 

470 pa.float32(), 

471 nullable=False, 

472 metadata={ 

473 "doc": ( 

474 "Auto-covariance of g1 measurement of the ShearObject " 

475 "(measured with gauss algorithm)." 

476 ), 

477 "unit": "", 

478 }, 

479 ), 

480 pa.field( 

481 "gauss_g1_g2_Cov", 

482 pa.float32(), 

483 nullable=False, 

484 metadata={ 

485 "doc": ( 

486 "Cross-covariance of g1 and g2 measurement of the ShearObject " 

487 "(measured with gauss algorithm)." 

488 ), 

489 "unit": "", 

490 }, 

491 ), 

492 pa.field( 

493 "gauss_g2_g2_Cov", 

494 pa.float32(), 

495 nullable=False, 

496 metadata={ 

497 "doc": ( 

498 "Auto-covariance of g2 measurement of the ShearObject " 

499 "(measured with gauss algorithm)." 

500 ), 

501 "unit": "", 

502 }, 

503 ), 

504 ], 

505 metadata={ 

506 "shear_step": str(SHEAR_STEP), 

507 "shear_bands": "".join(sorted(config.metadetect.shear_bands)), 

508 }, 

509 ) 

510 

511 for alg_name in ("gauss", "pgauss"): 

512 pa_schema = pa_schema.append( 

513 pa.field( 

514 f"{alg_name}_snr", 

515 pa.float32(), 

516 nullable=False, 

517 metadata={ 

518 "doc": ( 

519 "Signal-to-noise ratio measure of the ShearObject " 

520 f"(measured with {alg_name} algorithm)." 

521 ), 

522 "unit": "", 

523 }, 

524 ), 

525 ) 

526 pa_schema = pa_schema.append( 

527 pa.field( 

528 f"{alg_name}_T", 

529 pa.float32(), 

530 nullable=False, 

531 metadata={ 

532 "doc": ( 

533 "Trace (<x^2> + <y^2>) measurement of the ShearObject " 

534 f"(measured with {alg_name} algorithm)." 

535 ), 

536 "unit": "arcseconds squared", 

537 }, 

538 ), 

539 ) 

540 pa_schema = pa_schema.append( 

541 pa.field( 

542 f"{alg_name}_TErr", 

543 pa.float32(), 

544 nullable=False, 

545 metadata={ 

546 "doc": ( 

547 "Uncertainty in the trace measurement of the ShearObject " 

548 f"(measured with {alg_name} algorithm)." 

549 ), 

550 "unit": "arcseconds squared", 

551 }, 

552 ), 

553 ) 

554 pa_schema = pa_schema.append( 

555 pa.field( 

556 f"{alg_name}_shape_flags", 

557 pa.int32(), 

558 nullable=False, 

559 metadata={ 

560 "doc": ( 

561 "Flags for the second order moments measurement of the ShearObject " 

562 f"(measured with {alg_name} algorithm)." 

563 ), 

564 "unit": "", 

565 }, 

566 ), 

567 ) 

568 pa_schema = pa_schema.append( 

569 pa.field( 

570 f"{alg_name}_object_flags", 

571 pa.int32(), 

572 nullable=False, 

573 metadata={ 

574 "doc": f"Flags for the ShearObject measurement (measured with {alg_name} algorithm).", 

575 "unit": "", 

576 }, 

577 ), 

578 ) 

579 pa_schema = pa_schema.append( 

580 pa.field( 

581 f"{alg_name}_flags", 

582 pa.int32(), 

583 nullable=False, 

584 metadata={ 

585 "doc": f"Overall flags for {alg_name} measurement algorithm.", 

586 "unit": "", 

587 }, 

588 ), 

589 ) 

590 

591 # Per-band quantities, typically fluxes and associated quantites. 

592 for b in config.photometry_bands: 

593 pa_schema = pa_schema.append( 

594 pa.field( 

595 f"{b}_{alg_name}Flux_flags", 

596 pa.int32(), 

597 nullable=False, 

598 metadata={ 

599 "doc": f"Flags set for flux in {b} band measured with {alg_name} algorithm.", 

600 "unit": "", 

601 }, 

602 ), 

603 ) 

604 pa_schema = pa_schema.append( 

605 pa.field( 

606 f"{b}_{alg_name}Flux", 

607 pa.float32(), 

608 nullable=b not in config.metadetect.shear_bands, 

609 metadata={ 

610 "doc": f"Flux in {b} band (measured with {alg_name} algorithm).", 

611 "unit": "", 

612 }, 

613 ), 

614 ) 

615 pa_schema = pa_schema.append( 

616 pa.field( 

617 f"{b}_{alg_name}FluxErr", 

618 pa.float32(), 

619 nullable=b not in config.metadetect.shear_bands, 

620 metadata={ 

621 "doc": f"Flux uncertainty in {b} band (measured with {alg_name} algorithm).", 

622 "unit": "", 

623 }, 

624 ), 

625 ) 

626 

627 return pa_schema 

628 

629 def runQuantum( 

630 self, 

631 qc: QuantumContext, 

632 inputRefs: InputQuantizedConnection, 

633 outputRefs: OutputQuantizedConnection, 

634 ) -> None: 

635 # Docstring inherited. 

636 

637 id_generator = self.config.id_generator.apply(qc.quantum.dataId) 

638 

639 if self.config.do_mask_bright_objects: 

640 ref_loader = ReferenceObjectLoader( 

641 dataIds=[ref.datasetRef.dataId for ref in inputRefs.ref_cat], 

642 refCats=[qc.get(ref) for ref in inputRefs.ref_cat], 

643 name=self.config.connections.ref_cat, 

644 config=self.config.ref_loader, 

645 log=self.log, 

646 ) 

647 ref_cat = ref_loader.loadRegion( 

648 qc.quantum.dataId.region, filterName=self.config.ref_loader_filter_name 

649 ) 

650 else: 

651 ref_cat = None 

652 

653 # Read the coadds and put them in the order defined by 

654 # config.photometry_bands (note that each MultipleCellCoadd object also 

655 # knows its own band, if that's needed). 

656 

657 coadds_by_band = { 

658 ref.dataId["band"]: qc.get(ref) 

659 for ref in inputRefs.input_coadds 

660 if ref.dataId["band"] in self.config.photometry_bands 

661 } 

662 

663 try: 

664 outputs = self.run( 

665 patch_coadds=coadds_by_band, 

666 id_generator=id_generator, 

667 ref_cat=ref_cat, 

668 ) 

669 except AlgorithmError as err: 

670 # We know there are no actual outputs in this case, but this is 

671 # still the right exception to raise (it's just badly named). 

672 raise AnnotatedPartialOutputsError.annotate(err, self, log=self.log) from err 

673 qc.put(outputs, outputRefs) 

674 

675 def run( 

676 self, 

677 *, 

678 patch_coadds: Mapping[str, MultipleCellCoadd], 

679 id_generator: FullIdGenerator, 

680 ref_cat: SimpleCatalog | None, 

681 ) -> Struct: 

682 """Run metadetection on a patch. 

683 

684 Parameters 

685 ---------- 

686 patch_coadds : `~collections.abc.Mapping` [ \ 

687 `~lsst.cell_coadds.MultipleCellCoadd` ] 

688 Per-band, per-patch coadds, in the order specified by 

689 `MetadetectionShearConfig.photometry_bands`. 

690 id_generator : `~lsst.meas.base.FullIdGenerator` 

691 Generator for object IDs and to seed the random number generator. 

692 ref_cat : `lsst.afw.table.SimpleCatalog`, optional 

693 Reference catalog to use when masking bright stars. 

694 

695 Returns 

696 ------- 

697 results : `lsst.pipe.base.Struct` 

698 Structure with the following attributes: 

699 

700 - ``metadetect_catalog`` [ `pyarrow.Table` ]: the output object 

701 catalog for the patch, with schema equal to `metadetect_schema`. 

702 """ 

703 seed = id_generator.catalog_id 

704 self.rng = np.random.RandomState(seed) 

705 idstart = 0 

706 

707 grid = patch_coadds[self.config.metadetect.shear_bands[0]].grid 

708 nx_cells, ny_cells = grid.shape 

709 single_cell_tables: list[pa.Table] = [] 

710 for nx, ny in product(range(nx_cells), range(ny_cells)): 

711 cell_id = Index2D(nx, ny) 

712 bbox = grid.bbox_of(cell_id) 

713 cell_coadds = [patch_coadd.stitch(bbox) for patch_coadd in patch_coadds.values()] 

714 self.log.debug("Processing cell %s %s", nx, ny) 

715 

716 try: 

717 res = self.process_cell(cell_coadds, cell_id=cell_id) 

718 except Exception as e: 

719 self.log.error("Failed to process cell %s %s: %s", nx, ny, e) 

720 continue 

721 

722 if len(res) > 0: 

723 res["id"] = id_generator.arange(idstart, idstart + len(res)) 

724 # TODO: Avoid back and forth conversion between array and dict. 

725 da = self._dictify( 

726 res, 

727 tract=id_generator.data_id.tract.id, 

728 patch=id_generator.data_id.patch.id, 

729 ) 

730 table = pa.Table.from_pydict(da, self.metadetect_schema) 

731 

732 single_cell_tables.append(table) 

733 idstart += len(res) 

734 

735 if not single_cell_tables: 

736 raise MetadetectionProcessingError("No objects found in any cell") 

737 

738 # TODO: DM-53796 De-duplicate objects before concatenation. 

739 return Struct( 

740 metadetect_catalog=pa.concat_tables(single_cell_tables), 

741 ) 

742 

743 def process_cell( 

744 self, 

745 cell_coadds: Sequence[StitchedCoadd], 

746 cell_id: Index2D, 

747 ) -> pa.Table: 

748 """Run metadetection on a single cell. 

749 

750 Parameters 

751 ---------- 

752 cell_coadds : `~collections.abc.Sequence` [ \ 

753 `~lsst.cell_coadds.StitchedCoadd` ] 

754 Per-band, per-cell coadds, in the order specified by 

755 `MetadetectionShearConfig.photometry_bands`. 

756 cell_id : `~lsst.skymap.Index2D` 

757 The cell ID for the cell being processed. 

758 

759 Returns 

760 ------- 

761 metadetect_catalog : `pyarrow.Table` 

762 Output object catalog for the cell, with schema equal to 

763 `metadetect_schema`. 

764 """ 

765 

766 coadd_data = self._cell_to_coadd_data(cell_coadds) 

767 # TODO get bright star etc. info as input 

768 bright_info = [] 

769 

770 apply_apodized_edge_masks_mbexp(**coadd_data) 

771 

772 if len(bright_info) > 0: 

773 apply_apodized_bright_masks_mbexp(bright_info=bright_info, **coadd_data) 

774 

775 mask_frac = _get_mask_frac( 

776 coadd_data["mfrac_mbexp"], 

777 trim_pixels=0, 

778 ) 

779 

780 res = self.metadetect.run(rng=self.rng, **coadd_data) 

781 

782 comb_res = _make_comb_data( 

783 cell_coadd=cell_coadds[0], 

784 res=res, 

785 mask_frac=mask_frac, 

786 bands=[cell_coadd.band for cell_coadd in cell_coadds], 

787 cell_id=cell_id, 

788 ) 

789 

790 return comb_res 

791 

792 @staticmethod 

793 def _cell_to_coadd_data(cell_coadds: Sequence[StitchedCoadd]): 

794 coadd_data_list = [] 

795 for cell_coadd in cell_coadds: 

796 coadd_data = {} 

797 coadd_data["coadd_exp"] = cell_coadd.asExposure() 

798 coadd_data["coadd_noise_exp"] = cell_coadd.asExposure(noise_index=0) 

799 coadd_data["coadd_mfrac_exp"] = ExposureF(coadd_data["coadd_exp"], deep=True) 

800 coadd_data["coadd_mfrac_exp"].image = cell_coadd.mask_fractions 

801 coadd_data_list.append(coadd_data) 

802 

803 return extract_multiband_coadd_data(coadd_data_list) 

804 

805 def _dictify(self, data, tract: int, patch: int): 

806 output = {} 

807 # TODO: Move this to a better location after DP2. 

808 mapping = { 

809 "bmask": "bmask_flags", 

810 "cell_x": "cell_x", 

811 "cell_y": "cell_y", 

812 "col": "x", 

813 "col_diff": "x_offset", # dropped. 

814 "dec": "dec", 

815 "gauss_flags": "gauss_flags", 

816 "gauss_g_1": "gauss_g1", 

817 "gauss_g_2": "gauss_g2", 

818 "gauss_g_cov_11": "gauss_g1_g1_Cov", 

819 "gauss_g_cov_12": "gauss_g1_g2_Cov", # same as 21. 

820 "gauss_g_cov_22": "gauss_g2_g2_Cov", 

821 "gauss_obj_flags": "gauss_object_flags", 

822 "gauss_psf_flags": "gauss_psfReconvolved_flags", 

823 "gauss_psf_g_1": "gauss_psfReconvolved_g1", 

824 "gauss_psf_g_2": "gauss_psfReconvolved_g2", 

825 "gauss_psf_T": "gauss_psfReconvolved_T", 

826 "gauss_s2n": "gauss_snr", 

827 "gauss_T": "gauss_T", 

828 "gauss_T_err": "gauss_TErr", 

829 "gauss_T_flags": "gauss_shape_flags", 

830 "gauss_T_ratio": "gauss_T_ratio", # dropped. 

831 "id": "shearObjectId", 

832 "mfrac": "mfrac", 

833 "ormask": "ormask_flags", 

834 "pgauss_flags": "pgauss_flags", 

835 "pgauss_obj_flags": "pgauss_object_flags", 

836 "pgauss_s2n": "pgauss_snr", 

837 "pgauss_T": "pgauss_T", 

838 "pgauss_T_err": "pgauss_TErr", 

839 "pgauss_T_flags": "pgauss_shape_flags", 

840 "pgauss_T_ratio": "pgauss_T_ratio", # dropped. 

841 "psfrec_flags": "psfOriginal_flags", 

842 "psfrec_g_1": "psfOriginal_e1", 

843 "psfrec_g_2": "psfOriginal_e2", 

844 "psfrec_T": "psfOriginal_T", 

845 "ra": "ra", 

846 "row": "y", 

847 "row_diff": "y_offset", # dropped. 

848 "shear_type": "metaStep", 

849 "stamp_flags": "image_flags", 

850 } 

851 

852 for b, alg_name in product(self.config.photometry_bands, ("gauss", "pgauss")): 

853 mapping[f"{alg_name}_band_flux_{b}"] = f"{b}_{alg_name}Flux" 

854 mapping[f"{alg_name}_band_flux_err_{b}"] = f"{b}_{alg_name}FluxErr" 

855 mapping[f"{alg_name}_band_flux_flags_{b}"] = f"{b}_{alg_name}Flux_flags" 

856 

857 for name in mapping: 

858 if name in data.dtype.names: 

859 output[mapping.get(name, name)] = data[name] 

860 else: 

861 if "flags" in name.lower(): 

862 output[mapping.get(name, name)] = np.ones_like(data["id"], dtype=np.int32) 

863 else: 

864 output[mapping.get(name, name)] = np.ones_like(data["id"], dtype=np.float32) 

865 output[mapping.get(name, name)] *= np.nan 

866 

867 output["tract"] = tract * np.ones_like(data["id"], dtype=np.int64) 

868 output["patch"] = patch * np.ones_like(data["id"], dtype=np.int32) 

869 

870 return output 

871 

872 

873def _make_comb_data( 

874 cell_coadd, 

875 res, 

876 mask_frac, 

877 bands, 

878 cell_id, 

879): 

880 idinfo = cell_coadd.identifiers 

881 

882 copy_dt = [ 

883 # we will copy out of arrays to these 

884 ("psfrec_g_1", "f4"), 

885 ("psfrec_g_2", "f4"), 

886 ("gauss_psf_g_1", "f4"), 

887 ("gauss_psf_g_2", "f4"), 

888 ("gauss_g_1", "f4"), 

889 ("gauss_g_2", "f4"), 

890 ("gauss_g_cov_11", "f4"), 

891 ("gauss_g_cov_12", "f4"), 

892 ("gauss_g_cov_21", "f4"), 

893 ("gauss_g_cov_22", "f4"), 

894 ] 

895 

896 for b in bands: 

897 copy_dt.append(("gauss_band_flux_flags_%s" % b, "i4")) 

898 copy_dt.append(("gauss_band_flux_%s" % b, "f4")) 

899 copy_dt.append(("gauss_band_flux_err_%s" % b, "f4")) 

900 copy_dt.append(("pgauss_band_flux_flags_%s" % b, "i4")) 

901 copy_dt.append(("pgauss_band_flux_%s" % b, "f4")) 

902 copy_dt.append(("pgauss_band_flux_err_%s" % b, "f4")) 

903 

904 add_dt = [ 

905 ("id", "u8"), 

906 ("tract", "u4"), 

907 ("patch_x", "u1"), 

908 ("patch_y", "u1"), 

909 ("cell_x", "u1"), 

910 ("cell_y", "u1"), 

911 ("shear_type", "U2"), 

912 ("mask_frac", "f4"), 

913 ("primary", bool), 

914 ] + copy_dt 

915 

916 if not hasattr(res, "keys"): 

917 res = {"noshear": res} 

918 

919 dlist = [] 

920 for stype in res.keys(): 

921 data = res[stype] 

922 if data is not None: 

923 newdata = eu.numpy_util.add_fields(data, add_dt) 

924 newdata["psfrec_g_1"] = newdata["psfrec_g"][:, 0] 

925 newdata["psfrec_g_2"] = newdata["psfrec_g"][:, 1] 

926 

927 newdata["gauss_psf_g_1"] = newdata["gauss_psf_g"][:, 0] 

928 newdata["gauss_psf_g_2"] = newdata["gauss_psf_g"][:, 1] 

929 newdata["gauss_g_1"] = newdata["gauss_g"][:, 0] 

930 newdata["gauss_g_2"] = newdata["gauss_g"][:, 1] 

931 

932 newdata["gauss_g_cov_11"] = newdata["gauss_g_cov"][:, 0, 0] 

933 newdata["gauss_g_cov_12"] = newdata["gauss_g_cov"][:, 0, 1] 

934 newdata["gauss_g_cov_21"] = newdata["gauss_g_cov"][:, 1, 0] 

935 newdata["gauss_g_cov_22"] = newdata["gauss_g_cov"][:, 1, 1] 

936 

937 # To-do make compatible with a single band better than this. 

938 if len(bands) > 1: 

939 for i, b in enumerate(bands): 

940 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"][:, i] 

941 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"][:, i] 

942 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"][:, i] 

943 newdata["pgauss_band_flux_flags_%s" % b] = newdata["pgauss_band_flux_flags"][:, i] 

944 newdata["pgauss_band_flux_%s" % b] = newdata["pgauss_band_flux"][:, i] 

945 newdata["pgauss_band_flux_err_%s" % b] = newdata["pgauss_band_flux_err"][:, i] 

946 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"][:, i] 

947 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"][:, i] 

948 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"][:, i] 

949 else: 

950 b = bands[0] 

951 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"] 

952 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"] 

953 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"] 

954 newdata["pgauss_band_flux_flags_%s" % b] = newdata["pgauss_band_flux_flags"] 

955 newdata["pgauss_band_flux_%s" % b] = newdata["pgauss_band_flux"] 

956 newdata["pgauss_band_flux_err_%s" % b] = newdata["pgauss_band_flux_err"] 

957 newdata["gauss_band_flux_flags_%s" % b] = newdata["gauss_band_flux_flags"] 

958 newdata["gauss_band_flux_%s" % b] = newdata["gauss_band_flux"] 

959 newdata["gauss_band_flux_err_%s" % b] = newdata["gauss_band_flux_err"] 

960 

961 newdata["tract"] = idinfo.tract 

962 newdata["patch_x"] = idinfo.patch.x 

963 newdata["patch_y"] = idinfo.patch.y 

964 newdata["cell_x"] = cell_id.x 

965 newdata["cell_y"] = cell_id.y 

966 

967 if stype == "noshear": 

968 newdata["shear_type"] = "ns" 

969 else: 

970 newdata["shear_type"] = stype 

971 

972 dlist.append(newdata) 

973 

974 if len(dlist) > 0: 

975 output = eu.numpy_util.combine_arrlist(dlist) 

976 else: 

977 output = [] 

978 

979 return output 

980 

981 

982def _get_mask_frac(mfrac_mbexp, trim_pixels=0): 

983 """ 

984 get the average mask frac for each band and then return the max of those 

985 """ 

986 

987 mask_fracs = [] 

988 for mfrac_exp in mfrac_mbexp: 

989 mfrac = mfrac_exp.image.array 

990 dim = mfrac.shape[0] 

991 mfrac = mfrac[ 

992 trim_pixels : dim - trim_pixels - 1, 

993 trim_pixels : dim - trim_pixels - 1, 

994 ] 

995 mask_fracs.append(mfrac.mean()) 

996 

997 return max(mask_fracs)