Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 20%
164 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 09:33 +0000
1# This file is part of ap_association
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("TransformDiaSourceCatalogConnections",
23 "TransformDiaSourceCatalogConfig",
24 "TransformDiaSourceCatalogTask",
25 "UnpackApdbFlags")
27import numpy as np
28import os
29import yaml
30import pandas as pd
32from lsst.daf.base import DateTime
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35import lsst.pipe.base.connectionTypes as connTypes
36from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig
37from lsst.pipe.tasks.functors import Column
38from lsst.utils.timer import timeMethod
41class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections,
42 dimensions=("instrument", "visit", "detector"),
43 defaultTemplates={"coaddName": "deep", "fakesType": ""}):
44 diaSourceSchema = connTypes.InitInput(
45 doc="Schema for DIASource catalog output by ImageDifference.",
46 storageClass="SourceCatalog",
47 name="{fakesType}{coaddName}Diff_diaSrc_schema",
48 )
49 diaSourceCat = connTypes.Input(
50 doc="Catalog of DiaSources produced during image differencing.",
51 name="{fakesType}{coaddName}Diff_candidateDiaSrc",
52 storageClass="SourceCatalog",
53 dimensions=("instrument", "visit", "detector"),
54 )
55 diffIm = connTypes.Input(
56 doc="Difference image on which the DiaSources were detected.",
57 name="{fakesType}{coaddName}Diff_differenceExp",
58 storageClass="ExposureF",
59 dimensions=("instrument", "visit", "detector"),
60 )
61 reliability = connTypes.Input(
62 doc="Reliability (e.g. real/bogus) classificiation of diaSourceCat sources (optional).",
63 name="{fakesType}{coaddName}RealBogusSources",
64 storageClass="Catalog",
65 dimensions=("instrument", "visit", "detector"),
66 )
67 diaSourceTable = connTypes.Output(
68 doc=".",
69 name="{fakesType}{coaddName}Diff_diaSrcTable",
70 storageClass="DataFrame",
71 dimensions=("instrument", "visit", "detector"),
72 )
74 def __init__(self, *, config=None):
75 super().__init__(config=config)
76 if not self.config.doIncludeReliability:
77 self.inputs.remove("reliability")
80class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig,
81 pipelineConnections=TransformDiaSourceCatalogConnections):
82 flagMap = pexConfig.Field(
83 dtype=str,
84 doc="Yaml file specifying SciencePipelines flag fields to bit packs.",
85 default=os.path.join("${AP_ASSOCIATION_DIR}",
86 "data",
87 "association-flag-map.yaml"),
88 )
89 flagRenameMap = pexConfig.Field(
90 dtype=str,
91 doc="Yaml file specifying specifying rules to rename flag names",
92 default=os.path.join("${AP_ASSOCIATION_DIR}",
93 "data",
94 "flag-rename-rules.yaml"),
95 )
96 doRemoveSkySources = pexConfig.Field(
97 dtype=bool,
98 default=False,
99 doc="Input DiaSource catalog contains SkySources that should be "
100 "removed before storing the output DiaSource catalog."
101 )
102 # TODO: remove on DM-41532
103 doPackFlags = pexConfig.Field(
104 dtype=bool,
105 default=False,
106 doc="Do pack the flags into one integer column named 'flags'."
107 "If False, instead produce one boolean column per flag.",
108 deprecated="This field is no longer used. Will be removed after v28."
109 )
110 doIncludeReliability = pexConfig.Field(
111 dtype=bool,
112 default=False,
113 doc="Include the reliability (e.g. real/bogus) classifications in the output."
114 )
116 def setDefaults(self):
117 super().setDefaults()
118 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}",
119 "data",
120 "DiaSource.yaml")
123class TransformDiaSourceCatalogTask(TransformCatalogBaseTask):
124 """Transform a DiaSource catalog by calibrating and renaming columns to
125 produce a table ready to insert into the Apdb.
127 Parameters
128 ----------
129 initInputs : `dict`
130 Must contain ``diaSourceSchema`` as the schema for the input catalog.
131 """
132 ConfigClass = TransformDiaSourceCatalogConfig
133 _DefaultName = "transformDiaSourceCatalog"
134 # Needed to create a valid TransformCatalogBaseTask, but unused
135 inputDataset = "deepDiff_diaSrc"
136 outputDataset = "deepDiff_diaSrcTable"
138 def __init__(self, initInputs, **kwargs):
139 super().__init__(**kwargs)
140 self.funcs = self.getFunctors()
141 self.inputSchema = initInputs['diaSourceSchema'].schema
142 self._create_bit_pack_mappings()
144 if not self.config.doPackFlags:
145 # get the flag rename rules
146 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream:
147 self.rename_rules = list(yaml.safe_load_all(yaml_stream))
149 def _create_bit_pack_mappings(self):
150 """Setup all flag bit packings.
151 """
152 self.bit_pack_columns = []
153 flag_map_file = os.path.expandvars(self.config.flagMap)
154 with open(flag_map_file) as yaml_stream:
155 table_list = list(yaml.safe_load_all(yaml_stream))
156 for table in table_list:
157 if table['tableName'] == 'DiaSource':
158 self.bit_pack_columns = table['columns']
159 break
161 # Test that all flags requested are present in the input schemas.
162 # Output schemas are flexible, however if names are not specified in
163 # the Apdb schema, flag columns will not be persisted.
164 for outputFlag in self.bit_pack_columns:
165 bitList = outputFlag['bitList']
166 for bit in bitList:
167 try:
168 self.inputSchema.find(bit['name'])
169 except KeyError:
170 raise KeyError(
171 "Requested column %s not found in input DiaSource "
172 "schema. Please check that the requested input "
173 "column exists." % bit['name'])
175 def runQuantum(self, butlerQC, inputRefs, outputRefs):
176 inputs = butlerQC.get(inputRefs)
177 inputs["band"] = butlerQC.quantum.dataId["band"]
179 outputs = self.run(**inputs)
181 butlerQC.put(outputs, outputRefs)
183 @timeMethod
184 def run(self,
185 diaSourceCat,
186 diffIm,
187 band,
188 reliability=None):
189 """Convert input catalog to ParquetTable/Pandas and run functors.
191 Additionally, add new columns for stripping information from the
192 exposure and into the DiaSource catalog.
194 Parameters
195 ----------
196 diaSourceCat : `lsst.afw.table.SourceCatalog`
197 Catalog of sources measured on the difference image.
198 diffIm : `lsst.afw.image.Exposure`
199 Result of subtracting template and science images.
200 band : `str`
201 Filter band of the science image.
202 reliability : `lsst.afw.table.SourceCatalog`
203 Reliability (e.g. real/bogus) scores, row-matched to
204 ``diaSourceCat``.
206 Returns
207 -------
208 results : `lsst.pipe.base.Struct`
209 Results struct with components.
211 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values
212 and renamed columns.
213 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`)
214 """
215 self.log.info(
216 "Transforming/standardizing the DiaSource table for visit,detector: %i, %i",
217 diffIm.visitInfo.id, diffIm.detector.getId())
219 diaSourceDf = diaSourceCat.asAstropy().to_pandas()
220 if self.config.doRemoveSkySources:
221 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]]
222 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]]
224 diaSourceDf["time_processed"] = DateTime.now().toPython()
225 diaSourceDf["snr"] = getSignificance(diaSourceCat)
226 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat)
227 diaSourceDf["visit"] = diffIm.visitInfo.id
228 # int16 instead of uint8 because databases don't like unsigned bytes.
229 diaSourceDf["detector"] = np.int16(diffIm.detector.getId())
230 diaSourceDf["band"] = band
231 diaSourceDf["midpointMjdTai"] = diffIm.visitInfo.date.get(system=DateTime.MJD)
232 diaSourceDf["diaObjectId"] = 0
233 diaSourceDf["ssObjectId"] = 0
235 if self.config.doIncludeReliability:
236 reliabilityDf = reliability.asAstropy().to_pandas()
237 # This uses the pandas index to match scores with diaSources
238 # but it will silently fill with NaNs if they don't match.
239 diaSourceDf = pd.merge(diaSourceDf, reliabilityDf,
240 how="left", on="id", validate="1:1")
241 diaSourceDf = diaSourceDf.rename(columns={"score": "reliability"})
242 if np.sum(diaSourceDf["reliability"].isna()) == len(diaSourceDf):
243 self.log.warning("Reliability identifiers did not match diaSourceIds")
244 else:
245 diaSourceDf["reliability"] = np.float32(np.nan)
247 if self.config.doPackFlags:
248 # either bitpack the flags
249 self.bitPackFlags(diaSourceDf)
250 else:
251 # or add the individual flag functors
252 self.addUnpackedFlagFunctors()
253 # and remove the packed flag functor
254 if 'flags' in self.funcs.funcDict:
255 del self.funcs.funcDict['flags']
257 df = self.transform(band,
258 diaSourceDf,
259 self.funcs,
260 dataId=None).df
262 return pipeBase.Struct(
263 diaSourceTable=df,
264 )
266 def addUnpackedFlagFunctors(self):
267 """Add Column functor for each of the flags to the internal functor
268 dictionary.
269 """
270 for flag in self.bit_pack_columns[0]['bitList']:
271 flagName = flag['name']
272 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules'])
273 self.funcs.update({targetName: Column(flagName)})
275 def computeBBoxSizes(self, inputCatalog):
276 """Compute the size of a square bbox that fully contains the detection
277 footprint.
279 Parameters
280 ----------
281 inputCatalog : `lsst.afw.table.SourceCatalog`
282 Catalog containing detected footprints.
284 Returns
285 -------
286 outputBBoxSizes : `np.ndarray`, (N,)
287 Array of bbox sizes.
288 """
289 # Schema validation requires that this field is int.
290 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int)
291 for i, record in enumerate(inputCatalog):
292 footprintBBox = record.getFootprint().getBBox()
293 # Compute twice the size of the largest dimension of the footprint
294 # bounding box. This is the largest footprint we should need to cover
295 # the complete DiaSource assuming the centroid is within the bounding
296 # box.
297 maxSize = 2 * np.max([footprintBBox.getWidth(),
298 footprintBBox.getHeight()])
299 recX = record.getCentroid().x
300 recY = record.getCentroid().y
301 bboxSize = int(
302 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX,
303 footprintBBox.minX - recX,
304 footprintBBox.maxY - recY,
305 footprintBBox.minY - recY]))))
306 if bboxSize > maxSize:
307 bboxSize = maxSize
308 outputBBoxSizes[i] = bboxSize
310 return outputBBoxSizes
312 def bitPackFlags(self, df):
313 """Pack requested flag columns in inputRecord into single columns in
314 outputRecord.
316 Parameters
317 ----------
318 df : `pandas.DataFrame`
319 DataFrame to read bits from and pack them into.
320 """
321 for outputFlag in self.bit_pack_columns:
322 bitList = outputFlag['bitList']
323 value = np.zeros(len(df), dtype=np.uint64)
324 for bit in bitList:
325 # Hard type the bit arrays.
326 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64)
327 df[outputFlag['columnName']] = value
330class UnpackApdbFlags:
331 """Class for unpacking bits from integer flag fields stored in the Apdb.
333 Attributes
334 ----------
335 flag_map_file : `str`
336 Absolute or relative path to a yaml file specifiying mappings of flags
337 to integer bits.
338 table_name : `str`
339 Name of the Apdb table the integer bit data are coming from.
340 """
342 def __init__(self, flag_map_file, table_name):
343 self.bit_pack_columns = []
344 flag_map_file = os.path.expandvars(flag_map_file)
345 with open(flag_map_file) as yaml_stream:
346 table_list = list(yaml.safe_load_all(yaml_stream))
347 for table in table_list:
348 if table['tableName'] == table_name:
349 self.bit_pack_columns = table['columns']
350 break
352 self.output_flag_columns = {}
354 for column in self.bit_pack_columns:
355 names = {}
356 for bit in column["bitList"]:
357 names[bit["name"]] = bit["bit"]
358 self.output_flag_columns[column["columnName"]] = names
360 def unpack(self, input_flag_values, flag_name):
361 """Determine individual boolean flags from an input array of unsigned
362 ints.
364 Parameters
365 ----------
366 input_flag_values : array-like of type uint
367 Array of integer packed bit flags to unpack.
368 flag_name : `str`
369 Apdb column name from the loaded file, e.g. "flags".
371 Returns
372 -------
373 output_flags : `numpy.ndarray`
374 Numpy structured array of booleans, one column per flag in the
375 loaded file.
376 """
377 output_flags = np.zeros(len(input_flag_values),
378 dtype=[(name, bool) for name in self.output_flag_columns[flag_name]])
380 for name in self.output_flag_columns[flag_name]:
381 masked_bits = np.bitwise_and(input_flag_values,
382 2**self.output_flag_columns[flag_name][name])
383 output_flags[name] = masked_bits
385 return output_flags
387 def flagExists(self, flagName, columnName='flags'):
388 """Check if named flag is in the bitpacked flag set.
390 Parameters:
391 ----------
392 flagName : `str`
393 Flag name to search for.
394 columnName : `str`, optional
395 Name of bitpacked flag column to search in.
397 Returns
398 -------
399 flagExists : `bool`
400 `True` if `flagName` is present in `columnName`.
402 Raises
403 ------
404 ValueError
405 Raised if `columnName` is not defined.
406 """
407 if columnName not in self.output_flag_columns:
408 raise ValueError(f'column {columnName} not in flag map: {self.output_flag_columns}')
410 return flagName in [c for c in self.output_flag_columns[columnName]]
412 def makeFlagBitMask(self, flagNames, columnName='flags'):
413 """Return a bitmask corresponding to the supplied flag names.
415 Parameters:
416 ----------
417 flagNames : `list` [`str`]
418 Flag names to include in the bitmask.
419 columnName : `str`, optional
420 Name of bitpacked flag column.
422 Returns
423 -------
424 bitmask : `np.unit64`
425 Bitmask corresponding to the supplied flag names given the loaded configuration.
427 Raises
428 ------
429 ValueError
430 Raised if a flag in `flagName` is not included in `columnName`.
431 """
432 bitmask = np.uint64(0)
434 for flag in flagNames:
435 if not self.flagExists(flag, columnName=columnName):
436 raise ValueError(f"flag '{flag}' not included in '{columnName}' flag column")
438 for outputFlag in self.bit_pack_columns:
439 if outputFlag['columnName'] == columnName:
440 bitList = outputFlag['bitList']
441 for bit in bitList:
442 if bit['name'] in flagNames:
443 bitmask += np.uint64(2**bit['bit'])
445 return bitmask
448def getSignificance(catalog):
449 """Return the significance value of the first peak in each source
450 footprint, or NaN for peaks without a significance field.
452 Parameters
453 ----------
454 catalog : `lsst.afw.table.SourceCatalog`
455 Catalog to process.
457 Returns
458 -------
459 significance : `np.ndarray`, (N,)
460 Signficance of the first peak in each source footprint.
461 """
462 result = np.full(len(catalog), np.nan)
463 for i, record in enumerate(catalog):
464 peaks = record.getFootprint().peaks
465 if "significance" in peaks.schema:
466 result[i] = peaks[0]["significance"]
467 return result