Coverage for python / lsst / analysis / ap / plotImageSubtractionCutouts.py: 14%
363 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:09 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:09 +0000
1# This file is part of analysis_ap.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Construct template/image/difference cutouts for upload to Zooniverse, or
23to just to view as images.
24"""
26__all__ = ["PlotImageSubtractionCutoutsConfig", "PlotImageSubtractionCutoutsTask", "CutoutPath"]
28import argparse
29import functools
30import io
31import logging
32import multiprocessing
33import os
34from math import log10
36import astropy.units as u
37from lsst.daf.butler import DatasetNotFoundError
38import lsst.dax.apdb
39import lsst.geom
40import lsst.pex.config as pexConfig
41import lsst.pex.exceptions
42import lsst.pipe.base
43import lsst.utils
44import numpy as np
45import pandas as pd
46import sqlalchemy
48from . import apdb
51class _ButlerCache:
52 """Global class to handle butler queries, to allow lru_cache and
53 `multiprocessing.Pool` to work together.
55 If we redo this all to work with BPS or other parallelized systems, or get
56 good butler-side caching, we could remove this lru_cache system.
57 """
59 def set(self, butler, config):
60 """Call this to store a Butler and Config instance before using the
61 global class instance.
63 Parameters
64 ----------
65 butler : `lsst.daf.butler.Butler`
66 Butler instance to store.
67 config : `lsst.pex.config.Config`
68 Config instance to store.
69 """
70 self._butler = butler
71 self._config = config
72 # Ensure the caches are empty if we've been re-set.
73 self.get_exposures.cache_clear()
74 self.get_catalog.cache_clear()
76 @functools.lru_cache(maxsize=4)
77 def get_exposures(self, instrument, detector, visit):
78 """Return science, template, difference exposures, using a small
79 cache so we don't have to re-read files as often.
81 Parameters
82 ----------
83 instrument : `str`
84 Instrument name to define the data id.
85 detector : `int`
86 Detector id to define the data id.
87 visit : `int`
88 Visit id to define the data id.
90 Returns
91 -------
92 exposures : `tuple` [`lsst.afw.image.ExposureF`]
93 Science, template, and difference exposure for this data id.
94 """
95 data_id = {'instrument': instrument, 'detector': detector, 'visit': visit}
96 try:
97 science = self._butler.get(self._config.science_image_type, data_id)
98 except DatasetNotFoundError as e:
99 self.log.error(f"Cannot load {self._config.science_image_type} with data_id {data_id}: {e}")
100 self.log.error("If you are working with data processed earlier than May 2025, try setting "
101 "config.science_image_type = 'initial_pvi' or 'calexp'.")
102 raise
104 if self._config.diff_image_type is not None:
105 template = self._butler.get(f"{self._config.diff_image_type}_templateExp", data_id)
106 difference = self._butler.get(f"{self._config.diff_image_type}_differenceExp", data_id)
107 else:
108 template = self._butler.get("template_detector", data_id)
109 difference = self._butler.get("difference_image", data_id)
111 return science, template, difference
113 @functools.lru_cache(maxsize=4)
114 def get_catalog(self, instrument, detector, visit):
115 """Return the diaSrc catalog from the butler.
117 Parameters
118 ----------
119 instrument : `str`
120 Instrument name to define the data id.
121 detector : `int`
122 Detector id to define the data id.
123 visit : `int`
124 Visit id to define the data id.
126 Returns
127 -------
128 catalog : `lsst.afw.table.SourceCatalog`
129 DiaSource catalog for this data id.
130 """
131 data_id = {'instrument': instrument, 'detector': detector, 'visit': visit}
132 return self._butler.get(f'{self._config.diff_image_type}_diaSrc', data_id)
135# Global used within each multiprocessing worker (or single process).
136butler_cache = _ButlerCache()
139class PlotImageSubtractionCutoutsConfig(pexConfig.Config):
140 sizes = pexConfig.ListField(
141 doc="List of widths of cutout to extract for image from science, \
142 template, and difference exposures.",
143 dtype=int,
144 default=[30],
145 )
146 use_footprint = pexConfig.Field(
147 doc="Use source footprint to to define cutout region; "
148 "If set, ignore `size` and use the footprint bbox instead.",
149 dtype=bool,
150 default=False,
151 )
152 url_root = pexConfig.Field(
153 doc="URL that the resulting images will be served to Zooniverse from, for the manifest file. "
154 "If not set, no manifest file will be written.",
155 dtype=str,
156 default=None,
157 optional=True,
158 )
159 diff_image_type = pexConfig.Field(
160 doc="Optional partial dataset name of template and difference image to use for cutouts; "
161 "will have '_templateExp' and '_differenceExp' appended for butler.get(), respectively."
162 " If not specified, use `template_detector` and `difference_image`, respectively.",
163 dtype=str,
164 default=None,
165 optional=True
166 )
167 science_image_type = pexConfig.Field(
168 doc="Dataset type of science image to use for cutouts; "
169 "older processings could be `calexp` or `initial_pvi`.",
170 dtype=str,
171 default="preliminary_visit_image",
172 )
173 add_metadata = pexConfig.Field(
174 doc="Annotate the cutouts with catalog metadata, including coordinates, fluxes, flags, etc.",
175 dtype=bool,
176 default=True
177 )
178 chunk_size = pexConfig.Field(
179 doc="Chunk up files into subdirectories, with at most this many files per directory."
180 " None means write all the files to one `images/` directory.",
181 dtype=int,
182 default=10000,
183 optional=True
184 )
185 save_as_numpy = pexConfig.Field(
186 doc="Save the raw cutout images in numpy format.",
187 dtype=bool,
188 default=False
189 )
192class PlotImageSubtractionCutoutsTask(lsst.pipe.base.Task):
193 """Generate template/science/difference image cutouts of DiaSources and an
194 optional manifest for upload to a Zooniverse project.
196 Parameters
197 ----------
198 output_path : `str`
199 The path to write the output to; manifest goes here, while the
200 images themselves go into ``output_path/images/``.
201 """
202 ConfigClass = PlotImageSubtractionCutoutsConfig
203 _DefaultName = "plotImageSubtractionCutouts"
205 def __init__(self, *, output_path, **kwargs):
206 super().__init__(**kwargs)
207 self._output_path = output_path
208 self.cutout_path = CutoutPath(output_path, chunk_size=self.config.chunk_size)
209 self.numpy_path = CutoutPath(output_path, chunk_size=self.config.chunk_size,
210 subdirectory='numpy')
212 def _reduce_kwargs(self):
213 # to allow pickling of this Task
214 kwargs = super()._reduce_kwargs()
215 kwargs["output_path"] = self._output_path
216 return kwargs
218 def run(self, data, butler, njobs=0):
219 """Generate cutout images and a manifest for upload to Zooniverse
220 from a collection of DiaSources.
222 Parameters
223 ----------
224 data : `pandas.DataFrame`
225 The DiaSources to extract cutouts for. Must contain at least these
226 fields: ``ra, dec, diaSourceId, detector, visit, instrument``.
227 butler : `lsst.daf.butler.Butler`
228 The butler connection to use to load the data; create it with the
229 collections you wish to load images from.
230 njobs : `int`, optional
231 Number of multiprocessing jobs to make cutouts with; default of 0
232 means don't use multiprocessing at all.
234 Returns
235 -------
236 source_ids : `list` [`int`]
237 DiaSourceIds of cutout images that were generated.
238 """
239 result = self.write_images(data, butler, njobs=njobs)
240 self.write_manifest(result)
241 self.log.info("Wrote %d images to %s", len(result), self._output_path)
242 return result
244 def write_manifest(self, sources):
245 """Save a Zooniverse manifest attaching image URLs to source ids.
247 Parameters
248 ----------
249 sources : `list` [`int`]
250 The diaSourceIds of the sources that had cutouts succesfully made.
251 """
252 if self.config.url_root is not None:
253 manifest = self._make_manifest(sources)
254 manifest.to_csv(os.path.join(self._output_path, "manifest.csv"), index=False)
255 else:
256 self.log.info("No url_root config provided, so no Zooniverse manifest file was written.")
258 def _make_manifest(self, sources):
259 """Return a Zooniverse manifest attaching image URLs to source ids.
261 Parameters
262 ----------
263 sources : `list` [`int`]
264 The diaSourceIds of the sources that had cutouts succesfully made.
266 Returns
267 -------
268 manifest : `pandas.DataFrame`
269 The formatted URL manifest for upload to Zooniverse.
270 """
271 cutout_path = CutoutPath(self.config.url_root)
272 manifest = pd.DataFrame()
273 manifest["external_id"] = sources
274 manifest["location:1"] = [cutout_path(x, f'{x}.png') for x in sources]
275 manifest["metadata:diaSourceId"] = sources
276 return manifest
278 def write_images(self, data, butler, njobs=0):
279 """Make the 3-part cutout images for each requested source and write
280 them to disk.
282 Creates ``images/`` and ``numpy/`` subdirectories if they
283 do not already exist; images are written there as PNG and npy files.
285 Parameters
286 ----------
287 data : `pandas.DataFrame`
288 The DiaSources to extract cutouts for. Must contain at least these
289 fields: ``ra, dec, diaSourceId, detector, visit, instrument``.
290 butler : `lsst.daf.butler.Butler`
291 The butler connection to use to load the data; create it with the
292 collections you wish to load images from.
293 njobs : `int`, optional
294 Number of multiprocessing jobs to make cutouts with; default of 0
295 means don't use multiprocessing at all.
297 Returns
298 -------
299 sources : `list`
300 DiaSourceIds that had cutouts made.
301 """
302 # Ignore divide-by-zero and log-of-negative-value messages.
303 seterr_dict = np.seterr(divide="ignore", invalid="ignore")
305 # Exclude index if they are replicated in columns.
306 indexNotInColumns = not any(index in data.columns for index in data.index.names)
308 sources = []
309 butler_cache.set(butler, self.config)
310 if njobs > 0:
311 with multiprocessing.Pool(njobs) as pool:
312 sources = pool.map(self._do_one_source, data.to_records(index=indexNotInColumns))
313 else:
314 for i, source in enumerate(data.to_records(index=indexNotInColumns)):
315 if not self.cutout_path.exists(source["diaSourceId"],
316 f'{source["diaSourceId"]}.png'):
317 id = self._do_one_source(source)
318 sources.append(id)
320 # restore numpy error message state
321 np.seterr(**seterr_dict)
322 # Only return successful ids, not failures.
323 return [s for s in sources if s is not None]
325 def _do_one_source(self, source):
326 """Make cutouts for one diaSource.
328 Parameters
329 ----------
330 source : `numpy.record`, optional
331 DiaSource record for this cutout, to add metadata to the image.
333 Returns
334 -------
335 diaSourceId : `int` or None
336 Id of the source that was generated, or None if there was an error.
337 """
338 try:
339 center = lsst.geom.SpherePoint(source["ra"], source["dec"], lsst.geom.degrees)
340 science, template, difference = butler_cache.get_exposures(source["instrument"],
341 source["detector"],
342 source["visit"])
343 if self.config.use_footprint:
344 catalog = butler_cache.get_catalog(source["instrument"],
345 source["detector"],
346 source["visit"])
347 # The input catalogs must be sorted.
348 if not catalog.isSorted():
349 data_id = {'instrument': source["instrument"],
350 'detector': source["detector"],
351 'visit': source["visit"]}
352 msg = f"{self.config.diff_image_type}_diaSrc catalog for {data_id} is not sorted!"
353 raise RuntimeError(msg)
354 record = catalog.find(source['diaSourceId'])
355 footprint = record.getFootprint()
357 scale = science.wcs.getPixelScale(science.getBBox().getCenter()).asArcseconds()
358 image = self.generate_image(science, template, difference, center, scale,
359 dia_source_id=source['diaSourceId'],
360 save_as_numpy=self.config.save_as_numpy,
361 source=source if self.config.add_metadata else None,
362 footprint=footprint if self.config.use_footprint else None)
363 self.cutout_path.mkdir(source["diaSourceId"])
364 with open(self.cutout_path(source["diaSourceId"],
365 f'{source["diaSourceId"]}.png'), "wb") as outfile:
366 outfile.write(image.getbuffer())
367 return source["diaSourceId"]
368 except (LookupError, lsst.pex.exceptions.Exception) as e:
369 self.log.error(
370 f"{e.__class__.__name__} processing diaSourceId {source['diaSourceId']}: {e}"
371 )
372 return None
373 except Exception:
374 # Ensure other exceptions are interpretable when multiprocessing.
375 import traceback
376 traceback.print_exc()
377 raise
379 def generate_image(self, science, template, difference, center, scale, dia_source_id=None,
380 save_as_numpy=False, source=None, footprint=None):
381 """Get a 3-part cutout image to save to disk, for a single source.
383 Parameters
384 ----------
385 science : `lsst.afw.image.ExposureF`
386 Science exposure to include in the cutout.
387 template : `lsst.afw.image.ExposureF`
388 Matched template exposure to include in the cutout.
389 difference : `lsst.afw.image.ExposureF`
390 Matched science minus template exposure to include in the cutout.
391 center : `lsst.geom.SpherePoint`
392 Center of the source to be cut out of each image.
393 scale : `float`
394 Pixel scale in arcseconds.
395 dia_source_id : `int`, optional
396 DiaSourceId to use in the filename, if saving to disk.
397 save_as_numpy : `bool`, optional
398 Save the raw cutout images in numpy format.
399 source : `numpy.record`, optional
400 DiaSource record for this cutout, to add metadata to the image.
401 footprint : `lsst.afw.detection.Footprint`, optional
402 Detected source footprint; if specified, extract a square
403 surrounding the footprint bbox, otherwise use ``config.size``.
405 Returns
406 -------
407 image: `io.BytesIO`
408 The generated image, to be output to a file or displayed on screen.
409 """
410 numpy_cutouts = {}
411 if not self.config.use_footprint:
412 sizes = self.config.sizes
413 cutout_science, cutout_template, cutout_difference = [], [], []
414 for i, s in enumerate(sizes):
415 extent = lsst.geom.Extent2I(s, s)
416 science_cutout = science.getCutout(center, extent)
417 template_cutout = template.getCutout(center, extent)
418 difference_cutout = difference.getCutout(center, extent)
419 if save_as_numpy:
420 self.numpy_path.mkdir(dia_source_id)
421 numpy_cutouts[f"sci_{s}"] = science_cutout.image.array
422 numpy_cutouts[f"temp_{s}"] = template_cutout.image.array
423 numpy_cutouts[f"diff_{s}"] = difference_cutout.image.array
424 for cutout_type, cutout in numpy_cutouts.items():
425 outfile = self.numpy_path(dia_source_id, f'{dia_source_id}_{cutout_type}.npy')
426 np.save(outfile, np.expand_dims(cutout, axis=0))
427 cutout_science.append(science_cutout)
428 cutout_template.append(template_cutout)
429 cutout_difference.append(difference_cutout)
430 else:
431 if self.config.save_as_numpy:
432 raise RuntimeError("Cannot save as numpy when using footprints.")
433 cutout_science = [science.getCutout(footprint.getBBox())]
434 cutout_template = [template.getCutout(footprint.getBBox())]
435 cutout_difference = [difference.getCutout(footprint.getBBox())]
436 extent = footprint.getBBox().getDimensions()
437 # Plot a square equal to the largest dimension.
438 sizes = [extent.x if extent.x > extent.y else extent.y]
440 return self._plot_cutout(cutout_science,
441 cutout_template,
442 cutout_difference,
443 scale,
444 sizes,
445 source=source)
447 def _plot_cutout(self, science, template, difference, scale, sizes, source=None):
448 """Plot the cutouts for a source in one image.
450 Parameters
451 ----------
452 science : `list` [`lsst.afw.image.ExposureF`]
453 List of cutout Science exposure(s) to include in the image.
454 template : `list` [`lsst.afw.image.ExposureF`]
455 List of cutout template exposure(s) to include in the image.
456 difference : `list` [`lsst.afw.image.ExposureF`]
457 List of cutout science minus template exposure(s) to include
458 in the image.
459 source : `numpy.record`, optional
460 DiaSource record for this cutout, to add metadata to the image.
461 scale : `float`
462 Pixel scale in arcseconds.
463 size : `list` [`int`]
464 List of x/y dimensions of of the images passed in, to set imshow
465 extent.
467 Returns
468 -------
469 image: `io.BytesIO`
470 The generated image, to be output to a file via
471 `image.write(filename)` or displayed on screen.
472 """
473 import astropy.visualization as aviz
474 import matplotlib
475 matplotlib.use("AGG")
476 # Force matplotlib defaults
477 matplotlib.rcParams.update(matplotlib.rcParamsDefault)
478 import matplotlib.pyplot as plt
479 from matplotlib import cm
481 # TODO DM-32014: how do we color masked pixels (including edges)?
483 def plot_one_image(ax, data, size, name=None):
484 """Plot a normalized image on an axis."""
485 if name == "Difference":
486 norm = aviz.ImageNormalize(
487 # focus on a rect of dim 15 at the center of the image.
488 data[data.shape[0] // 2 - 7:data.shape[0] // 2 + 8,
489 data.shape[1] // 2 - 7:data.shape[1] // 2 + 8],
490 interval=aviz.MinMaxInterval(),
491 stretch=aviz.AsinhStretch(a=0.1),
492 )
493 else:
494 norm = aviz.ImageNormalize(
495 data,
496 interval=aviz.MinMaxInterval(),
497 stretch=aviz.AsinhStretch(a=0.1),
498 )
499 ax.imshow(data, cmap=cm.bone, interpolation="none", norm=norm,
500 extent=(0, size, 0, size), origin="lower", aspect="equal")
501 x_line = 1
502 y_line = 1
503 ax.plot((x_line, x_line + 1.0/scale), (y_line, y_line), color="blue", lw=6)
504 ax.plot((x_line, x_line + 1.0/scale), (y_line, y_line), color="yellow", lw=2)
505 ax.axis("off")
506 if name is not None:
507 ax.set_title(name)
509 try:
510 len_sizes = len(sizes)
511 fig, axs = plt.subplots(len_sizes, 3, constrained_layout=True)
512 if len_sizes == 1:
513 plot_one_image(axs[0], template[0].image.array, sizes[0], "Template")
514 plot_one_image(axs[1], science[0].image.array, sizes[0], "Science")
515 plot_one_image(axs[2], difference[0].image.array, sizes[0], "Difference")
516 else:
517 plot_one_image(axs[0][0], template[0].image.array, sizes[0], "Template")
518 plot_one_image(axs[0][1], science[0].image.array, sizes[0], "Science")
519 plot_one_image(axs[0][2], difference[0].image.array, sizes[0], "Difference")
520 for i in range(1, len(axs)):
521 plot_one_image(axs[i][0], template[i].image.array, sizes[i], None)
522 plot_one_image(axs[i][1], science[i].image.array, sizes[i], None)
523 plot_one_image(axs[i][2], difference[i].image.array, sizes[i], None)
524 if source is not None:
525 _annotate_image(fig, source, len_sizes)
527 output = io.BytesIO()
528 plt.savefig(output, bbox_inches="tight", format="png")
529 output.seek(0) # to ensure opening the image starts from the front
530 finally:
531 plt.close(fig)
533 return output
536def _annotate_image(fig, source, len_sizes):
537 """Annotate the cutouts image with metadata and flags.
539 Parameters
540 ----------
541 fig : `matplotlib.Figure`
542 Figure to be annotated.
543 source : `numpy.record`
544 DiaSource record of the object being plotted.
545 len_sizes : `int`
546 Length of the ``size`` array set in configuration.
547 """
548 # Names of flags fields to add a flag label to the image, using any().
549 flags_psf = ["psfFlux_flag", "psfFlux_flag_noGoodPixels", "psfFlux_flag_edge"]
550 flags_aperture = ["apFlux_flag", "apFlux_flag_apertureTruncated"]
551 flags_forced = ["forced_PsfFlux_flag", "forced_PsfFlux_flag_noGoodPixels",
552 "forced_PsfFlux_flag_edge"]
553 flags_edge = ["pixelFlags_edge"]
554 flags_interp = ["pixelFlags_interpolated", "pixelFlags_interpolatedCenter"]
555 flags_saturated = ["pixelFlags_saturated", "pixelFlags_saturatedCenter"]
556 flags_cr = ["pixelFlags_cr", "pixelFlags_crCenter"]
557 flags_bad = ["pixelFlags_bad"]
558 flags_suspect = ["pixelFlags_suspect", "pixelFlags_suspectCenter"]
559 flags_centroid = ["centroid_flag"]
560 flags_shape = ["shape_flag", "shape_flag_no_pixels", "shape_flag_not_contained",
561 "shape_flag_parent_source"]
563 flag_color = "red"
564 text_color = "grey"
566 if len_sizes == 1:
567 heights = [0.95, 0.91, 0.87, 0.83, 0.79]
568 else:
569 heights = [1.2, 1.15, 1.1, 1.05, 1.0]
571 # NOTE: fig.text coordinates are in fractions of the figure.
572 fig.text(0, heights[0], "diaSourceId:", color=text_color)
573 fig.text(0.145, heights[0], f"{source['diaSourceId']}")
574 fig.text(0.43, heights[0], f"{source['instrument']}", fontweight="bold")
575 fig.text(0.64, heights[0], "detector:", color=text_color)
576 fig.text(0.74, heights[0], f"{source['detector']}")
577 fig.text(0.795, heights[0], "visit:", color=text_color)
578 fig.text(0.85, heights[0], f"{source['visit']}")
579 fig.text(0.95, heights[0], f"{source['band']}")
581 fig.text(0.0, heights[1], "ra:", color=text_color)
582 fig.text(0.037, heights[1], f"{source['ra']:.8f}")
583 fig.text(0.21, heights[1], "dec:", color=text_color)
584 fig.text(0.265, heights[1], f"{source['dec']:+.8f}")
585 fig.text(0.50, heights[1], "detection S/N:", color=text_color)
586 fig.text(0.66, heights[1], f"{source['snr']:6.1f}")
587 fig.text(0.75, heights[1], "PSF chi2:", color=text_color)
588 fig.text(0.85, heights[1], f"{source['psfChi2']/source['psfNdata']:6.2f}")
590 fig.text(0.0, heights[2], "PSF (nJy):", color=flag_color if any(source[flags_psf]) else text_color)
591 fig.text(0.25, heights[2], f"{source['psfFlux']:8.1f}", horizontalalignment='right')
592 fig.text(0.252, heights[2], "+/-", color=text_color)
593 fig.text(0.29, heights[2], f"{source['psfFluxErr']:8.1f}")
594 fig.text(0.40, heights[2], "S/N:", color=text_color)
595 fig.text(0.45, heights[2], f"{abs(source['psfFlux']/source['psfFluxErr']):6.2f}")
597 # NOTE: yellow is hard to read on white; use goldenrod instead.
598 if any(source[flags_edge]):
599 fig.text(0.55, heights[2], "EDGE", color="goldenrod", fontweight="bold")
600 if any(source[flags_interp]):
601 fig.text(0.62, heights[2], "INTERP", color="green", fontweight="bold")
602 if any(source[flags_saturated]):
603 fig.text(0.72, heights[2], "SAT", color="green", fontweight="bold")
604 if any(source[flags_cr]):
605 fig.text(0.77, heights[2], "CR", color="magenta", fontweight="bold")
606 if any(source[flags_bad]):
607 fig.text(0.81, heights[2], "BAD", color="red", fontweight="bold")
608 if source['isDipole']:
609 fig.text(0.87, heights[2], "DIPOLE", color="indigo", fontweight="bold")
611 fig.text(0.0, heights[3], "ap (nJy):", color=flag_color if any(source[flags_aperture]) else text_color)
612 fig.text(0.25, heights[3], f"{source['apFlux']:8.1f}", horizontalalignment='right')
613 fig.text(0.252, heights[3], "+/-", color=text_color)
614 fig.text(0.29, heights[3], f"{source['apFluxErr']:8.1f}")
615 fig.text(0.40, heights[3], "S/N:", color=text_color)
616 fig.text(0.45, heights[3], f"{abs(source['apFlux']/source['apFluxErr']):#6.2f}")
618 if any(source[flags_suspect]):
619 fig.text(0.55, heights[3], "SUS", color="goldenrod", fontweight="bold")
620 if any(source[flags_centroid]):
621 fig.text(0.60, heights[3], "CENTROID", color="red", fontweight="bold")
622 if any(source[flags_shape]):
623 fig.text(0.73, heights[3], "SHAPE", color="red", fontweight="bold")
624 # Future option: to add two more flag flavors to the legend,
625 # use locations 0.80 and 0.87
627 # rb score
628 if source['reliability'] is not None and np.isfinite(source['reliability']):
629 fig.text(0.73, heights[4], f"RB:{source['reliability']:.03f}",
630 color='#e41a1c' if source['reliability'] < 0.5 else '#4daf4a',
631 fontweight="bold")
633 fig.text(0.0, heights[4], "sci (nJy):", color=flag_color if any(source[flags_forced]) else text_color)
634 fig.text(0.25, heights[4], f"{source['scienceFlux']:8.1f}", horizontalalignment='right')
635 fig.text(0.252, heights[4], "+/-", color=text_color)
636 fig.text(0.29, heights[4], f"{source['scienceFluxErr']:8.1f}")
637 fig.text(0.40, heights[4], "S/N:", color=text_color)
638 fig.text(0.45, heights[4], f"{abs(source['scienceFlux']/source['scienceFluxErr']):6.2f}")
639 fig.text(0.55, heights[4], "ABmag:", color=text_color)
640 fig.text(0.635, heights[4], f"{(source['scienceFlux']*u.nanojansky).to_value(u.ABmag):.3f}")
643class CutoutPath:
644 """Manage paths to image cutouts with filenames based on diaSourceId.
646 Supports local files, and id-chunked directories.
648 Parameters
649 ----------
650 root : `str`
651 Root file path to manage.
652 chunk_size : `int`, optional
653 At most this many files per directory. Must be a power of 10.
654 subdirectory : `str`, optional
655 Name of the subdirectory
657 Raises
658 ------
659 RuntimeError
660 Raised if chunk_size is not a power of 10.
661 """
663 def __init__(self, root, chunk_size=None, subdirectory='images'):
664 self._root = root
665 if chunk_size is not None and (log10(chunk_size) != int(log10(chunk_size))):
666 raise RuntimeError(f"CutoutPath file chunk_size must be a power of 10, got {chunk_size}.")
667 self._chunk_size = chunk_size
668 self._subdirectory = subdirectory
670 def directory(self, id):
671 """Return the directory to store the output in.
673 Parameters
674 ----------
675 id : `int`
676 Source id to create the path for.
678 Returns
679 -------
680 directory: `str`
681 Directory for this file.
682 """
684 def chunker(id, size):
685 return (id // size)*size
687 if self._chunk_size is not None:
688 return os.path.join(self._root,
689 f"{self._subdirectory}/{chunker(id, self._chunk_size)}")
690 else:
691 return os.path.join(self._root, f"{self._subdirectory}")
693 def __call__(self, id, filename):
694 """Return the full path to a diaSource cutout.
696 Parameters
697 ----------
698 id : `int`
699 Source id to create the path for.
700 filename: `str`
701 Filename to write.
703 Returns
704 -------
705 path : `str`
706 Full path to the requested file.
707 """
709 return os.path.join(self.directory(id), filename)
711 def exists(self, id, filename):
712 """Return True if the file already exists.
714 Parameters
715 ----------
716 id : `int`
717 Source id to create the path for.
718 filename: `str`
719 Filename to write.
721 Returns
722 -------
723 exists : `bool`
724 Does the supplied filename exist?
725 """
727 return os.path.exists(os.path.join(self.directory(id), filename))
729 def mkdir(self, id):
730 """Make the directory tree to write this cutout id to.
732 Parameters
733 ----------
734 id : `int`
735 Source id to create the path for.
736 """
737 os.makedirs(self.directory(id), exist_ok=True)
740def build_argparser():
741 """Construct an argument parser for the ``plotImageSubtractionCutouts``
742 script.
744 Returns
745 -------
746 argparser : `argparse.ArgumentParser`
747 The argument parser that defines the ``plotImageSubtractionCutouts``
748 command-line interface.
749 """
750 parser = argparse.ArgumentParser(
751 description=__doc__,
752 formatter_class=argparse.RawDescriptionHelpFormatter,
753 epilog="More information is available at https://pipelines.lsst.io.",
754 )
756 apdbArgs = parser.add_mutually_exclusive_group(required=True)
757 apdbArgs.add_argument(
758 "--sqlitefile",
759 default=None,
760 help="Path to sqlite file to load from; required for sqlite connection.",
761 )
762 apdbArgs.add_argument(
763 "--namespace",
764 default=None,
765 help="Postgres namespace (aka schema) to connect to; "
766 " required for postgres connections."
767 )
769 parser.add_argument(
770 "--postgres_url",
771 default="rubin@usdf-prompt-processing-dev.slac.stanford.edu/lsst-devl",
772 help="Postgres connection path, or default (None) to use ApdbPostgresQuery default."
773 )
775 parser.add_argument(
776 "--limit",
777 default=5,
778 type=int,
779 help="Number of sources to load from the APDB (default=5), or the "
780 "number of sources to load per 'page' when `--all` is set. "
781 "This should be significantly larger (100x or more) than the value of `-j`, "
782 "to ensure efficient use of each process.",
783 )
784 parser.add_argument(
785 "--all",
786 default=False,
787 action="store_true",
788 help="Process all the sources; --limit then becomes the 'page size' to chunk the DB into.",
789 )
791 parser.add_argument(
792 "-j",
793 "--jobs",
794 default=0,
795 type=int,
796 help="Number of processes to use when generating cutouts. "
797 "Specify 0 (the default) to not use multiprocessing at all. "
798 "Note that `--limit` determines how efficiently each process is filled."
799 )
801 parser.add_argument(
802 "-C",
803 "--configFile",
804 help="File containing the PlotImageSubtractionCutoutsConfig to load.",
805 )
806 parser.add_argument(
807 "--collections",
808 nargs="*",
809 help=(
810 "Butler collection(s) to load data from."
811 " If not specified, will search all butler collections, "
812 "which may be very slow."
813 ),
814 )
815 parser.add_argument("repo", help="Path to Butler repository to load data from.")
816 parser.add_argument(
817 "outputPath",
818 help="Path to write the output images and manifest to; "
819 "manifest is written here, while the images go to `OUTPUTPATH/images/`.",
820 )
821 parser.add_argument(
822 "--reliabilityMin",
823 type=float,
824 default=None,
825 help="Minimum reliability value (default=None) on which to filter the DiaSources.",
826 )
827 parser.add_argument(
828 "--reliabilityMax",
829 type=float,
830 default=None,
831 help="Maximum reliability value (default=None) on which to filter the DiaSources.",
832 )
833 return parser
836def _make_apdbQuery(sqlitefile=None, postgres_url=None, namespace=None):
837 """Return a query connection to the specified APDB.
839 Parameters
840 ----------
841 sqlitefile : `str`, optional
842 SQLite file to load APDB from; if set, postgres kwargs are ignored.
843 postgres_url : `str`, optional
844 Postgres connection URL to connect to APDB.
845 namespace : `str`, optional
846 Postgres schema to load from; required with postgres_url.
848 Returns
849 -------
850 apdb_query : `lsst.analysis.ap.ApdbQuery`
851 Query instance to use to load data from APDB.
853 Raises
854 ------
855 RuntimeError
856 Raised if the APDB connection kwargs are invalid in some way.
857 """
858 if sqlitefile is not None:
859 apdb_query = apdb.ApdbSqliteQuery(sqlitefile)
860 elif postgres_url is not None and namespace is not None:
861 apdb_query = apdb.ApdbPostgresQuery(namespace, postgres_url)
862 else:
863 raise RuntimeError("Cannot handle database connection args: "
864 f"sqlitefile={sqlitefile}, postgres_url={postgres_url}, namespace={namespace}")
865 return apdb_query
868def select_sources(apdb_query, limit, reliabilityMin=None, reliabilityMax=None):
869 """Load an APDB and return n sources from it.
871 Parameters
872 ----------
873 apdb_query : `lsst.analysis.ap.ApdbQuery`
874 APDB query interface to load from.
875 limit : `int`
876 Number of sources to select from the APDB.
877 reliabilityMin : `float`
878 Minimum reliability value on which to filter the DiaSources.
879 reliabilityMax : `float`
880 Maximum reliability value on which to filter the DiaSources.
882 Returns
883 -------
884 sources : `pandas.DataFrame`
885 The loaded DiaSource data.
886 """
887 offset = 0
888 try:
889 while True:
890 with apdb_query.connection as connection:
891 table = apdb_query._tables["DiaSource"]
892 query = table.select()
893 if reliabilityMin is not None:
894 query = query.where(table.columns['reliability'] >= reliabilityMin)
895 if reliabilityMax is not None:
896 query = query.where(table.columns['reliability'] <= reliabilityMax)
897 query = query.order_by(table.columns["visit"],
898 table.columns["detector"],
899 table.columns["diaSourceId"])
900 query = query.limit(limit).offset(offset)
901 sources = pd.read_sql_query(query, connection)
902 if len(sources) == 0:
903 break
904 apdb_query._fill_from_instrument(sources)
906 yield sources
907 offset += limit
908 finally:
909 connection.close()
912def len_sources(apdb_query, namespace=None):
913 """Return the number of DiaSources in the supplied APDB.
915 Parameters
916 ----------
917 apdb_query : `lsst.analysis.ap.ApdbQuery`
918 APDB query interface to load from.
919 namespace : `str`, optional
920 Postgres schema to load data from.
922 Returns
923 -------
924 count : `int`
925 Number of diaSources in this APDB.
926 """
927 with apdb_query.connection as connection:
928 if namespace:
929 connection.execute(sqlalchemy.text(f"SET search_path TO {namespace}"))
930 count = connection.execute(sqlalchemy.text('select count(*) FROM "DiaSource";')).scalar()
931 return count
934def run_cutouts(args):
935 """Run PlotImageSubtractionCutoutsTask on the parsed commandline arguments.
937 Parameters
938 ----------
939 args : `argparse.Namespace`
940 Parsed commandline arguments.
941 """
942 # We have to initialize the logger manually on the commandline.
943 logging.basicConfig(
944 level=logging.INFO, format="{name} {levelname}: {message}", style="{"
945 )
947 butler = lsst.daf.butler.Butler(args.repo, collections=args.collections)
948 apdb_query = _make_apdbQuery(sqlitefile=args.sqlitefile,
949 postgres_url=args.postgres_url,
950 namespace=args.namespace)
952 config = PlotImageSubtractionCutoutsConfig()
953 if args.configFile is not None:
954 config.load(os.path.expanduser(args.configFile))
955 config.freeze()
956 cutouts = PlotImageSubtractionCutoutsTask(config=config, output_path=args.outputPath)
958 if config.save_as_numpy:
959 # save the RB output up front so we can use partial runs
960 data = select_sources(apdb_query, args.limit, args.reliabilityMin, args.reliabilityMax)
961 cols_to_export = ["diaSourceId", "visit", "detector", "diaObjectId",
962 "ssObjectId", "midpointMjdTai", "ra", "dec", "x", "y",
963 "apFlux", "apFluxErr", "snr", "psfFlux", "psfFluxErr",
964 "isDipole", "trailLength", "band", "extendedness",
965 "pixelFlags_bad", "pixelFlags_cr", "pixelFlags_crCenter",
966 "pixelFlags_edge", "pixelFlags_interpolated", "pixelFlags_interpolatedCenter",
967 "pixelFlags_offimage", "pixelFlags_saturated", "pixelFlags_saturatedCenter",
968 "pixelFlags_suspect", "pixelFlags_suspectCenter", "pixelFlags_streak",
969 "pixelFlags_streakCenter", "pixelFlags_injected", "pixelFlags_injectedCenter",
970 "pixelFlags_injected_template", "pixelFlags_injected_templateCenter"]
971 # this is inefficient but otherwise we don't use the same query
972 all_data = pd.concat([d[cols_to_export] for d in data])
973 all_data.to_csv(os.path.join(args.outputPath, "all_diasources.csv.gz"), index=False)
975 getter = select_sources(apdb_query, args.limit, args.reliabilityMin, args.reliabilityMax)
976 # Process just one block of length "limit", or all sources in the database?
977 if not args.all:
978 data = next(getter)
979 sources = cutouts.run(data, butler, njobs=args.jobs)
980 else:
981 sources = []
982 count = len_sources(apdb_query, args.namespace)
983 for i, data in enumerate(getter):
984 sources.extend(cutouts.write_images(data, butler, njobs=args.jobs))
985 print(f"Completed {i+1} batches of {args.limit} size, out of {count} diaSources.")
986 cutouts.write_manifest(sources)
988 if config.save_as_numpy:
989 # Write a dataframe with only diasources successfully written.
990 data.loc[data['diaSourceId'].isin(sources), cols_to_export].to_csv(
991 os.path.join(args.outputPath, "exported_diasources.csv.gz"), index=False)
993 print(f"Generated {len(sources)} diaSource cutouts to {args.outputPath}.")
996def main():
997 args = build_argparser().parse_args()
998 run_cutouts(args)