Coverage for python / lsst / images / fields / _product.py: 36%

80 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 08:36 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("ProductField", "ProductFieldSerializationModel") 

15 

16from collections.abc import Iterable 

17from typing import TYPE_CHECKING, Any, Literal, final 

18 

19import astropy.units 

20import numpy as np 

21import pydantic 

22 

23from .._geom import Bounds, Box 

24from .._image import Image 

25from ..serialization import ArchiveTree, InputArchive, OutputArchive 

26from ._base import BaseField 

27 

28if TYPE_CHECKING: 

29 try: 

30 from lsst.afw.math import ProductBoundedField as LegacyProductBoundedField 

31 except ImportError: 

32 type LegacyProductBoundedField = Any # type: ignore[no-redef] 

33 

34 from ._concrete import Field, FieldSerializationModel 

35 

36 

37@final 

38class ProductField(BaseField): 

39 """A field that multiplies other fields lazily. 

40 

41 Parameters 

42 ---------- 

43 operands : `~collections.abc.Iterable` [ `BaseField` ] 

44 The fields to multiply together. 

45 """ 

46 

47 def __init__(self, operands: Iterable[Field]): 

48 self._operands = tuple(operands) 

49 if not self._operands: 

50 raise ValueError("At least one operand must be provided.") 

51 iterator = iter(self._operands) 

52 first = next(iterator) 

53 self._bounds = first.bounds 

54 self._unit = first.unit 

55 for operand in iterator: 

56 self._bounds = self._bounds.intersection(operand.bounds) 

57 if operand.unit is not None: 

58 if self._unit is None: 

59 self._unit = operand.unit 

60 else: 

61 self._unit *= operand.unit 

62 

63 @property 

64 def bounds(self) -> Bounds: 

65 return self._bounds 

66 

67 @property 

68 def unit(self) -> astropy.units.UnitBase | None: 

69 return self._unit 

70 

71 @property 

72 def operands(self) -> tuple[Field, ...]: 

73 """The fields that are multiplied together 

74 (`tuple` [`BaseField`, ...]). 

75 """ 

76 return self._operands 

77 

78 def evaluate( 

79 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

80 ) -> np.ndarray | astropy.units.Quantity: 

81 iterator = iter(self._operands) 

82 first = next(iterator) 

83 result = first(x=x, y=y, quantity=False) 

84 for operand in iterator: 

85 result *= operand(x=x, y=y, quantity=False) 

86 if quantity: 

87 return result * self.unit 

88 return result 

89 

90 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

91 if bbox is None: 

92 bbox = self.bounds.bbox 

93 result = Image(1.0, bbox=bbox, dtype=dtype, unit=self.unit) 

94 for operand in self._operands: 

95 result.array *= operand.render(bbox, dtype=dtype).array 

96 return result 

97 

98 def multiply_constant( 

99 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase 

100 ) -> ProductField: 

101 new_operands = list(self._operands[:-1]) 

102 new_operands.append(self._operands[-1] * factor) 

103 return ProductField(new_operands) 

104 

105 def serialize(self, archive: OutputArchive[Any]) -> ProductFieldSerializationModel: 

106 """Serialize the field to an output archive.""" 

107 return ProductFieldSerializationModel( 

108 operands=[operand.serialize(archive) for operand in self._operands] 

109 ) 

110 

111 @staticmethod 

112 def deserialize(model: ProductFieldSerializationModel, archive: InputArchive[Any]) -> ProductField: 

113 """Deserialize the field from an input archive.""" 

114 from ._concrete import deserialize_field 

115 

116 return ProductField([deserialize_field(operand, archive) for operand in model.operands]) 

117 

118 @staticmethod 

119 def _get_archive_tree_type( 

120 pointer_type: type[Any], 

121 ) -> type[ProductFieldSerializationModel]: 

122 """Return the serialization model type for this object for an archive 

123 type that uses the given pointer type. 

124 """ 

125 return ProductFieldSerializationModel 

126 

127 @staticmethod 

128 def from_legacy( 

129 legacy: LegacyProductBoundedField, unit: astropy.units.UnitBase | None = None 

130 ) -> ProductField: 

131 """Convert from a legacy `lsst.afw.math.ProductBoundedField`.""" 

132 from ._concrete import field_from_legacy 

133 

134 legacy_factors = legacy.getFactors() 

135 operands = [field_from_legacy(f) for f in legacy_factors[:-1]] 

136 operands.append(field_from_legacy(legacy_factors[-1], unit=unit)) 

137 return ProductField(operands) 

138 

139 def to_legacy(self) -> LegacyProductBoundedField: 

140 """Convert to a legacy `lsst.afw.math.ProductBoundedField`.""" 

141 from lsst.afw.math import ProductBoundedField 

142 

143 # Not all Field types have a to_legacy, since they don't all have an 

144 # afw analog. But we just let that "no method" exception propagate. 

145 return ProductBoundedField( 

146 [operand.to_legacy() for operand in self._operands] # type: ignore[union-attr] 

147 ) 

148 

149 

150class ProductFieldSerializationModel(ArchiveTree): 

151 """Serialization model for `ProductField`.""" 

152 

153 operands: list[FieldSerializationModel] = pydantic.Field(default_factory=list) 

154 

155 field_type: Literal["PRODUCT"] = "PRODUCT" 

156 

157 def finish_deserialize(self, archive: InputArchive) -> ProductField: 

158 return ProductField.deserialize(self, archive)