Coverage for python/lsst/cp/pipe/cpCombine.py : 16%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of cp_pipe.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21import numpy as np
22import time
24import lsst.pex.config as pexConfig
25import lsst.pipe.base as pipeBase
26import lsst.pipe.base.connectionTypes as cT
27import lsst.afw.math as afwMath
28import lsst.afw.image as afwImage
30from lsst.geom import Point2D
31from lsst.log import Log
32from astro_metadata_translator import merge_headers, ObservationGroup
33from astro_metadata_translator.serialize import dates_to_fits
36# CalibStatsConfig/CalibStatsTask from pipe_base/constructCalibs.py
37class CalibStatsConfig(pexConfig.Config):
38 """Parameters controlling the measurement of background statistics.
39 """
40 stat = pexConfig.Field(
41 dtype=int,
42 default=int(afwMath.MEANCLIP),
43 doc="Statistic to use to estimate background (from lsst.afw.math)",
44 )
45 clip = pexConfig.Field(
46 dtype=float,
47 default=3.0,
48 doc="Clipping threshold for background",
49 )
50 nIter = pexConfig.Field(
51 dtype=int,
52 default=3,
53 doc="Clipping iterations for background",
54 )
55 mask = pexConfig.ListField(
56 dtype=str,
57 default=["DETECTED", "BAD", "NO_DATA"],
58 doc="Mask planes to reject",
59 )
62class CalibStatsTask(pipeBase.Task):
63 """Measure statistics on the background
65 This can be useful for scaling the background, e.g., for flats and fringe frames.
66 """
67 ConfigClass = CalibStatsConfig
69 def run(self, exposureOrImage):
70 """Measure a particular statistic on an image (of some sort).
72 Parameters
73 ----------
74 exposureOrImage : `lsst.afw.image.Exposure`, `lsst.afw.image.MaskedImage`, or `lsst.afw.image.Image`
75 Exposure or image to calculate statistics on.
77 Returns
78 -------
79 results : float
80 Resulting statistic value.
81 """
82 stats = afwMath.StatisticsControl(self.config.clip, self.config.nIter,
83 afwImage.Mask.getPlaneBitMask(self.config.mask))
84 try:
85 image = exposureOrImage.getMaskedImage()
86 except Exception:
87 try:
88 image = exposureOrImage.getImage()
89 except Exception:
90 image = exposureOrImage
92 return afwMath.makeStatistics(image, self.config.stat, stats).getValue()
95class CalibCombineConnections(pipeBase.PipelineTaskConnections,
96 dimensions=("instrument", "detector")):
97 inputExps = cT.Input(
98 name="cpInputs",
99 doc="Input pre-processed exposures to combine.",
100 storageClass="Exposure",
101 dimensions=("instrument", "detector", "exposure"),
102 multiple=True,
103 )
104 inputScales = cT.Input(
105 name="cpScales",
106 doc="Input scale factors to use.",
107 storageClass="StructuredDataDict",
108 dimensions=("instrument", ),
109 multiple=False,
110 )
112 outputData = cT.Output(
113 name="cpProposal",
114 doc="Output combined proposed calibration.",
115 storageClass="ExposureF",
116 dimensions=("instrument", "detector"),
117 isCalibration=True,
118 )
120 def __init__(self, *, config=None):
121 super().__init__(config=config)
123 if config and config.exposureScaling != 'InputList':
124 self.inputs.discard("inputScales")
126 if config and len(config.calibrationDimensions) != 0:
127 newDimensions = tuple(config.calibrationDimensions)
128 newOutputData = cT.Output(
129 name=self.outputData.name,
130 doc=self.outputData.doc,
131 storageClass=self.outputData.storageClass,
132 dimensions=self.allConnections['outputData'].dimensions + newDimensions,
133 isCalibration=True,
134 )
135 self.dimensions.update(config.calibrationDimensions)
136 self.outputData = newOutputData
138 if config.exposureScaling == 'InputList':
139 newInputScales = cT.PrerequisiteInput(
140 name=self.inputScales.name,
141 doc=self.inputScales.doc,
142 storageClass=self.inputScales.storageClass,
143 dimensions=self.allConnections['inputScales'].dimensions + newDimensions
144 )
145 self.dimensions.update(config.calibrationDimensions)
146 self.inputScales = newInputScales
149# CalibCombineConfig/CalibCombineTask from pipe_base/constructCalibs.py
150class CalibCombineConfig(pipeBase.PipelineTaskConfig,
151 pipelineConnections=CalibCombineConnections):
152 """Configuration for combining calib exposures.
153 """
154 calibrationType = pexConfig.Field(
155 dtype=str,
156 default="calibration",
157 doc="Name of calibration to be generated.",
158 )
159 calibrationDimensions = pexConfig.ListField(
160 dtype=str,
161 default=[],
162 doc="List of updated dimensions to append to output."
163 )
165 exposureScaling = pexConfig.ChoiceField(
166 dtype=str,
167 allowed={
168 "None": "No scaling used.",
169 "ExposureTime": "Scale inputs by their exposure time.",
170 "DarkTime": "Scale inputs by their dark time.",
171 "MeanStats": "Scale inputs based on their mean values.",
172 "InputList": "Scale inputs based on a list of values.",
173 },
174 default=None,
175 doc="Scaling to be applied to each input exposure.",
176 )
177 scalingLevel = pexConfig.ChoiceField(
178 dtype=str,
179 allowed={
180 "DETECTOR": "Scale by detector.",
181 "AMP": "Scale by amplifier.",
182 },
183 default="DETECTOR",
184 doc="Region to scale.",
185 )
186 maxVisitsToCalcErrorFromInputVariance = pexConfig.Field(
187 dtype=int,
188 default=5,
189 doc="Maximum number of visits to estimate variance from input variance, not per-pixel spread",
190 )
192 doVignette = pexConfig.Field(
193 dtype=bool,
194 default=False,
195 doc="Copy vignette polygon to output and censor vignetted pixels?"
196 )
198 mask = pexConfig.ListField(
199 dtype=str,
200 default=["SAT", "DETECTED", "INTRP"],
201 doc="Mask planes to respect",
202 )
203 combine = pexConfig.Field(
204 dtype=int,
205 default=int(afwMath.MEANCLIP),
206 doc="Statistic to use for combination (from lsst.afw.math)",
207 )
208 clip = pexConfig.Field(
209 dtype=float,
210 default=3.0,
211 doc="Clipping threshold for combination",
212 )
213 nIter = pexConfig.Field(
214 dtype=int,
215 default=3,
216 doc="Clipping iterations for combination",
217 )
218 stats = pexConfig.ConfigurableField(
219 target=CalibStatsTask,
220 doc="Background statistics configuration",
221 )
224class CalibCombineTask(pipeBase.PipelineTask,
225 pipeBase.CmdLineTask):
226 """Task to combine calib exposures."""
227 ConfigClass = CalibCombineConfig
228 _DefaultName = 'cpCombine'
230 def __init__(self, **kwargs):
231 super().__init__(**kwargs)
232 self.makeSubtask("stats")
234 def runQuantum(self, butlerQC, inputRefs, outputRefs):
235 inputs = butlerQC.get(inputRefs)
237 dimensions = [exp.dataId.byName() for exp in inputRefs.inputExps]
238 inputs['inputDims'] = dimensions
240 outputs = self.run(**inputs)
241 butlerQC.put(outputs, outputRefs)
243 def run(self, inputExps, inputScales=None, inputDims=None):
244 """Combine calib exposures for a single detector.
246 Parameters
247 ----------
248 inputExps : `list` [`lsst.afw.image.Exposure`]
249 Input list of exposures to combine.
250 inputScales : `dict` [`dict` [`dict` [`float`]]], optional
251 Dictionary of scales, indexed by detector (`int`),
252 amplifier (`int`), and exposure (`int`). Used for
253 'inputList' scaling.
254 inputDims : `list` [`dict`]
255 List of dictionaries of input data dimensions/values.
256 Each list entry should contain:
258 ``"exposure"``
259 exposure id value (`int`)
260 ``"detector"``
261 detector id value (`int`)
263 Returns
264 -------
265 combinedExp : `lsst.afw.image.Exposure`
266 Final combined exposure generated from the inputs.
268 Raises
269 ------
270 RuntimeError
271 Raised if no input data is found. Also raised if
272 config.exposureScaling == InputList, and a necessary scale
273 was not found.
274 """
275 width, height = self.getDimensions(inputExps)
276 stats = afwMath.StatisticsControl(self.config.clip, self.config.nIter,
277 afwImage.Mask.getPlaneBitMask(self.config.mask))
278 numExps = len(inputExps)
279 if numExps < 1:
280 raise RuntimeError("No valid input data")
281 if numExps < self.config.maxVisitsToCalcErrorFromInputVariance:
282 stats.setCalcErrorFromInputVariance(True)
284 # Create output exposure for combined data.
285 combined = afwImage.MaskedImageF(width, height)
286 combinedExp = afwImage.makeExposure(combined)
288 # Apply scaling:
289 expScales = []
290 if inputDims is None:
291 inputDims = [dict() for i in inputExps]
293 for index, (exp, dims) in enumerate(zip(inputExps, inputDims)):
294 scale = 1.0
295 if exp is None:
296 self.log.warn("Input %d is None (%s); unable to scale exp.", index, dims)
297 continue
299 if self.config.exposureScaling == "ExposureTime":
300 scale = exp.getInfo().getVisitInfo().getExposureTime()
301 elif self.config.exposureScaling == "DarkTime":
302 scale = exp.getInfo().getVisitInfo().getDarkTime()
303 elif self.config.exposureScaling == "MeanStats":
304 scale = self.stats.run(exp)
305 elif self.config.exposureScaling == "InputList":
306 visitId = dims.get('exposure', None)
307 detectorId = dims.get('detector', None)
308 if visitId is None or detectorId is None:
309 raise RuntimeError(f"Could not identify scaling for input {index} ({dims})")
310 if detectorId not in inputScales['expScale']:
311 raise RuntimeError(f"Could not identify a scaling for input {index}"
312 f" detector {detectorId}")
314 if self.config.scalingLevel == "DETECTOR":
315 if visitId not in inputScales['expScale'][detectorId]:
316 raise RuntimeError(f"Could not identify a scaling for input {index}"
317 f"detector {detectorId} visit {visitId}")
318 scale = inputScales['expScale'][detectorId][visitId]
319 elif self.config.scalingLevel == 'AMP':
320 scale = [inputScales['expScale'][detectorId][amp.getName()][visitId]
321 for amp in exp.getDetector()]
322 else:
323 raise RuntimeError(f"Unknown scaling level: {self.config.scalingLevel}")
324 elif self.config.exposureScaling == 'None':
325 scale = 1.0
326 else:
327 raise RuntimeError(f"Unknown scaling type: {self.config.exposureScaling}.")
329 expScales.append(scale)
330 self.log.info("Scaling input %d by %s", index, scale)
331 self.applyScale(exp, scale)
333 self.combine(combined, inputExps, stats)
335 self.interpolateNans(combined)
337 if self.config.doVignette:
338 polygon = inputExps[0].getInfo().getValidPolygon()
339 VignetteExposure(combined, polygon=polygon, doUpdateMask=True,
340 doSetValue=True, vignetteValue=0.0)
342 # Combine headers
343 self.combineHeaders(inputExps, combinedExp,
344 calibType=self.config.calibrationType, scales=expScales)
346 # Return
347 return pipeBase.Struct(
348 outputData=combinedExp,
349 )
351 def getDimensions(self, expList):
352 """Get dimensions of the inputs.
354 Parameters
355 ----------
356 expList : `list` [`lsst.afw.image.Exposure`]
357 Exps to check the sizes of.
359 Returns
360 -------
361 width, height : `int`
362 Unique set of input dimensions.
363 """
364 dimList = [exp.getDimensions() for exp in expList if exp is not None]
365 return self.getSize(dimList)
367 def getSize(self, dimList):
368 """Determine a consistent size, given a list of image sizes.
370 Parameters
371 -----------
372 dimList : iterable of `tuple` (`int`, `int`)
373 List of dimensions.
375 Raises
376 ------
377 RuntimeError
378 If input dimensions are inconsistent.
380 Returns
381 --------
382 width, height : `int`
383 Common dimensions.
384 """
385 dim = set((w, h) for w, h in dimList)
386 if len(dim) != 1:
387 raise RuntimeError("Inconsistent dimensions: %s" % dim)
388 return dim.pop()
390 def applyScale(self, exposure, scale=None):
391 """Apply scale to input exposure.
393 This implementation applies a flux scaling: the input exposure is
394 divided by the provided scale.
396 Parameters
397 ----------
398 exposure : `lsst.afw.image.Exposure`
399 Exposure to scale.
400 scale : `float` or `list` [`float`], optional
401 Constant scale to divide the exposure by.
402 """
403 if scale is not None:
404 mi = exposure.getMaskedImage()
405 if isinstance(scale, list):
406 for amp, ampScale in zip(exposure.getDetector(), scale):
407 ampIm = mi[amp.getBBox()]
408 ampIm /= ampScale
409 else:
410 mi /= scale
412 def combine(self, target, expList, stats):
413 """Combine multiple images.
415 Parameters
416 ----------
417 target : `lsst.afw.image.Exposure`
418 Output exposure to construct.
419 expList : `list` [`lsst.afw.image.Exposure`]
420 Input exposures to combine.
421 stats : `lsst.afw.math.StatisticsControl`
422 Control explaining how to combine the input images.
423 """
424 images = [img.getMaskedImage() for img in expList if img is not None]
425 afwMath.statisticsStack(target, images, afwMath.Property(self.config.combine), stats)
427 def combineHeaders(self, expList, calib, calibType="CALIB", scales=None):
428 """Combine input headers to determine the set of common headers,
429 supplemented by calibration inputs.
431 Parameters
432 ----------
433 expList : `list` of `lsst.afw.image.Exposure`
434 Input list of exposures to combine.
435 calib : `lsst.afw.image.Exposure`
436 Output calibration to construct headers for.
437 calibType: `str`, optional
438 OBSTYPE the output should claim.
439 scales: `list` of `float`, optional
440 Scale values applied to each input to record.
442 Returns
443 -------
444 header : `lsst.daf.base.PropertyList`
445 Constructed header.
446 """
447 # Header
448 header = calib.getMetadata()
449 header.set("OBSTYPE", calibType)
451 # Keywords we care about
452 comments = {"TIMESYS": "Time scale for all dates",
453 "DATE-OBS": "Start date of earliest input observation",
454 "MJD-OBS": "[d] Start MJD of earliest input observation",
455 "DATE-END": "End date of oldest input observation",
456 "MJD-END": "[d] End MJD of oldest input observation",
457 "MJD-AVG": "[d] MJD midpoint of all input observations",
458 "DATE-AVG": "Midpoint date of all input observations"}
460 # Creation date
461 now = time.localtime()
462 calibDate = time.strftime("%Y-%m-%d", now)
463 calibTime = time.strftime("%X %Z", now)
464 header.set("CALIB_CREATE_DATE", calibDate)
465 header.set("CALIB_CREATE_TIME", calibTime)
467 # Merge input headers
468 inputHeaders = [exp.getMetadata() for exp in expList if exp is not None]
469 merged = merge_headers(inputHeaders, mode='drop')
470 for k, v in merged.items():
471 if k not in header:
472 md = expList[0].getMetadata()
473 comment = md.getComment(k) if k in md else None
474 header.set(k, v, comment=comment)
476 # Construct list of visits
477 visitInfoList = [exp.getInfo().getVisitInfo() for exp in expList if exp is not None]
478 for i, visit in enumerate(visitInfoList):
479 if visit is None:
480 continue
481 header.set("CPP_INPUT_%d" % (i,), visit.getExposureId())
482 header.set("CPP_INPUT_DATE_%d" % (i,), str(visit.getDate()))
483 header.set("CPP_INPUT_EXPT_%d" % (i,), visit.getExposureTime())
484 if scales is not None:
485 header.set("CPP_INPUT_SCALE_%d" % (i,), scales[i])
487 # Not yet working: DM-22302
488 # Create an observation group so we can add some standard headers
489 # independent of the form in the input files.
490 # Use try block in case we are dealing with unexpected data headers
491 try:
492 group = ObservationGroup(visitInfoList, pedantic=False)
493 except Exception:
494 self.log.warn("Exception making an obs group for headers. Continuing.")
495 # Fall back to setting a DATE-OBS from the calibDate
496 dateCards = {"DATE-OBS": "{}T00:00:00.00".format(calibDate)}
497 comments["DATE-OBS"] = "Date of start of day of calibration midpoint"
498 else:
499 oldest, newest = group.extremes()
500 dateCards = dates_to_fits(oldest.datetime_begin, newest.datetime_end)
502 for k, v in dateCards.items():
503 header.set(k, v, comment=comments.get(k, None))
505 return header
507 def interpolateNans(self, exp):
508 """Interpolate over NANs in the combined image.
510 NANs can result from masked areas on the CCD. We don't want them getting
511 into our science images, so we replace them with the median of the image.
513 Parameters
514 ----------
515 exp : `lsst.afw.image.Exposure`
516 Exp to check for NaNs.
517 """
518 array = exp.getImage().getArray()
519 bad = np.isnan(array)
521 median = np.median(array[np.logical_not(bad)])
522 count = np.sum(np.logical_not(bad))
523 array[bad] = median
524 if count > 0:
525 self.log.warn("Found %s NAN pixels", count)
528def VignetteExposure(exposure, polygon=None,
529 doUpdateMask=True, maskPlane='BAD',
530 doSetValue=False, vignetteValue=0.0,
531 log=None):
532 """Apply vignetted polygon to image pixels.
534 Parameters
535 ----------
536 exposure : `lsst.afw.image.Exposure`
537 Image to be updated.
538 doUpdateMask : `bool`, optional
539 Update the exposure mask for vignetted area?
540 maskPlane : `str`, optional,
541 Mask plane to assign.
542 doSetValue : `bool`, optional
543 Set image value for vignetted area?
544 vignetteValue : `float`, optional
545 Value to assign.
546 log : `lsst.log.Log`, optional
547 Log to write to.
549 Raises
550 ------
551 RuntimeError
552 Raised if no valid polygon exists.
553 """
554 polygon = polygon if polygon else exposure.getInfo().getValidPolygon()
555 if not polygon:
556 raise RuntimeError("Could not find valid polygon!")
557 log = log if log else Log.getLogger(__name__.partition(".")[2])
559 fullyIlluminated = True
560 for corner in exposure.getBBox().getCorners():
561 if not polygon.contains(Point2D(corner)):
562 fullyIlluminated = False
564 log.info("Exposure is fully illuminated? %s", fullyIlluminated)
566 if not fullyIlluminated:
567 # Scan pixels.
568 mask = exposure.getMask()
569 numPixels = mask.getBBox().getArea()
571 xx, yy = np.meshgrid(np.arange(0, mask.getWidth(), dtype=int),
572 np.arange(0, mask.getHeight(), dtype=int))
574 vignMask = np.array([not polygon.contains(Point2D(x, y)) for x, y in
575 zip(xx.reshape(numPixels), yy.reshape(numPixels))])
576 vignMask = vignMask.reshape(mask.getHeight(), mask.getWidth())
578 if doUpdateMask:
579 bitMask = mask.getPlaneBitMask(maskPlane)
580 maskArray = mask.getArray()
581 maskArray[vignMask] |= bitMask
582 if doSetValue:
583 imageArray = exposure.getImage().getArray()
584 imageArray[vignMask] = vignetteValue
585 log.info("Exposure contains %d vignetted pixels.",
586 np.count_nonzero(vignMask))