Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 20%
152 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-05 01:47 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-05 01:47 -0700
1# This file is part of ap_association
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ("TransformDiaSourceCatalogConnections",
23 "TransformDiaSourceCatalogConfig",
24 "TransformDiaSourceCatalogTask",
25 "UnpackApdbFlags")
27import numpy as np
28import os
29import yaml
31from lsst.daf.base import DateTime
32import lsst.pex.config as pexConfig
33import lsst.pipe.base as pipeBase
34import lsst.pipe.base.connectionTypes as connTypes
35from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig
36from lsst.pipe.tasks.parquetTable import ParquetTable
37from lsst.pipe.tasks.functors import Column
38from lsst.utils.timer import timeMethod
41class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections,
42 dimensions=("instrument", "visit", "detector"),
43 defaultTemplates={"coaddName": "deep", "fakesType": ""}):
44 diaSourceSchema = connTypes.InitInput(
45 doc="Schema for DIASource catalog output by ImageDifference.",
46 storageClass="SourceCatalog",
47 name="{fakesType}{coaddName}Diff_diaSrc_schema",
48 )
49 diaSourceCat = connTypes.Input(
50 doc="Catalog of DiaSources produced during image differencing.",
51 name="{fakesType}{coaddName}Diff_diaSrc",
52 storageClass="SourceCatalog",
53 dimensions=("instrument", "visit", "detector"),
54 )
55 diffIm = connTypes.Input(
56 doc="Difference image on which the DiaSources were detected.",
57 name="{fakesType}{coaddName}Diff_differenceExp",
58 storageClass="ExposureF",
59 dimensions=("instrument", "visit", "detector"),
60 )
61 diaSourceTable = connTypes.Output(
62 doc=".",
63 name="{fakesType}{coaddName}Diff_diaSrcTable",
64 storageClass="DataFrame",
65 dimensions=("instrument", "visit", "detector"),
66 )
69class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig,
70 pipelineConnections=TransformDiaSourceCatalogConnections):
71 flagMap = pexConfig.Field(
72 dtype=str,
73 doc="Yaml file specifying SciencePipelines flag fields to bit packs.",
74 default=os.path.join("${AP_ASSOCIATION_DIR}",
75 "data",
76 "association-flag-map.yaml"),
77 )
78 flagRenameMap = pexConfig.Field(
79 dtype=str,
80 doc="Yaml file specifying specifying rules to rename flag names",
81 default=os.path.join("${AP_ASSOCIATION_DIR}",
82 "data",
83 "flag-rename-rules.yaml"),
84 )
85 doRemoveSkySources = pexConfig.Field(
86 dtype=bool,
87 default=False,
88 doc="Input DiaSource catalog contains SkySources that should be "
89 "removed before storing the output DiaSource catalog."
90 )
91 doPackFlags = pexConfig.Field(
92 dtype=bool,
93 default=True,
94 doc="Do pack the flags into one integer column named 'flags'."
95 "If False, instead produce one boolean column per flag."
96 )
98 def setDefaults(self):
99 super().setDefaults()
100 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}",
101 "data",
102 "DiaSource.yaml")
105class TransformDiaSourceCatalogTask(TransformCatalogBaseTask):
106 """Transform a DiaSource catalog by calibrating and renaming columns to
107 produce a table ready to insert into the Apdb.
109 Parameters
110 ----------
111 initInputs : `dict`
112 Must contain ``diaSourceSchema`` as the schema for the input catalog.
113 """
114 ConfigClass = TransformDiaSourceCatalogConfig
115 _DefaultName = "transformDiaSourceCatalog"
116 # Needed to create a valid TransformCatalogBaseTask, but unused
117 inputDataset = "deepDiff_diaSrc"
118 outputDataset = "deepDiff_diaSrcTable"
120 def __init__(self, initInputs, **kwargs):
121 super().__init__(**kwargs)
122 self.funcs = self.getFunctors()
123 self.inputSchema = initInputs['diaSourceSchema'].schema
124 self._create_bit_pack_mappings()
126 if not self.config.doPackFlags:
127 # get the flag rename rules
128 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream:
129 self.rename_rules = list(yaml.safe_load_all(yaml_stream))
131 def _create_bit_pack_mappings(self):
132 """Setup all flag bit packings.
133 """
134 self.bit_pack_columns = []
135 flag_map_file = os.path.expandvars(self.config.flagMap)
136 with open(flag_map_file) as yaml_stream:
137 table_list = list(yaml.safe_load_all(yaml_stream))
138 for table in table_list:
139 if table['tableName'] == 'DiaSource':
140 self.bit_pack_columns = table['columns']
141 break
143 # Test that all flags requested are present in the input schemas.
144 # Output schemas are flexible, however if names are not specified in
145 # the Apdb schema, flag columns will not be persisted.
146 for outputFlag in self.bit_pack_columns:
147 bitList = outputFlag['bitList']
148 for bit in bitList:
149 try:
150 self.inputSchema.find(bit['name'])
151 except KeyError:
152 raise KeyError(
153 "Requested column %s not found in input DiaSource "
154 "schema. Please check that the requested input "
155 "column exists." % bit['name'])
157 def runQuantum(self, butlerQC, inputRefs, outputRefs):
158 inputs = butlerQC.get(inputRefs)
159 expId, expBits = butlerQC.quantum.dataId.pack("visit_detector",
160 returnMaxBits=True)
161 inputs["ccdVisitId"] = expId
162 inputs["band"] = butlerQC.quantum.dataId["band"]
164 outputs = self.run(**inputs)
166 butlerQC.put(outputs, outputRefs)
168 @timeMethod
169 def run(self,
170 diaSourceCat,
171 diffIm,
172 band,
173 ccdVisitId,
174 funcs=None):
175 """Convert input catalog to ParquetTable/Pandas and run functors.
177 Additionally, add new columns for stripping information from the
178 exposure and into the DiaSource catalog.
180 Parameters
181 ----------
182 diaSourceCat : `lsst.afw.table.SourceCatalog`
183 Catalog of sources measured on the difference image.
184 diffIm : `lsst.afw.image.Exposure`
185 Result of subtracting template and science images.
186 band : `str`
187 Filter band of the science image.
188 ccdVisitId : `int`
189 Identifier for this detector+visit.
190 funcs : `lsst.pipe.tasks.functors.Functors`
191 Functors to apply to the catalog's columns.
193 Returns
194 -------
195 results : `lsst.pipe.base.Struct`
196 Results struct with components.
198 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values
199 and renamed columns.
200 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`)
201 """
202 self.log.info(
203 "Transforming/standardizing the DiaSource table ccdVisitId: %i",
204 ccdVisitId)
206 diaSourceDf = diaSourceCat.asAstropy().to_pandas()
207 if self.config.doRemoveSkySources:
208 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]]
209 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]]
211 diaSourceDf["snr"] = getSignificance(diaSourceCat)
212 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat)
213 diaSourceDf["ccdVisitId"] = ccdVisitId
214 diaSourceDf["filterName"] = band
215 diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD)
216 diaSourceDf["diaObjectId"] = 0
217 diaSourceDf["ssObjectId"] = 0
219 if self.config.doPackFlags:
220 # either bitpack the flags
221 self.bitPackFlags(diaSourceDf)
222 else:
223 # or add the individual flag functors
224 self.addUnpackedFlagFunctors()
225 # and remove the packed flag functor
226 if 'flags' in self.funcs.funcDict:
227 del self.funcs.funcDict['flags']
229 df = self.transform(band,
230 ParquetTable(dataFrame=diaSourceDf),
231 self.funcs,
232 dataId=None).df
234 return pipeBase.Struct(
235 diaSourceTable=df,
236 )
238 def addUnpackedFlagFunctors(self):
239 """Add Column functor for each of the flags to the internal functor
240 dictionary.
241 """
242 for flag in self.bit_pack_columns[0]['bitList']:
243 flagName = flag['name']
244 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules'])
245 self.funcs.update({targetName: Column(flagName)})
247 def computeBBoxSizes(self, inputCatalog):
248 """Compute the size of a square bbox that fully contains the detection
249 footprint.
251 Parameters
252 ----------
253 inputCatalog : `lsst.afw.table.SourceCatalog`
254 Catalog containing detected footprints.
256 Returns
257 -------
258 outputBBoxSizes : `np.ndarray`, (N,)
259 Array of bbox sizes.
260 """
261 # Schema validation requires that this field is int.
262 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int)
263 for i, record in enumerate(inputCatalog):
264 footprintBBox = record.getFootprint().getBBox()
265 # Compute twice the size of the largest dimension of the footprint
266 # bounding box. This is the largest footprint we should need to cover
267 # the complete DiaSource assuming the centroid is within the bounding
268 # box.
269 maxSize = 2 * np.max([footprintBBox.getWidth(),
270 footprintBBox.getHeight()])
271 recX = record.getCentroid().x
272 recY = record.getCentroid().y
273 bboxSize = int(
274 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX,
275 footprintBBox.minX - recX,
276 footprintBBox.maxY - recY,
277 footprintBBox.minY - recY]))))
278 if bboxSize > maxSize:
279 bboxSize = maxSize
280 outputBBoxSizes[i] = bboxSize
282 return outputBBoxSizes
284 def bitPackFlags(self, df):
285 """Pack requested flag columns in inputRecord into single columns in
286 outputRecord.
288 Parameters
289 ----------
290 df : `pandas.DataFrame`
291 DataFrame to read bits from and pack them into.
292 """
293 for outputFlag in self.bit_pack_columns:
294 bitList = outputFlag['bitList']
295 value = np.zeros(len(df), dtype=np.uint64)
296 for bit in bitList:
297 # Hard type the bit arrays.
298 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64)
299 df[outputFlag['columnName']] = value
302class UnpackApdbFlags:
303 """Class for unpacking bits from integer flag fields stored in the Apdb.
305 Attributes
306 ----------
307 flag_map_file : `str`
308 Absolute or relative path to a yaml file specifiying mappings of flags
309 to integer bits.
310 table_name : `str`
311 Name of the Apdb table the integer bit data are coming from.
312 """
314 def __init__(self, flag_map_file, table_name):
315 self.bit_pack_columns = []
316 flag_map_file = os.path.expandvars(flag_map_file)
317 with open(flag_map_file) as yaml_stream:
318 table_list = list(yaml.safe_load_all(yaml_stream))
319 for table in table_list:
320 if table['tableName'] == table_name:
321 self.bit_pack_columns = table['columns']
322 break
324 self.output_flag_columns = {}
326 for column in self.bit_pack_columns:
327 names = []
328 for bit in column["bitList"]:
329 names.append((bit["name"], bool))
330 self.output_flag_columns[column["columnName"]] = names
332 def unpack(self, input_flag_values, flag_name):
333 """Determine individual boolean flags from an input array of unsigned
334 ints.
336 Parameters
337 ----------
338 input_flag_values : array-like of type uint
339 Array of integer flags to unpack.
340 flag_name : `str`
341 Apdb column name of integer flags to unpack. Names of packed int
342 flags are given by the flag_map_file.
344 Returns
345 -------
346 output_flags : `numpy.ndarray`
347 Numpy named tuple of booleans.
348 """
349 bit_names_types = self.output_flag_columns[flag_name]
350 output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types)
352 for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types):
353 masked_bits = np.bitwise_and(input_flag_values, 2**bit_idx)
354 output_flags[bit_name] = masked_bits
356 return output_flags
358 def flagExists(self, flagName, columnName='flags'):
359 """Check if named flag is in the bitpacked flag set.
361 Parameters:
362 ----------
363 flagName : `str`
364 Flag name to search for.
365 columnName : `str`, optional
366 Name of bitpacked flag column to search in.
368 Returns
369 -------
370 flagExists : `bool`
371 `True` if `flagName` is present in `columnName`.
373 Raises
374 ------
375 ValueError
376 Raised if `columnName` is not defined.
377 """
378 if columnName not in self.output_flag_columns:
379 raise ValueError(f'column {columnName} not in flag map: {self.output_flag_columns}')
381 return flagName in [c[0] for c in self.output_flag_columns[columnName]]
383 def makeFlagBitMask(self, flagNames, columnName='flags'):
384 """Return a bitmask corresponding to the supplied flag names.
386 Parameters:
387 ----------
388 flagNames : `list` [`str`]
389 Flag names to include in the bitmask.
390 columnName : `str`, optional
391 Name of bitpacked flag column.
393 Returns
394 -------
395 bitmask : `np.unit64`
396 Bitmask corresponding to the supplied flag names given the loaded configuration.
398 Raises
399 ------
400 ValueError
401 Raised if a flag in `flagName` is not included in `columnName`.
402 """
403 bitmask = np.uint64(0)
405 for flag in flagNames:
406 if not self.flagExists(flag, columnName=columnName):
407 raise ValueError(f"flag '{flag}' not included in '{columnName}' flag column")
409 for outputFlag in self.bit_pack_columns:
410 if outputFlag['columnName'] == columnName:
411 bitList = outputFlag['bitList']
412 for bit in bitList:
413 if bit['name'] in flagNames:
414 bitmask += np.uint64(2**bit['bit'])
416 return bitmask
419def getSignificance(catalog):
420 """Return the significance value of the first peak in each source
421 footprint, or NaN for peaks without a significance field.
423 Parameters
424 ----------
425 catalog : `lsst.afw.table.SourceCatalog`
426 Catalog to process.
428 Returns
429 -------
430 significance : `np.ndarray`, (N,)
431 Signficance of the first peak in each source footprint.
432 """
433 result = np.full(len(catalog), np.nan)
434 for i, record in enumerate(catalog):
435 peaks = record.getFootprint().peaks
436 if "significance" in peaks.schema:
437 result[i] = peaks[0]["significance"]
438 return result