Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 20%
153 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-03 03:42 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-03 03:42 -0700
1# This file is part of ap_association
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("TransformDiaSourceCatalogConnections",
23 "TransformDiaSourceCatalogConfig",
24 "TransformDiaSourceCatalogTask",
25 "UnpackApdbFlags")
27import numpy as np
28import os
29import yaml
31from lsst.daf.base import DateTime
32import lsst.pex.config as pexConfig
33import lsst.pipe.base as pipeBase
34import lsst.pipe.base.connectionTypes as connTypes
35from lsst.meas.base import DetectorVisitIdGeneratorConfig
36from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig
37from lsst.pipe.tasks.functors import Column
38from lsst.utils.timer import timeMethod
41class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections,
42 dimensions=("instrument", "visit", "detector"),
43 defaultTemplates={"coaddName": "deep", "fakesType": ""}):
44 diaSourceSchema = connTypes.InitInput(
45 doc="Schema for DIASource catalog output by ImageDifference.",
46 storageClass="SourceCatalog",
47 name="{fakesType}{coaddName}Diff_diaSrc_schema",
48 )
49 diaSourceCat = connTypes.Input(
50 doc="Catalog of DiaSources produced during image differencing.",
51 name="{fakesType}{coaddName}Diff_diaSrc",
52 storageClass="SourceCatalog",
53 dimensions=("instrument", "visit", "detector"),
54 )
55 diffIm = connTypes.Input(
56 doc="Difference image on which the DiaSources were detected.",
57 name="{fakesType}{coaddName}Diff_differenceExp",
58 storageClass="ExposureF",
59 dimensions=("instrument", "visit", "detector"),
60 )
61 diaSourceTable = connTypes.Output(
62 doc=".",
63 name="{fakesType}{coaddName}Diff_diaSrcTable",
64 storageClass="DataFrame",
65 dimensions=("instrument", "visit", "detector"),
66 )
69class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig,
70 pipelineConnections=TransformDiaSourceCatalogConnections):
71 flagMap = pexConfig.Field(
72 dtype=str,
73 doc="Yaml file specifying SciencePipelines flag fields to bit packs.",
74 default=os.path.join("${AP_ASSOCIATION_DIR}",
75 "data",
76 "association-flag-map.yaml"),
77 )
78 flagRenameMap = pexConfig.Field(
79 dtype=str,
80 doc="Yaml file specifying specifying rules to rename flag names",
81 default=os.path.join("${AP_ASSOCIATION_DIR}",
82 "data",
83 "flag-rename-rules.yaml"),
84 )
85 doRemoveSkySources = pexConfig.Field(
86 dtype=bool,
87 default=False,
88 doc="Input DiaSource catalog contains SkySources that should be "
89 "removed before storing the output DiaSource catalog."
90 )
91 doPackFlags = pexConfig.Field(
92 dtype=bool,
93 default=True,
94 doc="Do pack the flags into one integer column named 'flags'."
95 "If False, instead produce one boolean column per flag."
96 )
97 idGenerator = DetectorVisitIdGeneratorConfig.make_field()
99 def setDefaults(self):
100 super().setDefaults()
101 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}",
102 "data",
103 "DiaSource.yaml")
106class TransformDiaSourceCatalogTask(TransformCatalogBaseTask):
107 """Transform a DiaSource catalog by calibrating and renaming columns to
108 produce a table ready to insert into the Apdb.
110 Parameters
111 ----------
112 initInputs : `dict`
113 Must contain ``diaSourceSchema`` as the schema for the input catalog.
114 """
115 ConfigClass = TransformDiaSourceCatalogConfig
116 _DefaultName = "transformDiaSourceCatalog"
117 # Needed to create a valid TransformCatalogBaseTask, but unused
118 inputDataset = "deepDiff_diaSrc"
119 outputDataset = "deepDiff_diaSrcTable"
121 def __init__(self, initInputs, **kwargs):
122 super().__init__(**kwargs)
123 self.funcs = self.getFunctors()
124 self.inputSchema = initInputs['diaSourceSchema'].schema
125 self._create_bit_pack_mappings()
127 if not self.config.doPackFlags:
128 # get the flag rename rules
129 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream:
130 self.rename_rules = list(yaml.safe_load_all(yaml_stream))
132 def _create_bit_pack_mappings(self):
133 """Setup all flag bit packings.
134 """
135 self.bit_pack_columns = []
136 flag_map_file = os.path.expandvars(self.config.flagMap)
137 with open(flag_map_file) as yaml_stream:
138 table_list = list(yaml.safe_load_all(yaml_stream))
139 for table in table_list:
140 if table['tableName'] == 'DiaSource':
141 self.bit_pack_columns = table['columns']
142 break
144 # Test that all flags requested are present in the input schemas.
145 # Output schemas are flexible, however if names are not specified in
146 # the Apdb schema, flag columns will not be persisted.
147 for outputFlag in self.bit_pack_columns:
148 bitList = outputFlag['bitList']
149 for bit in bitList:
150 try:
151 self.inputSchema.find(bit['name'])
152 except KeyError:
153 raise KeyError(
154 "Requested column %s not found in input DiaSource "
155 "schema. Please check that the requested input "
156 "column exists." % bit['name'])
158 def runQuantum(self, butlerQC, inputRefs, outputRefs):
159 inputs = butlerQC.get(inputRefs)
160 idGenerator = self.config.idGenerator.apply(butlerQC.quantum.dataId)
161 inputs["ccdVisitId"] = idGenerator.catalog_id
162 inputs["band"] = butlerQC.quantum.dataId["band"]
164 outputs = self.run(**inputs)
166 butlerQC.put(outputs, outputRefs)
168 @timeMethod
169 def run(self,
170 diaSourceCat,
171 diffIm,
172 band,
173 ccdVisitId):
174 """Convert input catalog to ParquetTable/Pandas and run functors.
176 Additionally, add new columns for stripping information from the
177 exposure and into the DiaSource catalog.
179 Parameters
180 ----------
181 diaSourceCat : `lsst.afw.table.SourceCatalog`
182 Catalog of sources measured on the difference image.
183 diffIm : `lsst.afw.image.Exposure`
184 Result of subtracting template and science images.
185 band : `str`
186 Filter band of the science image.
187 ccdVisitId : `int`
188 Identifier for this detector+visit.
190 Returns
191 -------
192 results : `lsst.pipe.base.Struct`
193 Results struct with components.
195 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values
196 and renamed columns.
197 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`)
198 """
199 self.log.info(
200 "Transforming/standardizing the DiaSource table ccdVisitId: %i",
201 ccdVisitId)
203 diaSourceDf = diaSourceCat.asAstropy().to_pandas()
204 if self.config.doRemoveSkySources:
205 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]]
206 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]]
208 diaSourceDf["snr"] = getSignificance(diaSourceCat)
209 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat)
210 diaSourceDf["ccdVisitId"] = ccdVisitId
211 diaSourceDf["filterName"] = band
212 diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD)
213 diaSourceDf["diaObjectId"] = 0
214 diaSourceDf["ssObjectId"] = 0
216 if self.config.doPackFlags:
217 # either bitpack the flags
218 self.bitPackFlags(diaSourceDf)
219 else:
220 # or add the individual flag functors
221 self.addUnpackedFlagFunctors()
222 # and remove the packed flag functor
223 if 'flags' in self.funcs.funcDict:
224 del self.funcs.funcDict['flags']
226 df = self.transform(band,
227 diaSourceDf,
228 self.funcs,
229 dataId=None).df
231 return pipeBase.Struct(
232 diaSourceTable=df,
233 )
235 def addUnpackedFlagFunctors(self):
236 """Add Column functor for each of the flags to the internal functor
237 dictionary.
238 """
239 for flag in self.bit_pack_columns[0]['bitList']:
240 flagName = flag['name']
241 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules'])
242 self.funcs.update({targetName: Column(flagName)})
244 def computeBBoxSizes(self, inputCatalog):
245 """Compute the size of a square bbox that fully contains the detection
246 footprint.
248 Parameters
249 ----------
250 inputCatalog : `lsst.afw.table.SourceCatalog`
251 Catalog containing detected footprints.
253 Returns
254 -------
255 outputBBoxSizes : `np.ndarray`, (N,)
256 Array of bbox sizes.
257 """
258 # Schema validation requires that this field is int.
259 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int)
260 for i, record in enumerate(inputCatalog):
261 footprintBBox = record.getFootprint().getBBox()
262 # Compute twice the size of the largest dimension of the footprint
263 # bounding box. This is the largest footprint we should need to cover
264 # the complete DiaSource assuming the centroid is within the bounding
265 # box.
266 maxSize = 2 * np.max([footprintBBox.getWidth(),
267 footprintBBox.getHeight()])
268 recX = record.getCentroid().x
269 recY = record.getCentroid().y
270 bboxSize = int(
271 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX,
272 footprintBBox.minX - recX,
273 footprintBBox.maxY - recY,
274 footprintBBox.minY - recY]))))
275 if bboxSize > maxSize:
276 bboxSize = maxSize
277 outputBBoxSizes[i] = bboxSize
279 return outputBBoxSizes
281 def bitPackFlags(self, df):
282 """Pack requested flag columns in inputRecord into single columns in
283 outputRecord.
285 Parameters
286 ----------
287 df : `pandas.DataFrame`
288 DataFrame to read bits from and pack them into.
289 """
290 for outputFlag in self.bit_pack_columns:
291 bitList = outputFlag['bitList']
292 value = np.zeros(len(df), dtype=np.uint64)
293 for bit in bitList:
294 # Hard type the bit arrays.
295 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64)
296 df[outputFlag['columnName']] = value
299class UnpackApdbFlags:
300 """Class for unpacking bits from integer flag fields stored in the Apdb.
302 Attributes
303 ----------
304 flag_map_file : `str`
305 Absolute or relative path to a yaml file specifiying mappings of flags
306 to integer bits.
307 table_name : `str`
308 Name of the Apdb table the integer bit data are coming from.
309 """
311 def __init__(self, flag_map_file, table_name):
312 self.bit_pack_columns = []
313 flag_map_file = os.path.expandvars(flag_map_file)
314 with open(flag_map_file) as yaml_stream:
315 table_list = list(yaml.safe_load_all(yaml_stream))
316 for table in table_list:
317 if table['tableName'] == table_name:
318 self.bit_pack_columns = table['columns']
319 break
321 self.output_flag_columns = {}
323 for column in self.bit_pack_columns:
324 names = []
325 for bit in column["bitList"]:
326 names.append((bit["name"], bool))
327 self.output_flag_columns[column["columnName"]] = names
329 def unpack(self, input_flag_values, flag_name):
330 """Determine individual boolean flags from an input array of unsigned
331 ints.
333 Parameters
334 ----------
335 input_flag_values : array-like of type uint
336 Array of integer flags to unpack.
337 flag_name : `str`
338 Apdb column name of integer flags to unpack. Names of packed int
339 flags are given by the flag_map_file.
341 Returns
342 -------
343 output_flags : `numpy.ndarray`
344 Numpy named tuple of booleans.
345 """
346 bit_names_types = self.output_flag_columns[flag_name]
347 output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types)
349 for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types):
350 masked_bits = np.bitwise_and(input_flag_values, 2**bit_idx)
351 output_flags[bit_name] = masked_bits
353 return output_flags
355 def flagExists(self, flagName, columnName='flags'):
356 """Check if named flag is in the bitpacked flag set.
358 Parameters:
359 ----------
360 flagName : `str`
361 Flag name to search for.
362 columnName : `str`, optional
363 Name of bitpacked flag column to search in.
365 Returns
366 -------
367 flagExists : `bool`
368 `True` if `flagName` is present in `columnName`.
370 Raises
371 ------
372 ValueError
373 Raised if `columnName` is not defined.
374 """
375 if columnName not in self.output_flag_columns:
376 raise ValueError(f'column {columnName} not in flag map: {self.output_flag_columns}')
378 return flagName in [c[0] for c in self.output_flag_columns[columnName]]
380 def makeFlagBitMask(self, flagNames, columnName='flags'):
381 """Return a bitmask corresponding to the supplied flag names.
383 Parameters:
384 ----------
385 flagNames : `list` [`str`]
386 Flag names to include in the bitmask.
387 columnName : `str`, optional
388 Name of bitpacked flag column.
390 Returns
391 -------
392 bitmask : `np.unit64`
393 Bitmask corresponding to the supplied flag names given the loaded configuration.
395 Raises
396 ------
397 ValueError
398 Raised if a flag in `flagName` is not included in `columnName`.
399 """
400 bitmask = np.uint64(0)
402 for flag in flagNames:
403 if not self.flagExists(flag, columnName=columnName):
404 raise ValueError(f"flag '{flag}' not included in '{columnName}' flag column")
406 for outputFlag in self.bit_pack_columns:
407 if outputFlag['columnName'] == columnName:
408 bitList = outputFlag['bitList']
409 for bit in bitList:
410 if bit['name'] in flagNames:
411 bitmask += np.uint64(2**bit['bit'])
413 return bitmask
416def getSignificance(catalog):
417 """Return the significance value of the first peak in each source
418 footprint, or NaN for peaks without a significance field.
420 Parameters
421 ----------
422 catalog : `lsst.afw.table.SourceCatalog`
423 Catalog to process.
425 Returns
426 -------
427 significance : `np.ndarray`, (N,)
428 Signficance of the first peak in each source footprint.
429 """
430 result = np.full(len(catalog), np.nan)
431 for i, record in enumerate(catalog):
432 peaks = record.getFootprint().peaks
433 if "significance" in peaks.schema:
434 result[i] = peaks[0]["significance"]
435 return result