Coverage for python / lsst / meas / extensions / multiprofit / rebuild_coadd_multiband.py: 0%
178 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:22 +0000
1# This file is part of meas_extensions_multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["ModelRebuilder", "PatchModelMatches", "PatchCoaddRebuilder"]
24from functools import cached_property
25from typing import Iterable
27import astropy.table
28import astropy.units as u
29import lsst.afw.table as afwTable
30import lsst.daf.butler as dafButler
31import lsst.gauss2d.fit as g2f
32import lsst.geom as geom
33from lsst.meas.extensions.scarlet.io import updateCatalogFootprints
34from lsst.pipe.base import QuantumContext, QuantumGraph
35from lsst.pipe.tasks.fit_coadd_multiband import (
36 CoaddMultibandFitBaseTemplates,
37 CoaddMultibandFitInputConnections,
38 CoaddMultibandFitTask,
39)
40from lsst.skymap import BaseSkyMap, TractInfo
41import numpy as np
42import pydantic
44from .fit_coadd_multiband import (
45 CatalogExposurePsfs,
46 CatalogSourceFitterConfigData,
47 MultiProFitSourceConfig,
48 MultiProFitSourceTask,
49)
51astropy_to_geom_units = {
52 u.arcmin: geom.arcminutes,
53 u.arcsec: geom.arcseconds,
54 u.mas: geom.milliarcseconds,
55 u.deg: geom.degrees,
56 u.rad: geom.radians,
57}
60def astropy_unit_to_geom(unit: u.Unit, default=None) -> geom.AngleUnit:
61 """Convert an astropy unit to an lsst.geom unit.
63 Parameters
64 ----------
65 unit
66 The astropy unit to convert.
67 default
68 The default value to return if no known conversion is found.
70 Returns
71 -------
72 unit_geom
73 The equivalent unit, if found.
75 Raises
76 ------
77 ValueError
78 Raised if no equivalent unit is found.
79 """
80 unit_geom = astropy_to_geom_units.get(unit, default)
81 if unit_geom is None:
82 raise ValueError(f"{unit=} not found in {astropy_to_geom_units=}")
83 return unit_geom
86def find_patches(tract_info: TractInfo, ra_array, dec_array, unit: geom.AngleUnit) -> list[int]:
87 """Find the patches containing a list of ra/dec values within a tract.
89 Parameters
90 ----------
91 tract_info
92 The TractInfo object for the tract.
93 ra_array
94 The array of right ascension values.
95 dec_array
96 The array of declination values (must be same length as ra_array).
97 unit
98 The unit of the RA/dec values.
100 Returns
101 -------
102 patches
103 A list of patches containing the specified RA/dec values.
104 """
105 radec = [geom.SpherePoint(ra, dec, units=unit) for ra, dec in zip(ra_array, dec_array, strict=True)]
106 points = np.array([geom.Point2I(tract_info.wcs.skyToPixel(coords)) for coords in radec])
107 x_list, y_list = (points[:, idx] // tract_info.patch_inner_dimensions[idx] for idx in range(2))
108 patches = [tract_info.getSequentialPatchIndexFromPair((x, y)) for x, y in zip(x_list, y_list)]
109 return patches
112def get_radec_unit(table: astropy.table.Table, coord_ra: str, coord_dec: str, default=None):
113 """Get the RA/dec units for columns in a table.
115 Parameters
116 ----------
117 table
118 The table to determine units for.
119 coord_ra
120 The key of the right ascension column.
121 coord_dec
122 The key of the declination column.
123 default
124 The default value to return if no unit is found.
126 Returns
127 -------
128 unit
129 The unit of the RA/dec columns or None if none is found.
131 Raises
132 ------
133 ValueError
134 Raised if the units are inconsistent.
135 """
136 unit_ra, unit_dec = (
137 astropy_unit_to_geom(table[coord].unit, default=default) for coord in (coord_ra, coord_dec)
138 )
139 if unit_ra != unit_dec:
140 units = {coord: table[coord].unit for coord in (coord_ra, coord_dec)}
141 raise ValueError(f"Reference table has inconsistent {units=}")
142 return unit_ra
145class DataLoader(pydantic.BaseModel):
146 """A collection of data that can be used to rebuild models."""
148 model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)
150 catexps: list[CatalogExposurePsfs] = pydantic.Field(
151 doc="List of MultiProFit catalog-exposure-psf objects used to fit PSF-convolved models",
152 )
153 catalog_multi: afwTable.SourceCatalog = pydantic.Field(
154 doc="Patch-level multiband reference catalog (deepCoadd_ref)",
155 )
157 @cached_property
158 def channels(self) -> tuple[g2f.Channel]:
159 channels = tuple(g2f.Channel.get(catexp.band) for catexp in self.catexps)
160 return channels
162 @classmethod
163 def from_butler(
164 cls, butler: dafButler.Butler, data_id: dict[str], bands: Iterable[str], name_coadd=None, **kwargs
165 ):
166 """Construct a DataLoader from a Butler and dataId.
168 Parameters
169 ----------
170 butler
171 The butler to load from.
172 data_id
173 Key-value pairs for the {name_coadd}Coadd_* dataId.
174 bands
175 The list of bands to load.
176 name_coadd
177 The prefix of the Coadd datasettype name.
178 **kwargs
179 Additional keyword arguments to pass to the init method for
180 `CoaddMultibandFitInputConnections`.
182 Returns
183 -------
184 data_loader
185 An initialized DataLoader.
186 """
187 bands = tuple(bands)
188 if len(set(bands)) != len(bands):
189 raise ValueError(f"{bands=} is not a set")
190 if name_coadd is None:
191 name_coadd = CoaddMultibandFitBaseTemplates["name_coadd"]
193 catalog_multi = butler.get(
194 CoaddMultibandFitInputConnections.cat_ref.name.format(name_coadd=name_coadd), **data_id, **kwargs
195 )
197 catexps = {}
198 for band in bands:
199 data_id["band"] = band
200 catalog = butler.get(
201 CoaddMultibandFitInputConnections.cats_meas.name.format(name_coadd=name_coadd),
202 **data_id,
203 **kwargs,
204 )
205 exposure = butler.get(
206 CoaddMultibandFitInputConnections.coadds.name.format(name_coadd=name_coadd),
207 **data_id,
208 **kwargs,
209 )
210 models_scarlet = butler.get(
211 CoaddMultibandFitInputConnections.models_scarlet.name.format(name_coadd=name_coadd),
212 **data_id,
213 **kwargs,
214 )
215 updateCatalogFootprints(
216 modelData=models_scarlet,
217 catalog=catalog,
218 band=data_id["band"],
219 imageForRedistribution=exposure,
220 removeScarletData=True,
221 updateFluxColumns=False,
222 )
223 # The config and table are harmless dummies
224 catexps[band] = CatalogExposurePsfs(
225 catalog=catalog,
226 exposure=exposure,
227 table_psf_fits=astropy.table.Table(),
228 dataId=data_id,
229 id_tract_patch=data_id["patch"],
230 channel=g2f.Channel.get(band),
231 config_fit=MultiProFitSourceConfig(),
232 )
233 return cls(
234 catalog_multi=catalog_multi,
235 catexps=list(catexps.values()),
236 )
238 def load_deblended_object(
239 self,
240 idx_row: int,
241 ) -> list[g2f.ObservationD]:
242 """Load a deblended object from catexps.
244 Parameters
245 ----------
246 idx_row
247 The index of the object to load.
249 Returns
250 -------
251 observations
252 The observations of the object (deblended if it is a child).
253 """
254 observations = []
255 for catexp in self.catexps:
256 observations.append(catexp.get_source_observation(catexp.get_catalog()[idx_row]))
257 return observations
260class ModelRebuilder(DataLoader):
261 """A rebuilder of MultiProFit models from their inputs and best-fit
262 parameter values.
263 """
265 fit_results: astropy.table.Table = pydantic.Field(doc="Multiprofit model fit results")
266 task_fit: MultiProFitSourceTask = pydantic.Field(doc="The task")
268 @cached_property
269 def config_data(self) -> CatalogSourceFitterConfigData:
270 config_data = self.make_config_data()
271 return config_data
273 @classmethod
274 def from_quantumGraph(
275 cls,
276 butler: dafButler.Butler,
277 quantumgraph: QuantumGraph,
278 dataId: dict = None,
279 ):
280 """Make a rebuilder from a butler and quantumgraph.
282 Parameters
283 ----------
284 butler
285 The butler that the quantumgraph was built for.
286 quantumgraph
287 The quantum graph file from a CoaddMultibandFitTask using the
288 MultiProFitSourceTask.
289 dataId
290 The dataId for the fit, including skymap, tract and patch.
292 Returns
293 -------
294 rebuilder
295 A ModelRebuilder instance initialized with the necessary kwargs.
296 """
297 if dataId is None:
298 quantum = next(iter(quantumgraph.outputQuanta)).quantum
299 else:
300 quantum = None
301 for node in quantumgraph.outputQuanta:
302 if node.quantum.dataId.to_simple().dataId == dataId:
303 quantum = node.quantum
304 break
305 if quantum is None:
306 raise ValueError(
307 f"{dataId=} not found in {[x.quantum.dataId for x in quantumgraph.outputQuanta]=}"
308 )
309 taskDef = next(iter(quantumgraph.iterTaskGraph()))
310 butlerQC = QuantumContext(butler, quantum)
311 config = butler.get(f"{taskDef.label}_config")
312 # I have no idea what to put for initInputs.
313 # quantum.initInputs looks wrong - the values can be lists
314 # quantumgraph.initInputRefs(taskDef) returns a list of DatasetRefs...
315 # ... but I'm not sure how to map that to connection names?
316 task: CoaddMultibandFitTask = taskDef.taskClass(config=config, initInputs={})
317 if not isinstance(task, CoaddMultibandFitTask):
318 raise ValueError(f"{task=} type={type(task)} !isinstance of {CoaddMultibandFitTask=}")
319 task_fit: MultiProFitSourceTask = task.fit_coadd_multiband
320 if not isinstance(task_fit, MultiProFitSourceTask):
321 raise ValueError(f"{task_fit=} type={type(task_fit)} !isinstance of {MultiProFitSourceTask=}")
322 inputRefs, outputRefs = taskDef.connections.buildDatasetRefs(quantum)
323 inputs = butlerQC.get(inputRefs)
324 catexps = task.build_catexps(butlerQC, inputRefs, inputs)
325 catexps = [task_fit.make_CatalogExposurePsfs(catexp) for catexp in catexps]
326 cat_output: astropy.table.Table = butler.get(outputRefs.cat_output, storageClass="ArrowAstropy")
327 return cls(
328 catexps=catexps,
329 task_fit=task_fit,
330 catalog_multi=inputs["cat_ref"],
331 fit_results=cat_output,
332 )
334 def make_config_data(self):
335 """Make a ConfigData object out of self's channels and fit task
336 config.
337 """
338 config_data = CatalogSourceFitterConfigData(channels=self.channels, config=self.task_fit.config)
339 return config_data
341 def make_model(
342 self,
343 idx_row: int,
344 config_data: CatalogSourceFitterConfigData = None,
345 init: bool = True,
346 ) -> g2f.ModelD:
347 """Make a ModelD for a single row from the originally fitted catalog.
349 Parameters
350 ----------
351 idx_row
352 The index of the row to make a model for.
353 config_data
354 The model configuration data object.
355 init
356 Whether to initialize the model parameters as they would have been
357 prior to fitting.
359 Returns
360 -------
361 model
362 The rebuilt model.
363 """
364 if config_data is None:
365 config_data = self.config_data
366 model = self.task_fit.get_model(
367 idx_row=idx_row,
368 catalog_multi=self.catalog_multi,
369 catexps=self.catexps,
370 config_data=config_data,
371 results=self.fit_results,
372 set_flux_limits=False,
373 )
374 if init:
375 self.set_model(idx_row, config_data)
376 return model
378 def set_model(self, idx_row: int, config_data: CatalogSourceFitterConfigData = None) -> None:
379 """Set model parameters to the best-fit values for a given row.
381 Parameters
382 ----------
383 idx_row
384 The index of the row in the fit parameter table to initialize from.
385 config_data
386 The model configuration data object.
387 """
388 if config_data is None:
389 config_data = self.config_data
390 row = self.fit_results[idx_row]
391 prefix = config_data.config.prefix_column
392 offsets = {}
393 offset_cen = config_data.config.centroid_pixel_offset
394 if offset_cen != 0:
395 offsets[g2f.CentroidXParameterD] = -offset_cen
396 offsets[g2f.CentroidYParameterD] = -offset_cen
397 for key, param in config_data.parameters.items():
398 param.value = row[f"{prefix}{key}"] + offsets.get(type(param), 0.0)
401class PatchModelMatches(pydantic.BaseModel):
402 """Storage for MultiProFit tables matched to a reference catalog."""
404 model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)
406 matches: astropy.table.Table | None = pydantic.Field(doc="Catalogs of matches")
407 quantumgraph: QuantumGraph | None = pydantic.Field(doc="Quantum graph for fit task")
408 rebuilder: DataLoader | ModelRebuilder | None = pydantic.Field(doc="MultiProFit object model rebuilder")
411class PatchCoaddRebuilder(pydantic.BaseModel):
412 """A rebuilder for patch-level coadd catalog/exposure fits."""
414 model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)
416 matches: dict[str, PatchModelMatches] = pydantic.Field("Model matches by algorithm name")
417 name_model_ref: str = pydantic.Field(doc="The name of the reference model in matches")
418 objects: astropy.table.Table = pydantic.Field(doc="Object table")
419 objects_multiprofit: astropy.table.Table | None = pydantic.Field(doc="Object table for MultiProFit fits")
420 reference: astropy.table.Table = pydantic.Field(doc="Reference object table")
422 skymap: str = pydantic.Field(doc="The skymap name")
423 tract: int = pydantic.Field(doc="The tract index")
424 patch: int = pydantic.Field(doc="The patch index")
426 @classmethod
427 def from_butler(
428 cls,
429 butler: dafButler.Butler,
430 skymap: str,
431 tract: int,
432 patch: int,
433 collection_merged: str,
434 matches: dict[str, QuantumGraph | None],
435 bands: Iterable[str] = None,
436 name_model_ref: str = None,
437 format_collection: str = "{run}",
438 load_multiprofit: bool = True,
439 dataset_type_ref: str = "truth_summary",
440 ):
441 """Construct a PatchCoaddRebuilder from a single Butler collection.
443 Parameters
444 ----------
445 butler
446 The butler to load from.
447 skymap
448 The skymap for the collection.
449 tract
450 The skymap tract id.
451 patch
452 The skymap patch id.
453 collection_merged
454 The name of the collection with the merged objectTable(s).
455 matches
456 A dictionary of model names with corresponding QuantumGraphs.
457 These may be None but must be provided for MultiProFit model
458 reconstruction to be possible.
459 bands
460 The list of bands to load data for.
461 name_model_ref
462 The name of the model to use as a reference. Must be a key in
463 `matches`.
464 format_collection
465 A format string for the output collection(s) defined in the
466 `matches` QuantumGraphs.
467 load_multiprofit
468 Whether to attempt to load an objectTable_tract_multiprofit.
469 dataset_type_ref
470 The dataset type of the reference catalog.
472 Returns
473 -------
474 rebuilder
475 The fully-configured PatchCoaddRebuilder.
476 """
477 if name_model_ref is None:
478 for name, quantumgraph in matches.items():
479 if quantumgraph is not None:
480 name_model_ref = name
481 break
482 if name_model_ref is None:
483 raise ValueError("Must supply name_model_ref or at least one matches with a quantumgraph")
484 dataId = dict(skymap=skymap, tract=tract, patch=patch)
485 objects = butler.get(
486 "objectTable_tract", collections=[collection_merged], storageClass="ArrowAstropy", **dataId
487 )
488 objects = objects[objects["patch"] == patch]
489 if load_multiprofit:
490 objects_multiprofit = butler.get(
491 "objectTable_tract_multiprofit",
492 collections=[collection_merged],
493 storageClass="ArrowAstropy",
494 **dataId,
495 )
496 objects_multiprofit = objects_multiprofit[objects_multiprofit["patch"] == patch]
497 else:
498 objects_multiprofit = None
499 reference = butler.get(
500 dataset_type_ref, collections=[collection_merged], storageClass="ArrowAstropy", **dataId
501 )
502 skymap_tract = butler.get(BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, skymap=skymap)[tract]
503 unit_coord_ref = get_radec_unit(reference, "ra", "dec", default=geom.degrees)
504 if "patch" not in reference.columns:
505 patches = find_patches(skymap_tract, reference["ra"], reference["dec"], unit=unit_coord_ref)
506 reference["patch"] = patches
507 elif reference["patch"].dtype != int:
508 # the ci_imsim truth_summary still has string patches
509 index_patch = skymap_tract[patch].index
510 str_patch = f"{index_patch.y},{index_patch.x}"
511 reference = reference[
512 (reference["patch"] == str_patch) & (reference["is_unique_truth_entry"] == True) # noqa: E712
513 ]
514 del reference["patch"]
515 reference["patch"] = patch
516 reference = reference[reference["patch"] == patch]
517 points = skymap_tract.wcs.skyToPixel(
518 [geom.SpherePoint(row["ra"], row["dec"], units=geom.degrees) for row in reference]
519 )
520 reference["x"] = [point.x for point in points]
521 reference["y"] = [point.y for point in points]
522 matches_name = {}
523 for name, quantumgraph in matches.items():
524 is_mpf = quantumgraph is not None
525 matched = butler.get(
526 f"matched_{dataset_type_ref}_objectTable_tract{'_multiprofit' if is_mpf else ''}",
527 collections=[
528 (
529 format_collection.format(run=quantumgraph.metadata["output"], name=name)
530 if is_mpf
531 else collection_merged
532 )
533 ],
534 storageClass="ArrowAstropy",
535 **dataId,
536 )
537 # unmatched ref objects don't have a patch set
538 # should probably be fixed in diff_matched
539 # but need to decide priority on matched - ref first? or target?
540 unit_coord_ref = get_radec_unit(
541 matched,
542 "refcat_ra",
543 "refcat_dec",
544 default=geom.degrees,
545 )
546 unmatched = (
547 matched["patch"].mask if np.ma.is_masked(matched["patch"]) else ~(matched["patch"] >= 0)
548 ) & np.isfinite(matched["refcat_ra"])
549 patches_unmatched = find_patches(
550 skymap_tract,
551 matched["refcat_ra"][unmatched],
552 matched["refcat_dec"][unmatched],
553 unit=unit_coord_ref,
554 )
555 matched["patch"][np.where(unmatched)[0]] = patches_unmatched
556 matched = matched[matched["patch"] == patch]
557 rebuilder = (
558 ModelRebuilder.from_quantumGraph(butler, quantumgraph, dataId=dataId)
559 if is_mpf
560 else DataLoader.from_butler(
561 butler, data_id=dataId, bands=bands, collections=[collection_merged]
562 )
563 )
564 matches_name[name] = PatchModelMatches(
565 matches=matched, quantumgraph=quantumgraph, rebuilder=rebuilder
566 )
567 return cls(
568 matches=matches_name,
569 objects=objects,
570 objects_multiprofit=objects_multiprofit,
571 reference=reference,
572 skymap=skymap,
573 tract=tract,
574 patch=patch,
575 name_model_ref=name_model_ref,
576 )