Coverage for python / lsst / meas / photoz / base / estimate_photoz_task.py: 48%
177 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:49 +0000
1# This file is part of meas_photoz_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "EstimatePhotozAlgoConfigBase",
26 "EstimatePhotozAlgoTask",
27 "EstimatePhotozTask",
28 "EstimatePhotozTaskConfig",
29 "photozAlgoRegistry",
30]
32from abc import ABC, abstractmethod
33from typing import Any
35import numpy as np
36from astropy.table import Table
37from ceci.config import StageConfig as CeciStageConfig
38from ceci.config import StageParameter as CeciParam
39from rail.core.model import Model
40from rail.estimation.estimator import CatEstimator
41from rail.interfaces import PZFactory
43import lsst.pex.config as pexConfig
44import lsst.pipe.base.connectionTypes as cT
45from lsst.pipe.base import (
46 PipelineTask,
47 PipelineTaskConfig,
48 PipelineTaskConnections,
49 Struct,
50 Task,
51)
54class EstimatePhotozConnections(
55 PipelineTaskConnections,
56 dimensions=("skymap", "tract"),
57 defaultTemplates={"algo": "trainz"},
58):
59 """Connections for tasks that make p(z) estimates.
61 These will take pickled model file as a "calibration-like" input,
62 an object table as input, and create a p(z) file in 'qp' format.
63 """
65 photoz_model = cT.PrerequisiteInput(
66 doc="Model for PZ Estimation",
67 name="photoz_model_{algo}",
68 storageClass="PhotozModel",
69 dimensions=["instrument"],
70 isCalibration=True,
71 )
72 objects = cT.Input(
73 doc="Object table",
74 name="object",
75 storageClass="ArrowAstropy",
76 dimensions=("skymap", "tract"),
77 deferLoad=True,
78 )
80 photoz_ensemble = cT.Output(
81 doc="Per-object p(z) estimates",
82 name="photoz_ensemble_{algo}",
83 storageClass="QPEnsemble",
84 dimensions=("skymap", "tract"),
85 )
88class EstimatePhotozAlgoConfigBase(
89 pexConfig.Config,
90):
91 """Base class for configurations of algorithm-specific p(z)
92 estimation tasks.
94 This class mostly just translates the RAIL configuration
95 parameters to pex.config parameters.
97 Subclasses will just have to set `estimator_class` and `stage_name`
98 and invoke `_make_fields` once in the module.
99 """
101 @classmethod
102 @abstractmethod
103 def estimator_class(cls) -> type[CatEstimator]:
104 """Return the type of the estimator's RAIL class."""
105 raise NotImplementedError("Subclasses must specify an estimator class")
107 # This should be a property but py3.13+ don't allow it
108 @classmethod
109 @abstractmethod
110 def stage_name(cls) -> str:
111 """Return the RAIL stage name for the estimator."""
112 raise NotImplementedError("Subclasses must define a RAIL stage name")
114 # Extinction coefficients; see https://ui.adsabs.harvard.edu/abs/1989ApJ...345..245C/abstract
115 # Also in rail.utils.catalog_utils.RubinCatalogConfig.a_env
116 default_a_env_values = dict(
117 u=4.81,
118 g=3.64,
119 r=2.70,
120 i=2.06,
121 z=1.58,
122 y=1.31,
123 )
125 # These appear in many DESC repos, for example:
126 # https://github.com/LSSTDESC/TXPipe/blob/00ebe7476fd5d9529f5bbc4d73fcef0629d134c7/examples/dp0.2/config.yml#L47
127 # They seem to be 10y WFD limits. Origin unclear.
128 default_mag_limit_10y_values = dict(
129 u=27.79,
130 g=29.04,
131 r=29.06,
132 i=28.62,
133 z=27.98,
134 y=27.05,
135 )
137 # These appear to be from Roman-Rubin simulations:
138 # https://github.com/LSSTDESC/rail_base/blob/v1.2.1/src/rail/utils/catalog_utils.py#L207
139 # Presumably max 5y depth, and more useful for now
140 default_mag_limit_values = dict(
141 u=24.0,
142 g=27.66,
143 r=27.25,
144 i=26.6,
145 z=26.24,
146 y=25.35,
147 )
149 def get_band_a_env_dict(self):
150 """Return the set of a_envs to use."""
151 return {band_: self.default_a_env_values[band_] for band_ in self.bands_to_convert}
153 def get_mag_lim_dict(self):
154 """Return the set of maglims to use."""
155 return {
156 self.mag_template.format(band=band): self.default_mag_limit_values[band]
157 for band in self.bands_to_convert
158 }
160 def get_flux_names(self) -> dict[str, str]:
161 """Return a dict mapping band to flux column name."""
162 return {band: self.flux_column_template.format(band=band) for band in self.bands_to_convert}
164 def get_flux_err_names(self) -> dict[str, str]:
165 """Return a dict mapping band to flux error column name."""
166 return {band: self.flux_err_column_template.format(band=band) for band in self.bands_to_convert}
168 def get_mag_names(self) -> dict[str, str]:
169 """Return a dict mapping band to mag column name."""
170 return {band: self.mag_template.format(band=band) for band in self.bands_to_convert}
172 def get_mag_err_names(self) -> dict[str, str]:
173 """Return a dict mapping band to mag error column name."""
174 return {band: self.mag_err_template.format(band=band) for band in self.bands_to_convert}
176 mag_offset = pexConfig.Field(doc="Magnitude offset", dtype=float, default=31.4)
177 deredden = pexConfig.Field[bool](
178 doc="Apply dereddening",
179 default=True,
180 )
181 band_ref = pexConfig.Field[str](
182 doc="Name of the most reliable reference band, if needed",
183 default="i",
184 )
185 bands_to_convert = pexConfig.ListField[str](
186 doc="Names of bands to convert fluxs to mags for RAIL",
187 default=["u", "g", "r", "i", "z", "y"],
188 )
189 flux_column_template = pexConfig.Field[str](
190 doc="Template for flux column names",
191 default="{band}_gaap1p0Flux",
192 # default="{band}_cModelFlux",
193 )
194 flux_err_column_template = pexConfig.Field[str](
195 doc="Template for flux error column names",
196 default="{band}_gaap1p0FluxErr",
197 # default="{band}_cModelFluxErr",
198 )
199 mag_template = pexConfig.Field[str](
200 doc="Template for magnitude names",
201 default="{band}_gaap1p0Mag",
202 # default="{band}_cModelMag",
203 )
204 mag_err_template = pexConfig.Field[str](
205 doc="Template for magntitude error names",
206 default="{band}_gaap1p0MagErr",
207 # default="{band}_cModelMagErr",
208 )
209 nondetect_val = pexConfig.Field[float](
210 doc="Magnitude to set for non-detections",
211 default=np.nan,
212 )
213 band_a_env = pexConfig.DictField[str, float](
214 doc="Reddening parameters",
215 default=default_a_env_values,
216 )
218 def freeze(self):
219 if not self._frozen:
220 self._finalize()
221 super().freeze()
223 def _finalize(self):
224 # These calls will fail if it's already frozen.
225 if hasattr(self, "ref_band"):
226 self.ref_band = self.mag_template.format(band=self.band_ref)
227 if hasattr(self, "bands"):
228 # This is a list of mag columns in RAIL, not bands
229 self.bands = list(self.get_mag_names().values())
230 if hasattr(self, "err_bands"):
231 self.err_bands = list(self.get_mag_err_names().values())
232 if hasattr(self, "mag_limits"):
233 self.mag_limits = self.get_mag_lim_dict()
234 if hasattr(self, "band_a_env"):
235 self.band_a_env = self.get_band_a_env_dict()
237 @classmethod
238 def _make_fields(cls) -> None:
239 """Import the RAIL estimation stage.
241 This method loops through the stage config parameters and converts
242 RAIL/Ceci parameters to corresponding pex.config parameters.
244 It should be called exactly once, immediately after the definition
245 of every subclass of this base class.
246 """
247 if hasattr(cls, "__fields_made__"): 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true
248 if cls.__fields_made__ is not True:
249 raise RuntimeError(f"{cls.__fields_made__=} exists but is not True")
250 raise RuntimeError(f"{cls=} called _make_fields twice")
251 stage_class = cls.estimator_class()
252 for key, val in stage_class.config_options.items():
253 if isinstance(val, CeciStageConfig):
254 val = val.get(key)
255 if isinstance(val, CeciParam): 255 ↛ 252line 255 didn't jump to line 252 because the condition on line 255 was always true
256 if val.dtype in [bool, int, float, str]:
257 if (attr := getattr(cls, key, None)) is not None:
258 if not isinstance(attr, pexConfig.Field): 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true
259 raise RuntimeError(f"{cls=} {key=} exists but is of {type(key)=}, not Field")
260 elif attr.dtype != val.dtype: 260 ↛ 261line 260 didn't jump to line 261 because the condition on line 260 was never true
261 raise RuntimeError(f"{cls=} {key=} exists but {attr.dtype=} != {val.dtype=}")
262 attr.default = val.default
263 attr.doc = f"{val.msg} (overriding base doc='{attr.doc}')"
264 else:
265 setattr(
266 cls,
267 key,
268 pexConfig.Field(doc=val.msg, dtype=val.dtype, default=val.default),
269 )
270 elif val.dtype in [list]:
271 # this is a hack, but it works.
272 if val.default:
273 item_type = type(val.default[0])
274 else:
275 item_type = str
276 setattr(
277 cls,
278 key,
279 pexConfig.ListField(doc=val.msg, dtype=item_type, default=val.default),
280 )
281 elif val.dtype in [dict]: 281 ↛ 252line 281 didn't jump to line 252 because the condition on line 281 was always true
282 setattr(
283 cls,
284 key,
285 pexConfig.DictField(doc=val.msg, keytype=str, default=val.default),
286 )
287 cls.__fields_made__ = True
290photozAlgoRegistry = pexConfig.makeRegistry(
291 doc="A registry of photometric redshift estimation algorithm subtasks",
292)
295class EstimatePhotozAlgoTask(Task, ABC):
296 """Task for algorithm-specific p(z) estimation.
298 This provides almost all of the functionality
299 needed to run RAIL p(z) algorithms.
301 Parameters
302 ----------
303 **kwargs
304 Additional keyword arguments to pass to super().__init__.
305 """
307 ConfigClass = EstimatePhotozAlgoConfigBase
309 mag_conv = np.log(10) * 0.4
311 def __init__(self, **kwargs: Any):
312 super().__init__(**kwargs)
314 @staticmethod
315 def _flux_to_mag(
316 flux_vals: np.ndarray,
317 mag_offset: float,
318 nondetect_val: float,
319 ) -> np.ndarray:
320 """Convert flux to magnitude.
322 Parameters
323 ----------
324 flux_vals : np.array
325 Input flux values (units?)
327 mag_offset : float
328 Magnitude offset (corresponding to a flux of 1.)
330 nondetect_val : float
331 Value to set for non-detections
333 Returns
334 -------
335 mags : np.array
336 Magnitude values
337 """
338 vals = np.empty_like(flux_vals)
339 positive = flux_vals >= 0
340 vals[positive] = -2.5 * np.log10(flux_vals[positive]) + mag_offset
341 vals[~positive] = nondetect_val
342 return vals
344 @staticmethod
345 def _flux_err_to_mag_err(
346 flux_vals: np.ndarray,
347 flux_err_vals: np.ndarray,
348 mag_conv: float,
349 nondetect_val: float = np.nan,
350 ) -> np.ndarray:
351 """Config flux error to magnitude error.
353 Parameters
354 ----------
355 flux_vals : np.array
356 Input flux values (units?)
358 flux_err_vals : np.array
359 Input flux errors (units?)
361 mag_conv : float
362 Magnitude to flux conversion (typically np.log(10)*0.4)
364 nondetect_val : float
365 Value to set for non-detections
367 Returns
368 -------
369 mags_errs : np.array
370 Magnitude errors
371 """
372 vals = np.empty_like(flux_vals)
373 positive = flux_vals >= 0
374 vals[positive] = flux_err_vals[positive] / (flux_vals[positive] * mag_conv)
375 vals[~positive] = nondetect_val
376 return vals
378 @staticmethod
379 def _deredden_mags(
380 data: dict[str, np.ndarray],
381 a_env_dict: dict[str, float],
382 mag_names: dict[str, str],
383 nondetect_val: float,
384 ) -> dict[str, np.ndarray]:
385 """Deredden the magnitdues
387 Parameters
388 ----------
389 data: dict[str, np.array]
390 Input data
392 a_env_dict: dict[str, float],
393 Redenning parameters for bands
395 mag_names: dict[str, str]
396 Mapping from bands to magnitudes
398 nondetect_val : float
399 Value to set for non-detections
401 Returns
402 -------
403 mags: dict[str, np.array]
404 Udpated dict with dereddened mags
405 """
406 ebv = data["ebv"]
407 for band_, a_env_ in a_env_dict.items():
408 mag_name = mag_names[band_]
409 raw_mag = data[mag_name]
410 dered_mag = np.where(
411 np.isfinite(raw_mag),
412 raw_mag - ebv * a_env_,
413 nondetect_val,
414 )
415 data[mag_name] = dered_mag
416 return data
418 def _get_mags_and_errs(
419 self,
420 fluxes: Table,
421 mag_offset: float,
422 ) -> dict[str, np.ndarray]:
423 """Fill and return a numpy dict with mags and mag errors.
425 Parameters
426 ----------
427 fluxes : Table
428 Input fluxes and flux errors
430 mag_offset : float
431 Magnitude offset (corresponding to a flux of 1.)
433 Returns
434 -------
435 mags: dict[str, np.array]
436 Numpy dict with mags and mag errors
437 """
438 # get all the column names we will use
439 flux_names = self.config.get_flux_names()
440 mag_names = self.config.get_mag_names()
441 flux_err_names = self.config.get_flux_err_names()
442 mag_err_names = self.config.get_mag_err_names()
443 nondetect_val = self.config.nondetect_val
444 # output dict
445 mag_dict = {}
446 # loop over bands, make mags and mag errors and fill dict
447 for band in flux_names.keys():
448 fluxVals = np.asarray(fluxes[flux_names[band]])
449 fluxErrVals = np.asarray(fluxes[flux_err_names[band]])
450 mag_dict[mag_names[band]] = self._flux_to_mag(
451 fluxVals,
452 mag_offset,
453 nondetect_val,
454 )
455 if flux_err_names:
456 mag_dict[mag_err_names[band]] = self._flux_err_to_mag_err(
457 fluxVals,
458 fluxErrVals,
459 self.mag_conv,
460 nondetect_val,
461 )
463 # return the dict with the mags
464 return mag_dict
466 def init(
467 self,
468 photoz_model: Model,
469 ) -> None:
470 """Set up the RAIL stage to compute photo-zs.
472 Parameters
473 ----------
474 photoz_model : Model
475 Model used by the p(z) estimation algorithm.
476 """
477 # pop the pipeline task config options
478 # so that we can pass the rest to RAIL
479 rail_kwargs = self.config.toDict().copy()
480 for key in ["saveLogOutput", "stage_name", "mag_offset", "connections"]:
481 rail_kwargs.pop(key, None)
482 rail_kwargs["output_mode"] = "return"
484 # Build the RAIL stage
485 self._stage = PZFactory.build_stage_instance(
486 self.config.stage_name(),
487 self.config.estimator_class(),
488 model_path=photoz_model.data,
489 input_path="dummy.in",
490 **rail_kwargs,
491 )
492 self._stage._initialize_run()
494 def col_names(
495 self,
496 ) -> list[str]:
497 """Get the list of column names to read from the input data."""
498 columns = list(self.config.get_flux_names().values()) + list(
499 self.config.get_flux_err_names().values()
500 )
501 if self.config.deredden:
502 columns += ["ebv"]
504 return columns
506 def run(
507 self,
508 fluxes: Table,
509 ) -> Struct:
510 """Run a p(z) estimation algorithm.
512 Parameters
513 ----------
514 fluxes : Table
515 Fluxes used to compute the redshifts.
517 Returns
518 -------
519 photoz_pdfs : qp.Ensemble
520 Object with the p(z) PDFs.
521 """
522 n_obj = len(fluxes)
523 # Convert fluxes to mags
524 mags = self._get_mags_and_errs(fluxes, self.config.mag_offset)
525 nondetect_val = self.config.nondetect_val if hasattr(self.config, "nondetect_val") else np.nan
527 # De-redden
528 if self.config.deredden:
529 # asarray will convert an astropy column to an array w/o units
530 mags["ebv"] = np.asarray(fluxes["ebv"])
531 mags = self._deredden_mags(
532 mags,
533 self.config.band_a_env,
534 self.config.get_mag_names(),
535 nondetect_val,
536 )
538 # Pass the mags to RAIL and get back the p(z) pdfs
539 # as a qp.Ensemble object
540 photoz_pdfs = PZFactory.estimate_single_pz(self._stage, mags, n_obj)
541 return Struct(photoz_ensemble=photoz_pdfs)
544class EstimatePhotozTaskConfig(PipelineTaskConfig, pipelineConnections=EstimatePhotozConnections):
545 """Configuration for EstimatePhotozTask PipelineTask."""
547 photoz_algo = photozAlgoRegistry.makeField(
548 doc="Algorithm specific configuration p(z) estimation task",
549 )
552class EstimatePhotozTask(PipelineTask):
553 """PipelineTask for p(z) estimation.
555 Parameters
556 ----------
557 initInputs
558 Initialization inputs to pass to super().__init__.
559 **kwargs
560 Additional keyword arguments to pass to super().__init__.
561 """
563 ConfigClass = EstimatePhotozTaskConfig
564 _DefaultName = "estimatePhotoz"
566 def __init__(self, initInputs: dict, **kwargs):
567 super().__init__(initInputs=initInputs, **kwargs)
568 self._initialized = False
569 self.makeSubtask("photoz_algo")
571 def runQuantum(self, butlerQC, inputRefs, outputRefs):
572 inputs = butlerQC.get(inputRefs)
573 inputs["fluxes"] = inputs.pop("objects").get(
574 parameters=dict(columns=self.photoz_algo.col_names()),
575 )
576 outputs = self.run(**inputs, skip_init=self._initialized)
577 butlerQC.put(outputs, outputRefs)
579 def run(
580 self,
581 *,
582 photoz_model: Model,
583 fluxes: Table,
584 skip_init: bool = False,
585 ) -> Struct:
586 if not skip_init:
587 self._initialized = True
588 self.photoz_algo.init(photoz_model)
590 ret_struct = self.photoz_algo.run(fluxes)
591 return Struct(photoz_ensemble=ret_struct.photoz_ensemble)