Coverage for python/lsst/meas/transiNet/rbTransiNetTask.py: 48%
48 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-01 10:20 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-01 10:20 +0000
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22__all__ = ["RBTransiNetTask", "RBTransiNetConfig"]
24import lsst.geom
25import lsst.pex.config
26import lsst.pipe.base
27from lsst.utils.timer import timeMethod
28import numpy as np
30from . import rbTransiNetInterface
33class RBTransiNetConnections(lsst.pipe.base.PipelineTaskConnections,
34 dimensions=("instrument", "visit", "detector"),
35 defaultTemplates={"coaddName": "deep", "fakesType": ""}):
36 # NOTE: Do we want the "ready to difference" template, or something
37 # earlier? This one is warped, but not PSF-matched.
38 template = lsst.pipe.base.connectionTypes.Input(
39 doc="Input warped template to subtract.",
40 dimensions=("instrument", "visit", "detector"),
41 storageClass="ExposureF",
42 name="{fakesType}{coaddName}Diff_templateExp"
43 )
44 science = lsst.pipe.base.connectionTypes.Input(
45 doc="Input science exposure to subtract from.",
46 dimensions=("instrument", "visit", "detector"),
47 storageClass="ExposureF",
48 name="{fakesType}calexp"
49 )
50 difference = lsst.pipe.base.connectionTypes.Input(
51 doc="Result of subtracting convolved template from science image.",
52 dimensions=("instrument", "visit", "detector"),
53 storageClass="ExposureF",
54 name="{fakesType}{coaddName}Diff_differenceExp",
55 )
56 diaSources = lsst.pipe.base.connectionTypes.Input(
57 doc="Detected sources on the difference image.",
58 dimensions=("instrument", "visit", "detector"),
59 storageClass="SourceCatalog",
60 name="{fakesType}{coaddName}Diff_diaSrc",
61 )
63 # Outputs
64 classifications = lsst.pipe.base.connectionTypes.Output(
65 doc="Catalog of real/bogus classifications for each diaSource, "
66 "element-wise aligned with diaSources.",
67 dimensions=("instrument", "visit", "detector"),
68 storageClass="Catalog",
69 name="{fakesType}{coaddName}RealBogusSources",
70 )
73class RBTransiNetConfig(lsst.pipe.base.PipelineTaskConfig, pipelineConnections=RBTransiNetConnections):
74 modelPackageName = lsst.pex.config.Field(
75 dtype=str,
76 doc=("A unique identifier of a model package. ")
77 )
78 modelPackageStorageMode = lsst.pex.config.ChoiceField(
79 dtype=str,
80 doc=("A string that indicates _where_ and _how_ the model package is stored."),
81 allowed={'local': 'packages stored in the meas_transiNet repository',
82 'neighbor': 'packages stored in the rbClassifier_data repository',
83 },
84 default='neighbor',
85 )
86 cutoutSize = lsst.pex.config.Field(
87 dtype=int,
88 doc="Width/height of square cutouts to send to classifier.",
89 default=256,
90 )
93class RBTransiNetTask(lsst.pipe.base.PipelineTask):
94 """Task for running TransiNet real/bogus classification on the output of
95 the image subtraction pipeline.
96 """
97 _DefaultName = "rbTransiNet"
98 ConfigClass = RBTransiNetConfig
100 def __init__(self, **kwargs):
101 super().__init__(**kwargs)
103 self.interface = rbTransiNetInterface.RBTransiNetInterface(self.config.modelPackageName,
104 self.config.modelPackageStorageMode)
106 @timeMethod
107 def run(self, template, science, difference, diaSources):
108 cutouts = [self._make_cutouts(template, science, difference, source) for source in diaSources]
109 self.log.info("Extracted %d cutouts.", len(cutouts))
110 scores = self.interface.infer(cutouts)
111 self.log.info("Scored %d cutouts.", len(scores))
112 schema = lsst.afw.table.Schema()
113 schema.addField(diaSources.schema["id"].asField())
114 schema.addField("score", doc="real/bogus score of this source", type=float)
115 classifications = lsst.afw.table.BaseCatalog(schema)
116 classifications.resize(len(scores))
118 classifications["id"] = diaSources["id"]
119 classifications["score"] = scores
121 return lsst.pipe.base.Struct(classifications=classifications)
123 def _make_cutouts(self, template, science, difference, source):
124 """Return cutouts of each image centered at the source location.
126 Parameters
127 ----------
128 template : `lsst.afw.image.ExposureF`
129 science : `lsst.afw.image.ExposureF`
130 difference : `lsst.afw.image.ExposureF`
131 Exposures to cut images out of.
132 source : `lsst.afw.table.SourceRecord`
133 Source to make cutouts of.
135 Returns
136 -------
137 cutouts, `lsst.meas.transiNet.CutoutInputs`
138 Cutouts of each of the input images.
139 """
141 # Try to create cutouts, or simply return empty cutouts if
142 # failed (most probably out-of-border box)
143 extent = lsst.geom.Extent2I(self.config.cutoutSize)
144 box = lsst.geom.Box2I.makeCenteredBox(source.getCentroid(), extent)
146 if science.getBBox().contains(box):
147 science_cutout = science.Factory(science, box).image.array
148 template_cutout = template.Factory(template, box).image.array
149 difference_cutout = difference.Factory(difference, box).image.array
150 else:
151 science_cutout = np.zeros((self.config.cutoutSize, self.config.cutoutSize), dtype=np.float32)
152 template_cutout = np.zeros_like(science_cutout)
153 difference_cutout = np.zeros_like(science_cutout)
155 return rbTransiNetInterface.CutoutInputs(science=science_cutout,
156 template=template_cutout,
157 difference=difference_cutout)