Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 19%
151 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-25 09:00 +0000
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-25 09:00 +0000
1# This file is part of ap_association
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("TransformDiaSourceCatalogConnections",
23 "TransformDiaSourceCatalogConfig",
24 "TransformDiaSourceCatalogTask",
25 "UnpackApdbFlags")
27import numpy as np
28import os
29import yaml
31from lsst.daf.base import DateTime
32import lsst.pex.config as pexConfig
33import lsst.pipe.base as pipeBase
34import lsst.pipe.base.connectionTypes as connTypes
35from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig
36from lsst.pipe.tasks.functors import Column
37from lsst.utils.timer import timeMethod
40class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections,
41 dimensions=("instrument", "visit", "detector"),
42 defaultTemplates={"coaddName": "deep", "fakesType": ""}):
43 diaSourceSchema = connTypes.InitInput(
44 doc="Schema for DIASource catalog output by ImageDifference.",
45 storageClass="SourceCatalog",
46 name="{fakesType}{coaddName}Diff_diaSrc_schema",
47 )
48 diaSourceCat = connTypes.Input(
49 doc="Catalog of DiaSources produced during image differencing.",
50 name="{fakesType}{coaddName}Diff_diaSrc",
51 storageClass="SourceCatalog",
52 dimensions=("instrument", "visit", "detector"),
53 )
54 diffIm = connTypes.Input(
55 doc="Difference image on which the DiaSources were detected.",
56 name="{fakesType}{coaddName}Diff_differenceExp",
57 storageClass="ExposureF",
58 dimensions=("instrument", "visit", "detector"),
59 )
60 diaSourceTable = connTypes.Output(
61 doc=".",
62 name="{fakesType}{coaddName}Diff_diaSrcTable",
63 storageClass="DataFrame",
64 dimensions=("instrument", "visit", "detector"),
65 )
68class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig,
69 pipelineConnections=TransformDiaSourceCatalogConnections):
70 flagMap = pexConfig.Field(
71 dtype=str,
72 doc="Yaml file specifying SciencePipelines flag fields to bit packs.",
73 default=os.path.join("${AP_ASSOCIATION_DIR}",
74 "data",
75 "association-flag-map.yaml"),
76 )
77 flagRenameMap = pexConfig.Field(
78 dtype=str,
79 doc="Yaml file specifying specifying rules to rename flag names",
80 default=os.path.join("${AP_ASSOCIATION_DIR}",
81 "data",
82 "flag-rename-rules.yaml"),
83 )
84 doRemoveSkySources = pexConfig.Field(
85 dtype=bool,
86 default=False,
87 doc="Input DiaSource catalog contains SkySources that should be "
88 "removed before storing the output DiaSource catalog."
89 )
90 doPackFlags = pexConfig.Field(
91 dtype=bool,
92 default=True,
93 doc="Do pack the flags into one integer column named 'flags'."
94 "If False, instead produce one boolean column per flag."
95 )
97 def setDefaults(self):
98 super().setDefaults()
99 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}",
100 "data",
101 "DiaSource.yaml")
104class TransformDiaSourceCatalogTask(TransformCatalogBaseTask):
105 """Transform a DiaSource catalog by calibrating and renaming columns to
106 produce a table ready to insert into the Apdb.
108 Parameters
109 ----------
110 initInputs : `dict`
111 Must contain ``diaSourceSchema`` as the schema for the input catalog.
112 """
113 ConfigClass = TransformDiaSourceCatalogConfig
114 _DefaultName = "transformDiaSourceCatalog"
115 # Needed to create a valid TransformCatalogBaseTask, but unused
116 inputDataset = "deepDiff_diaSrc"
117 outputDataset = "deepDiff_diaSrcTable"
119 def __init__(self, initInputs, **kwargs):
120 super().__init__(**kwargs)
121 self.funcs = self.getFunctors()
122 self.inputSchema = initInputs['diaSourceSchema'].schema
123 self._create_bit_pack_mappings()
125 if not self.config.doPackFlags:
126 # get the flag rename rules
127 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream:
128 self.rename_rules = list(yaml.safe_load_all(yaml_stream))
130 def _create_bit_pack_mappings(self):
131 """Setup all flag bit packings.
132 """
133 self.bit_pack_columns = []
134 flag_map_file = os.path.expandvars(self.config.flagMap)
135 with open(flag_map_file) as yaml_stream:
136 table_list = list(yaml.safe_load_all(yaml_stream))
137 for table in table_list:
138 if table['tableName'] == 'DiaSource':
139 self.bit_pack_columns = table['columns']
140 break
142 # Test that all flags requested are present in the input schemas.
143 # Output schemas are flexible, however if names are not specified in
144 # the Apdb schema, flag columns will not be persisted.
145 for outputFlag in self.bit_pack_columns:
146 bitList = outputFlag['bitList']
147 for bit in bitList:
148 try:
149 self.inputSchema.find(bit['name'])
150 except KeyError:
151 raise KeyError(
152 "Requested column %s not found in input DiaSource "
153 "schema. Please check that the requested input "
154 "column exists." % bit['name'])
156 def runQuantum(self, butlerQC, inputRefs, outputRefs):
157 inputs = butlerQC.get(inputRefs)
158 expId, expBits = butlerQC.quantum.dataId.pack("visit_detector",
159 returnMaxBits=True)
160 inputs["ccdVisitId"] = expId
161 inputs["band"] = butlerQC.quantum.dataId["band"]
163 outputs = self.run(**inputs)
165 butlerQC.put(outputs, outputRefs)
167 @timeMethod
168 def run(self,
169 diaSourceCat,
170 diffIm,
171 band,
172 ccdVisitId,
173 funcs=None):
174 """Convert input catalog to ParquetTable/Pandas and run functors.
176 Additionally, add new columns for stripping information from the
177 exposure and into the DiaSource catalog.
179 Parameters
180 ----------
181 diaSourceCat : `lsst.afw.table.SourceCatalog`
182 Catalog of sources measured on the difference image.
183 diffIm : `lsst.afw.image.Exposure`
184 Result of subtracting template and science images.
185 band : `str`
186 Filter band of the science image.
187 ccdVisitId : `int`
188 Identifier for this detector+visit.
189 funcs : `lsst.pipe.tasks.functors.Functors`
190 Functors to apply to the catalog's columns.
192 Returns
193 -------
194 results : `lsst.pipe.base.Struct`
195 Results struct with components.
197 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values
198 and renamed columns.
199 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`)
200 """
201 self.log.info(
202 "Transforming/standardizing the DiaSource table ccdVisitId: %i",
203 ccdVisitId)
205 diaSourceDf = diaSourceCat.asAstropy().to_pandas()
206 if self.config.doRemoveSkySources:
207 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]]
208 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]]
210 diaSourceDf["snr"] = getSignificance(diaSourceCat)
211 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat)
212 diaSourceDf["ccdVisitId"] = ccdVisitId
213 diaSourceDf["filterName"] = band
214 diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD)
215 diaSourceDf["diaObjectId"] = 0
216 diaSourceDf["ssObjectId"] = 0
218 if self.config.doPackFlags:
219 # either bitpack the flags
220 self.bitPackFlags(diaSourceDf)
221 else:
222 # or add the individual flag functors
223 self.addUnpackedFlagFunctors()
224 # and remove the packed flag functor
225 if 'flags' in self.funcs.funcDict:
226 del self.funcs.funcDict['flags']
228 df = self.transform(band,
229 diaSourceDf,
230 self.funcs,
231 dataId=None).df
233 return pipeBase.Struct(
234 diaSourceTable=df,
235 )
237 def addUnpackedFlagFunctors(self):
238 """Add Column functor for each of the flags to the internal functor
239 dictionary.
240 """
241 for flag in self.bit_pack_columns[0]['bitList']:
242 flagName = flag['name']
243 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules'])
244 self.funcs.update({targetName: Column(flagName)})
246 def computeBBoxSizes(self, inputCatalog):
247 """Compute the size of a square bbox that fully contains the detection
248 footprint.
250 Parameters
251 ----------
252 inputCatalog : `lsst.afw.table.SourceCatalog`
253 Catalog containing detected footprints.
255 Returns
256 -------
257 outputBBoxSizes : `np.ndarray`, (N,)
258 Array of bbox sizes.
259 """
260 # Schema validation requires that this field is int.
261 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int)
262 for i, record in enumerate(inputCatalog):
263 footprintBBox = record.getFootprint().getBBox()
264 # Compute twice the size of the largest dimension of the footprint
265 # bounding box. This is the largest footprint we should need to cover
266 # the complete DiaSource assuming the centroid is within the bounding
267 # box.
268 maxSize = 2 * np.max([footprintBBox.getWidth(),
269 footprintBBox.getHeight()])
270 recX = record.getCentroid().x
271 recY = record.getCentroid().y
272 bboxSize = int(
273 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX,
274 footprintBBox.minX - recX,
275 footprintBBox.maxY - recY,
276 footprintBBox.minY - recY]))))
277 if bboxSize > maxSize:
278 bboxSize = maxSize
279 outputBBoxSizes[i] = bboxSize
281 return outputBBoxSizes
283 def bitPackFlags(self, df):
284 """Pack requested flag columns in inputRecord into single columns in
285 outputRecord.
287 Parameters
288 ----------
289 df : `pandas.DataFrame`
290 DataFrame to read bits from and pack them into.
291 """
292 for outputFlag in self.bit_pack_columns:
293 bitList = outputFlag['bitList']
294 value = np.zeros(len(df), dtype=np.uint64)
295 for bit in bitList:
296 # Hard type the bit arrays.
297 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64)
298 df[outputFlag['columnName']] = value
301class UnpackApdbFlags:
302 """Class for unpacking bits from integer flag fields stored in the Apdb.
304 Attributes
305 ----------
306 flag_map_file : `str`
307 Absolute or relative path to a yaml file specifiying mappings of flags
308 to integer bits.
309 table_name : `str`
310 Name of the Apdb table the integer bit data are coming from.
311 """
313 def __init__(self, flag_map_file, table_name):
314 self.bit_pack_columns = []
315 flag_map_file = os.path.expandvars(flag_map_file)
316 with open(flag_map_file) as yaml_stream:
317 table_list = list(yaml.safe_load_all(yaml_stream))
318 for table in table_list:
319 if table['tableName'] == table_name:
320 self.bit_pack_columns = table['columns']
321 break
323 self.output_flag_columns = {}
325 for column in self.bit_pack_columns:
326 names = []
327 for bit in column["bitList"]:
328 names.append((bit["name"], bool))
329 self.output_flag_columns[column["columnName"]] = names
331 def unpack(self, input_flag_values, flag_name):
332 """Determine individual boolean flags from an input array of unsigned
333 ints.
335 Parameters
336 ----------
337 input_flag_values : array-like of type uint
338 Array of integer flags to unpack.
339 flag_name : `str`
340 Apdb column name of integer flags to unpack. Names of packed int
341 flags are given by the flag_map_file.
343 Returns
344 -------
345 output_flags : `numpy.ndarray`
346 Numpy named tuple of booleans.
347 """
348 bit_names_types = self.output_flag_columns[flag_name]
349 output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types)
351 for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types):
352 masked_bits = np.bitwise_and(input_flag_values, 2**bit_idx)
353 output_flags[bit_name] = masked_bits
355 return output_flags
357 def flagExists(self, flagName, columnName='flags'):
358 """Check if named flag is in the bitpacked flag set.
360 Parameters:
361 ----------
362 flagName : `str`
363 Flag name to search for.
364 columnName : `str`, optional
365 Name of bitpacked flag column to search in.
367 Returns
368 -------
369 flagExists : `bool`
370 `True` if `flagName` is present in `columnName`.
372 Raises
373 ------
374 ValueError
375 Raised if `columnName` is not defined.
376 """
377 if columnName not in self.output_flag_columns:
378 raise ValueError(f'column {columnName} not in flag map: {self.output_flag_columns}')
380 return flagName in [c[0] for c in self.output_flag_columns[columnName]]
382 def makeFlagBitMask(self, flagNames, columnName='flags'):
383 """Return a bitmask corresponding to the supplied flag names.
385 Parameters:
386 ----------
387 flagNames : `list` [`str`]
388 Flag names to include in the bitmask.
389 columnName : `str`, optional
390 Name of bitpacked flag column.
392 Returns
393 -------
394 bitmask : `np.unit64`
395 Bitmask corresponding to the supplied flag names given the loaded configuration.
397 Raises
398 ------
399 ValueError
400 Raised if a flag in `flagName` is not included in `columnName`.
401 """
402 bitmask = np.uint64(0)
404 for flag in flagNames:
405 if not self.flagExists(flag, columnName=columnName):
406 raise ValueError(f"flag '{flag}' not included in '{columnName}' flag column")
408 for outputFlag in self.bit_pack_columns:
409 if outputFlag['columnName'] == columnName:
410 bitList = outputFlag['bitList']
411 for bit in bitList:
412 if bit['name'] in flagNames:
413 bitmask += np.uint64(2**bit['bit'])
415 return bitmask
418def getSignificance(catalog):
419 """Return the significance value of the first peak in each source
420 footprint, or NaN for peaks without a significance field.
422 Parameters
423 ----------
424 catalog : `lsst.afw.table.SourceCatalog`
425 Catalog to process.
427 Returns
428 -------
429 significance : `np.ndarray`, (N,)
430 Signficance of the first peak in each source footprint.
431 """
432 result = np.full(len(catalog), np.nan)
433 for i, record in enumerate(catalog):
434 peaks = record.getFootprint().peaks
435 if "significance" in peaks.schema:
436 result[i] = peaks[0]["significance"]
437 return result