Coverage for python/lsst/meas/transiNet/rbTransiNetTask.py: 42%
60 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-15 02:57 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-15 02:57 -0700
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22__all__ = ["RBTransiNetTask", "RBTransiNetConfig"]
24import lsst.geom
25import lsst.pex.config
26import lsst.pipe.base
27from lsst.utils.timer import timeMethod
28import numpy as np
30from . import rbTransiNetInterface
31from lsst.meas.transiNet.modelPackages.storageAdapterButler import StorageAdapterButler
34class RBTransiNetConnections(lsst.pipe.base.PipelineTaskConnections,
35 dimensions=("instrument", "visit", "detector"),
36 defaultTemplates={"coaddName": "deep", "fakesType": ""}):
37 # NOTE: Do we want the "ready to difference" template, or something
38 # earlier? This one is warped, but not PSF-matched.
39 template = lsst.pipe.base.connectionTypes.Input(
40 doc="Input warped template to subtract.",
41 dimensions=("instrument", "visit", "detector"),
42 storageClass="ExposureF",
43 name="{fakesType}{coaddName}Diff_templateExp"
44 )
45 science = lsst.pipe.base.connectionTypes.Input(
46 doc="Input science exposure to subtract from.",
47 dimensions=("instrument", "visit", "detector"),
48 storageClass="ExposureF",
49 name="{fakesType}calexp"
50 )
51 difference = lsst.pipe.base.connectionTypes.Input(
52 doc="Result of subtracting convolved template from science image.",
53 dimensions=("instrument", "visit", "detector"),
54 storageClass="ExposureF",
55 name="{fakesType}{coaddName}Diff_differenceExp",
56 )
57 diaSources = lsst.pipe.base.connectionTypes.Input(
58 doc="Detected sources on the difference image.",
59 dimensions=("instrument", "visit", "detector"),
60 storageClass="SourceCatalog",
61 name="{fakesType}{coaddName}Diff_candidateDiaSrc",
62 )
63 pretrainedModel = lsst.pipe.base.connectionTypes.PrerequisiteInput(
64 doc="Pretrained neural network model (-package) for the RBClassifier.",
65 dimensions=(),
66 storageClass="NNModelPackagePayload",
67 name=StorageAdapterButler.dataset_type_name,
68 )
70 # Outputs
71 classifications = lsst.pipe.base.connectionTypes.Output(
72 doc="Catalog of real/bogus classifications for each diaSource, "
73 "element-wise aligned with diaSources.",
74 dimensions=("instrument", "visit", "detector"),
75 storageClass="Catalog",
76 name="{fakesType}{coaddName}RealBogusSources",
77 )
79 def __init__(self, *, config=None):
80 super().__init__(config=config)
82 if self.config.modelPackageStorageMode != "butler":
83 del self.pretrainedModel
86class RBTransiNetConfig(lsst.pipe.base.PipelineTaskConfig, pipelineConnections=RBTransiNetConnections):
87 modelPackageName = lsst.pex.config.Field(
88 optional=True,
89 dtype=str,
90 doc=("A unique identifier of a model package. ")
91 )
92 modelPackageStorageMode = lsst.pex.config.ChoiceField(
93 dtype=str,
94 doc=("A string that indicates _where_ and _how_ the model package is stored."),
95 allowed={'local': 'packages stored in the meas_transiNet repository',
96 'neighbor': 'packages stored in the rbClassifier_data repository',
97 'butler': 'packages stored in the butler repository',
98 },
99 default='neighbor',
100 )
101 cutoutSize = lsst.pex.config.Field(
102 dtype=int,
103 doc="Width/height of square cutouts to send to classifier.",
104 default=256,
105 )
107 def validate(self):
108 # if we are in the butler mode, the user should not set
109 # a modelPackageName as a config field.
110 if self.modelPackageStorageMode == "butler":
111 if self.modelPackageName is not None:
112 raise ValueError("In a _real_ run of a pipeline when the "
113 "modelPackageStorageMode is 'butler', "
114 "the modelPackageName cannot be specified "
115 "as a config field. Pass it as a collection"
116 "name in the command-line instead.")
119class RBTransiNetTask(lsst.pipe.base.PipelineTask):
120 """Task for running TransiNet real/bogus classification on the output of
121 the image subtraction pipeline.
122 """
123 _DefaultName = "rbTransiNet"
124 ConfigClass = RBTransiNetConfig
126 def __init__(self, **kwargs):
127 super().__init__(**kwargs)
129 self.butler_loaded_package = None
131 @timeMethod
132 def run(self, template, science, difference, diaSources, pretrainedModel=None):
134 # Create the TransiNet interface object.
135 # Note: assuming each quanta creates one instance of this task, this is
136 # a proper place for doing this since loading of the model is run only
137 # once. However, if in the future we come up with a design in which one
138 # task instance is used for multiple quanta, this will need to be moved
139 # somewhere else -- e.g. to the __init__ method, or even to runQuantum.
140 self.butler_loaded_package = pretrainedModel # This will be used by the interface
141 self.interface = rbTransiNetInterface.RBTransiNetInterface(self)
143 cutouts = [self._make_cutouts(template, science, difference, source) for source in diaSources]
144 self.log.info("Extracted %d cutouts.", len(cutouts))
145 scores = self.interface.infer(cutouts)
146 self.log.info("Scored %d cutouts.", len(scores))
147 schema = lsst.afw.table.Schema()
148 schema.addField(diaSources.schema["id"].asField())
149 schema.addField("score", doc="real/bogus score of this source", type=float)
150 classifications = lsst.afw.table.BaseCatalog(schema)
151 classifications.resize(len(scores))
153 classifications["id"] = diaSources["id"]
154 classifications["score"] = scores
156 return lsst.pipe.base.Struct(classifications=classifications)
158 def _make_cutouts(self, template, science, difference, source):
159 """Return cutouts of each image centered at the source location.
161 Parameters
162 ----------
163 template : `lsst.afw.image.ExposureF`
164 science : `lsst.afw.image.ExposureF`
165 difference : `lsst.afw.image.ExposureF`
166 Exposures to cut images out of.
167 source : `lsst.afw.table.SourceRecord`
168 Source to make cutouts of.
170 Returns
171 -------
172 cutouts, `lsst.meas.transiNet.CutoutInputs`
173 Cutouts of each of the input images.
174 """
176 # Try to create cutouts, or simply return empty cutouts if
177 # failed (most probably out-of-border box)
178 extent = lsst.geom.Extent2I(self.config.cutoutSize)
179 box = lsst.geom.Box2I.makeCenteredBox(source.getCentroid(), extent)
181 if science.getBBox().contains(box):
182 science_cutout = science.Factory(science, box).image.array
183 template_cutout = template.Factory(template, box).image.array
184 difference_cutout = difference.Factory(difference, box).image.array
185 else:
186 science_cutout = np.zeros((self.config.cutoutSize, self.config.cutoutSize), dtype=np.float32)
187 template_cutout = np.zeros_like(science_cutout)
188 difference_cutout = np.zeros_like(science_cutout)
190 return rbTransiNetInterface.CutoutInputs(science=science_cutout,
191 template=template_cutout,
192 difference=difference_cutout)