Coverage for python/lsst/meas/transiNet/rbTransiNetInterface.py: 30%
52 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 03:18 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 03:18 -0700
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["RBTransiNetInterface", "CutoutInputs"]
24import numpy as np
25import dataclasses
26import torch
28from .modelPackages.nnModelPackage import NNModelPackage
31@dataclasses.dataclass(frozen=True, kw_only=True)
32class CutoutInputs:
33 """Science/template/difference cutouts of a single object plus other
34 metadata.
35 """
36 science: np.ndarray
37 template: np.ndarray
38 difference: np.ndarray
40 label: bool = None
41 """Known truth of whether this is a real or bogus object."""
44class RBTransiNetInterface:
45 """ The interface between the LSST AP pipeline and a trained pytorch-based
46 RBTransiNet neural network model.
48 Parameters
49 ----------
50 task : `lsst.meas.transiNet.RBTransiNetTask`
51 The task that is using this interface: the 'left side'.
52 model_package_name : `str`
53 Name of the model package to load.
54 package_storage_mode : {'local', 'neighbor'}
55 Storage mode of the model package
56 device : `str`
57 Device to load and run the neural network on, e.g. 'cpu' or 'cuda:0'
58 """
60 def __init__(self, task, device='cpu'):
61 self.task = task
63 # in case the model package name is not set at this stage, it is not
64 # needed (e.g. in butler mode).
65 self.model_package_name = task.config.modelPackageName or 'N/A'
67 self.package_storage_mode = task.config.modelPackageStorageMode
68 self.device = device
69 self.init_model()
71 def init_model(self):
72 """Create and initialize an NN model
73 """
75 if self.package_storage_mode == 'butler' and self.task.butler_loaded_package is None:
76 raise RuntimeError("RBTransiNetInterface is trying to load a butler-mode NN model package, "
77 "but the RBTransiNetTask has not passed down a preloaded payload.")
79 model_package = NNModelPackage(model_package_name=self.model_package_name,
80 package_storage_mode=self.package_storage_mode,
81 butler_loaded_package=self.task.butler_loaded_package)
82 self.model = model_package.load(self.device)
84 # Put the model in evaluation mode instead of training model.
85 self.model.eval()
87 def input_to_batches(self, inputs, batchSize):
88 """Convert a list of inputs to a generator of batches.
90 Parameters
91 ----------
92 inputs : `list` [`CutoutInputs`]
93 Inputs to be scored.
95 Returns
96 -------
97 batches : `generator`
98 Generator of batches of inputs.
99 """
100 for i in range(0, len(inputs), batchSize):
101 yield inputs[i:i + batchSize]
103 def prepare_input(self, inputs):
104 """Convert inputs from numpy arrays, etc. to a torch.tensor blob.
106 Parameters
107 ----------
108 inputs : `list` [`CutoutInputs`]
109 Inputs to be scored.
111 Returns
112 -------
113 blob
114 Prepared torch tensor blob to run the model on.
115 labels
116 Truth labels, concatenated into a single list.
117 """
118 cutoutsList = []
119 labelsList = []
120 for inp in inputs:
121 # Convert each cutout to a torch tensor
122 template = torch.from_numpy(inp.template)
123 science = torch.from_numpy(inp.science)
124 difference = torch.from_numpy(inp.difference)
126 # Stack the components to create a single blob
127 singleBlob = torch.stack((template, science, difference), dim=0)
129 # And append them to the temporary list
130 cutoutsList.append(singleBlob)
132 labelsList.append(inp.label)
134 blob = torch.stack(cutoutsList)
135 return blob, labelsList
137 def infer(self, inputs):
138 """Return the score of this cutout.
140 Parameters
141 ----------
142 inputs : `list` [`CutoutInputs`]
143 Inputs to be scored.
145 Returns
146 -------
147 scores : `numpy.array`
148 Float scores for each element of ``inputs``.
149 """
151 # Convert the inputs to batches.
152 # TODO: The batch size is set to 64 for now. Later when
153 # deploying parallel instances of the task, memory limits
154 # should be taken into account, if necessary.
155 batches = self.input_to_batches(inputs, batchSize=64)
157 # Loop over the batches
158 for i, batch in enumerate(batches):
159 torchBlob, labelsList = self.prepare_input(batch)
161 # Run the model
162 with torch.no_grad():
163 output_ = self.model(torchBlob)
164 output = torch.sigmoid(output_)
166 # And append the results to the list
167 if i == 0:
168 scores = output
169 else:
170 scores = torch.cat((scores, output.cpu()), dim=0)
172 npyScores = scores.detach().numpy().ravel()
173 return npyScores