Coverage for python/lsst/meas/transiNet/rbTransiNetInterface.py: 33%
49 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 11:35 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 11:35 +0000
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["RBTransiNetInterface", "CutoutInputs"]
24import numpy as np
25import dataclasses
26import torch
28from .modelPackages.nnModelPackage import NNModelPackage
31@dataclasses.dataclass(frozen=True, kw_only=True)
32class CutoutInputs:
33 """Science/template/difference cutouts of a single object plus other
34 metadata.
35 """
36 science: np.ndarray
37 template: np.ndarray
38 difference: np.ndarray
40 label: bool = None
41 """Known truth of whether this is a real or bogus object."""
44class RBTransiNetInterface:
45 """
46 The interface between the LSST AP pipeline and a trained pytorch-based
47 RBTransiNet neural network model.
49 Parameters
50 ----------
51 model_package_name : `str`
52 Name of the model package to load.
53 package_storage_mode : {'local', 'neighbor'}
54 Storage mode of the model package
55 device : `str`
56 Device to load and run the neural network on, e.g. 'cpu' or 'cuda:0'
57 """
59 def __init__(self, model_package_name, package_storage_mode, device='cpu'):
60 self.model_package_name = model_package_name
61 self.package_storage_mode = package_storage_mode
62 self.device = device
63 self.init_model()
65 def init_model(self):
66 """Create and initialize an NN model
67 """
68 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
69 self.model = model_package.load(self.device)
71 # Put the model in evaluation mode instead of training model.
72 self.model.eval()
74 def input_to_batches(self, inputs, batchSize):
75 """Convert a list of inputs to a generator of batches.
77 Parameters
78 ----------
79 inputs : `list` [`CutoutInputs`]
80 Inputs to be scored.
82 Returns
83 -------
84 batches : `generator`
85 Generator of batches of inputs.
86 """
87 for i in range(0, len(inputs), batchSize):
88 yield inputs[i:i + batchSize]
90 def prepare_input(self, inputs):
91 """
92 Convert inputs from numpy arrays, etc. to a torch.tensor blob.
94 Parameters
95 ----------
96 inputs : `list` [`CutoutInputs`]
97 Inputs to be scored.
99 Returns
100 -------
101 blob
102 Prepared torch tensor blob to run the model on.
103 labels
104 Truth labels, concatenated into a single list.
105 """
106 cutoutsList = []
107 labelsList = []
108 for inp in inputs:
109 # Convert each cutout to a torch tensor
110 template = torch.from_numpy(inp.template)
111 science = torch.from_numpy(inp.science)
112 difference = torch.from_numpy(inp.difference)
114 # Stack the components to create a single blob
115 singleBlob = torch.stack((template, science, difference), dim=0)
117 # And append them to the temporary list
118 cutoutsList.append(singleBlob)
120 labelsList.append(inp.label)
122 blob = torch.stack(cutoutsList)
123 return blob, labelsList
125 def infer(self, inputs):
126 """Return the score of this cutout.
128 Parameters
129 ----------
130 inputs : `list` [`CutoutInputs`]
131 Inputs to be scored.
133 Returns
134 -------
135 scores : `numpy.array`
136 Float scores for each element of ``inputs``.
137 """
139 # Convert the inputs to batches.
140 # TODO: The batch size is set to 64 for now. Later when
141 # deploying parallel instances of the task, memory limits
142 # should be taken into account, if necessary.
143 batches = self.input_to_batches(inputs, batchSize=64)
145 # Loop over the batches
146 for i, batch in enumerate(batches):
147 torchBlob, labelsList = self.prepare_input(batch)
149 # Run the model
150 with torch.no_grad():
151 output_ = self.model(torchBlob)
152 output = torch.sigmoid(output_)
154 # And append the results to the list
155 if i == 0:
156 scores = output
157 else:
158 scores = torch.cat((scores, output.cpu()), dim=0)
160 npyScores = scores.detach().numpy().ravel()
161 return npyScores