Coverage for python/lsst/meas/transiNet/rbTransiNetTask.py: 44%
46 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-09 03:11 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-09 03:11 -0700
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22__all__ = ["RBTransiNetTask", "RBTransiNetConfig"]
24import lsst.geom
25import lsst.pex.config
26import lsst.pipe.base
27import numpy as np
29from . import rbTransiNetInterface
32class RBTransiNetConnections(lsst.pipe.base.PipelineTaskConnections,
33 dimensions=("instrument", "visit", "detector"),
34 defaultTemplates={"coaddName": "deep", "fakesType": ""}):
35 # NOTE: Do we want the "ready to difference" template, or something
36 # earlier? This one is warped, but not PSF-matched.
37 template = lsst.pipe.base.connectionTypes.Input(
38 doc="Input warped template to subtract.",
39 dimensions=("instrument", "visit", "detector"),
40 storageClass="ExposureF",
41 name="{fakesType}{coaddName}Diff_templateExp"
42 )
43 science = lsst.pipe.base.connectionTypes.Input(
44 doc="Input science exposure to subtract from.",
45 dimensions=("instrument", "visit", "detector"),
46 storageClass="ExposureF",
47 name="{fakesType}calexp"
48 )
49 difference = lsst.pipe.base.connectionTypes.Input(
50 doc="Result of subtracting convolved template from science image.",
51 dimensions=("instrument", "visit", "detector"),
52 storageClass="ExposureF",
53 name="{fakesType}{coaddName}Diff_differenceExp",
54 )
55 diaSources = lsst.pipe.base.connectionTypes.Input(
56 doc="Detected sources on the difference image.",
57 dimensions=("instrument", "visit", "detector"),
58 storageClass="SourceCatalog",
59 name="{fakesType}{coaddName}Diff_diaSrc",
60 )
62 # Outputs
63 classifications = lsst.pipe.base.connectionTypes.Output(
64 doc="Catalog of real/bogus classifications for each diaSource, "
65 "element-wise aligned with diaSources.",
66 dimensions=("instrument", "visit", "detector"),
67 storageClass="Catalog",
68 name="{fakesType}{coaddName}RealBogusSources",
69 )
72class RBTransiNetConfig(lsst.pipe.base.PipelineTaskConfig, pipelineConnections=RBTransiNetConnections):
73 modelPackageName = lsst.pex.config.Field(
74 dtype=str,
75 doc=("A unique identifier of a model package. ")
76 )
77 modelPackageStorageMode = lsst.pex.config.ChoiceField(
78 dtype=str,
79 doc=("A string that indicates _where_ and _how_ the model package is stored."),
80 allowed={'local': 'packages stored in the meas_transiNet repository',
81 'neighbor': 'packages stored in the rbClassifier_data repository',
82 },
83 default='neighbor',
84 )
85 cutoutSize = lsst.pex.config.Field(
86 dtype=int,
87 doc="Width/height of square cutouts to send to classifier.",
88 default=256,
89 )
92class RBTransiNetTask(lsst.pipe.base.PipelineTask):
93 """Task for running TransiNet real/bogus classification on the output of
94 the image subtraction pipeline.
95 """
96 _DefaultName = "rbTransiNet"
97 ConfigClass = RBTransiNetConfig
99 def __init__(self, **kwargs):
100 super().__init__(**kwargs)
102 self.interface = rbTransiNetInterface.RBTransiNetInterface(self.config.modelPackageName,
103 self.config.modelPackageStorageMode)
105 def run(self, template, science, difference, diaSources):
106 cutouts = [self._make_cutouts(template, science, difference, source) for source in diaSources]
107 self.log.info("Extracted %d cutouts.", len(cutouts))
108 scores = self.interface.infer(cutouts)
109 self.log.info("Scored %d cutouts.", len(scores))
110 schema = lsst.afw.table.Schema()
111 schema.addField(diaSources.schema["id"].asField())
112 schema.addField("score", doc="real/bogus score of this source", type=float)
113 classifications = lsst.afw.table.BaseCatalog(schema)
114 classifications.resize(len(scores))
116 classifications["id"] = diaSources["id"]
117 classifications["score"] = scores
119 return lsst.pipe.base.Struct(classifications=classifications)
121 def _make_cutouts(self, template, science, difference, source):
122 """Return cutouts of each image centered at the source location.
124 Parameters
125 ----------
126 template : `lsst.afw.image.ExposureF`
127 science : `lsst.afw.image.ExposureF`
128 difference : `lsst.afw.image.ExposureF`
129 Exposures to cut images out of.
130 source : `lsst.afw.table.SourceRecord`
131 Source to make cutouts of.
133 Returns
134 -------
135 cutouts, `lsst.meas.transiNet.CutoutInputs`
136 Cutouts of each of the input images.
137 """
139 # Try to create cutouts, or simply return empty cutouts if
140 # failed (most probably out-of-border box)
141 extent = lsst.geom.Extent2I(self.config.cutoutSize)
142 box = lsst.geom.Box2I.makeCenteredBox(source.getCentroid(), extent)
144 if science.getBBox().contains(box):
145 science_cutout = science.Factory(science, box).image.array
146 template_cutout = template.Factory(template, box).image.array
147 difference_cutout = difference.Factory(difference, box).image.array
148 else:
149 science_cutout = np.zeros((self.config.cutoutSize, self.config.cutoutSize), dtype=np.float32)
150 template_cutout = np.zeros_like(science_cutout)
151 difference_cutout = np.zeros_like(science_cutout)
153 return rbTransiNetInterface.CutoutInputs(science=science_cutout,
154 template=template_cutout,
155 difference=difference_cutout)