Coverage for python/lsst/meas/transiNet/rbTransiNetInterface.py: 43%
40 statements
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-29 03:40 -0700
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-29 03:40 -0700
1# This file is part of meas_transiNet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["RBTransiNetInterface", "CutoutInputs"]
24import numpy as np
25import dataclasses
26import torch
28from .modelPackages.nnModelPackage import NNModelPackage
31@dataclasses.dataclass(frozen=True, kw_only=True)
32class CutoutInputs:
33 """Science/template/difference cutouts of a single object plus other
34 metadata.
35 """
36 science: np.ndarray
37 template: np.ndarray
38 difference: np.ndarray
40 label: bool = None
41 """Known truth of whether this is a real or bogus object."""
44class RBTransiNetInterface:
45 """
46 The interface between the LSST AP pipeline and a trained pytorch-based
47 RBTransiNet neural network model.
49 Parameters
50 ----------
51 model_package_name : `str`
52 Name of the model package to load.
53 package_storage_mode : {'local', 'neighbor'}
54 Storage mode of the model package
55 device : `str`
56 Device to load and run the neural network on, e.g. 'cpu' or 'cuda:0'
57 """
59 def __init__(self, model_package_name, package_storage_mode, device='cpu'):
60 self.model_package_name = model_package_name
61 self.package_storage_mode = package_storage_mode
62 self.device = device
63 self.init_model()
65 def init_model(self):
66 """Create and initialize an NN model
67 """
68 model_package = NNModelPackage(self.model_package_name, self.package_storage_mode)
69 self.model = model_package.load(self.device)
71 # Put the model in evaluation mode instead of training model.
72 self.model.eval()
74 def prepare_input(self, inputs):
75 """
76 Convert inputs from numpy arrays, etc. to a torch.tensor blob.
78 Parameters
79 ----------
80 inputs : `list` [`CutoutInputs`]
81 Inputs to be scored.
83 Returns
84 -------
85 blob
86 Prepared torch tensor blob to run the model on.
87 labels
88 Truth labels, concatenated into a single list.
89 """
90 cutoutsList = []
91 labelsList = []
92 for inp in inputs:
93 # Convert each cutout to a torch tensor
94 template = torch.from_numpy(inp.template)
95 science = torch.from_numpy(inp.science)
96 difference = torch.from_numpy(inp.difference)
98 # Stack the components to create a single blob
99 singleBlob = torch.stack((template, science, difference), dim=0)
101 # And append them to the temporary list
102 cutoutsList.append(singleBlob)
104 labelsList.append(inp.label)
106 torchBlob = torch.stack(cutoutsList)
107 return torchBlob, labelsList
109 def infer(self, inputs):
110 """Return the score of this cutout.
112 Parameters
113 ----------
114 inputs : `list` [`CutoutInputs`]
115 Inputs to be scored.
117 Returns
118 -------
119 scores : `numpy.array`
120 Float scores for each element of ``inputs``.
121 """
122 blob, labels = self.prepare_input(inputs)
123 result = self.model(blob)
124 scores = torch.sigmoid(result)
125 npyScores = scores.detach().numpy().ravel()
127 return npyScores