Coverage for python/lsst/analysis/tools/interfaces/_task.py: 20%

149 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-28 05:16 -0700

1# This file is part of analysis_tools. # 

2# Developed for the LSST Data Management System. 

3# This product includes software developed by the LSST Project 

4# (https://www.lsst.org). 

5# See the COPYRIGHT file at the top-level directory of this distribution 

6# for details of code ownership. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the GNU General Public License 

19# along with this program. If not, see <https://www.gnu.org/licenses/>. 

20from __future__ import annotations 

21 

22"""Base class implementation for the classes needed in creating `PipelineTasks` 

23which execute `AnalysisTools`. 

24 

25The classes defined in this module have all the required behaviors for 

26defining, introspecting, and executing `AnalysisTools` against an input dataset 

27type. 

28 

29Subclasses of these tasks should specify specific datasets to consume in their 

30connection classes and should specify a unique name 

31""" 

32 

33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask") 

34 

35from copy import deepcopy 

36from typing import TYPE_CHECKING, Any, Iterable, Mapping, MutableMapping, cast 

37 

38from lsst.verify import Measurement 

39 

40if TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true

41 from lsst.daf.butler import DeferredDatasetHandle 

42 

43from lsst.daf.butler import DataCoordinate 

44from lsst.pex.config import Field, ListField 

45from lsst.pex.config.configurableActions import ConfigurableActionStructField 

46from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

47from lsst.pipe.base import connectionTypes as ct 

48from lsst.pipe.base.butlerQuantumContext import ButlerQuantumContext 

49from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection 

50 

51from ._actions import JointAction, MetricAction, NoMetric 

52from ._analysisTools import AnalysisTool 

53from ._interfaces import KeyedData, PlotTypes 

54from ._metricMeasurementBundle import MetricMeasurementBundle 

55 

56 

57class AnalysisBaseConnections( 

58 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"} 

59): 

60 r"""Base class for Connections used for AnalysisTools PipelineTasks. 

61 

62 This class has a pre-defined output connection for the 

63 MetricMeasurementMapping. The dataset type name for this connection is 

64 determined by the template ``outputName``. 

65 

66 Output connections for plots created by `AnalysisPlot`\ s are created 

67 dynamically when an instance of the class is created. The init method 

68 examines all the `AnalysisPlot` actions specified in the associated 

69 `AnalysisBaseConfig` subclass accumulating all the info needed to 

70 create the output connections. 

71 

72 The dimensions for all of the output connections (metric and plot) will 

73 be the same as the dimensions specified for the AnalysisBaseConnections 

74 subclass (i.e. quantum dimensions). 

75 """ 

76 

77 metrics = ct.Output( 

78 doc="Metrics calculated on input dataset type", 

79 name="{outputName}_metrics", 

80 storageClass="MetricMeasurementBundle", 

81 ) 

82 

83 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore 

84 # Validate that the outputName template has been set in config. This 

85 # should have been checked early with the configs validate method, but 

86 # it is possible for someone to manually create everything in a script 

87 # without running validate, so also check it late here. 

88 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore 

89 raise RuntimeError( 

90 "Subclasses must specify an alternative value for the defaultTemplate `outputName`" 

91 ) 

92 super().__init__(config=config) 

93 

94 # All arguments must be passed by kw, but python has not method to do 

95 # that without specifying a default, so None is used. Validate that 

96 # it is not None. This is largely for typing reasons, as in the normal 

97 # course of operation code execution paths ensure this will not be None 

98 assert config is not None 

99 

100 for tool in config.atools: 

101 match tool.produce: 

102 case JointAction(): 

103 if isinstance(tool.produce.metric, NoMetric): 

104 continue 

105 if len(tool.produce.metric.units) != 0: 

106 hasMetrics = True 

107 break 

108 case MetricAction(): 

109 hasMetrics = True 

110 break 

111 else: 

112 hasMetrics = False 

113 

114 # Set the dimensions for the metric 

115 if hasMetrics: 

116 self.metrics = ct.Output( 

117 name=self.metrics.name, 

118 doc=self.metrics.doc, 

119 storageClass=self.metrics.storageClass, 

120 dimensions=self.dimensions, 

121 multiple=False, 

122 isCalibration=False, 

123 ) 

124 else: 

125 # There are no metrics to produce, remove the output connection 

126 self.outputs.remove("metrics") 

127 

128 # Look for any conflicting names, creating a set of them, as these 

129 # will be added to the instance as well as recorded in the outputs 

130 # set. 

131 existingNames = set(dir(self)) 

132 

133 # Accumulate all the names to be used from all of the defined 

134 # AnalysisPlots. 

135 names: list[str] = [] 

136 for action in config.atools: 

137 if action.parameterizedBand: 

138 for band in config.bands: 

139 names.extend(name.format(band=band) for name in action.getOutputNames()) 

140 else: 

141 names.extend(action.getOutputNames()) 

142 

143 # For each of the names found, create output connections. 

144 for name in names: 

145 name = f"{outputName}_{name}" 

146 if name in self.outputs or name in existingNames: 

147 raise NameError( 

148 f"Plot with name {name} conflicts with existing connection" 

149 " are two plots named the same?" 

150 ) 

151 outConnection = ct.Output( 

152 name=name, 

153 storageClass="Plot", 

154 doc="Dynamic connection for plotting", 

155 dimensions=self.dimensions, 

156 ) 

157 object.__setattr__(self, name, outConnection) 

158 self.outputs.add(name) 

159 

160 

161class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections): 

162 """Base class for all configs used to define an `AnalysisPipelineTask` 

163 

164 This base class defines two fields that should be used in all subclasses, 

165 atools, and bands. 

166 

167 The ``atools`` field is where the user configures which analysis tools will 

168 be run as part of this `PipelineTask`. 

169 

170 The bands field specifies which bands will be looped over for 

171 `AnalysisTools` which support parameterized bands. I.e. called once for 

172 each band in the list. 

173 """ 

174 

175 atools = ConfigurableActionStructField[AnalysisTool]( 

176 doc="The analysis tools that are to be run by this task at execution" 

177 ) 

178 # Temporarally alias these for backwards compatibility 

179 plots = atools 

180 metrics = atools 

181 bands = ListField[str]( 

182 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"] 

183 ) 

184 metric_tags = ListField[str]( 

185 doc="List of tags which will be added to all configurable actions", default=[] 

186 ) 

187 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True) 

188 reference_package = Field[str]( 

189 doc="A package who's version, at the time of metric upload to a " 

190 "time series database, will be converted to a timestamp of when " 

191 "that version was produced", 

192 default="lsst_distrib", 

193 ) 

194 timestamp_version = Field[str]( 194 ↛ exitline 194 didn't jump to the function exit

195 doc="Which time stamp should be used as the reference timestamp for a " 

196 "metric in a time series database, valid values are; " 

197 "reference_package_timestamp, run_timestamp, current_timestamp, " 

198 "and dataset_timestamp", 

199 default="run_timestamp", 

200 check=lambda x: x 

201 in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"), 

202 ) 

203 

204 def freeze(self): 

205 # Copy the meta configuration values to each of the configured tools 

206 # only do this if the tool has not been further specialized 

207 if not self._frozen: 

208 for tool in self.atools: 

209 for tag in self.metric_tags: 

210 tool.metric_tags.insert(-1, tag) 

211 if tool.dataset_identifier is None: 

212 tool.dataset_identifier = self.dataset_identifier 

213 if tool.reference_package == "lsst_distrib": 

214 tool.reference_package = self.reference_package 

215 if tool.timestamp_version == "run_timestamp": 

216 tool.timestamp_version = self.timestamp_version 

217 super().freeze() 

218 

219 def validate(self): 

220 super().validate() 

221 # Validate that the required connections template is set. 

222 if self.connections.outputName == "Placeholder": # type: ignore 

223 raise RuntimeError("Connections class 'outputName' must have a config explicitly set") 

224 

225 

226class _StandinPlotInfo(dict): 

227 """This class is an implementation detail to support plots in the instance 

228 no PlotInfo object is present in the call to run. 

229 """ 

230 

231 def __missing__(self, key): 

232 return "" 

233 

234 

235class AnalysisPipelineTask(PipelineTask): 

236 """Base class for `PipelineTasks` intended to run `AnalysisTools`. 

237 

238 The run method will run all of the `AnalysisTools` defined in the config 

239 class. 

240 """ 

241 

242 # Typing config because type checkers dont know about our Task magic 

243 config: AnalysisBaseConfig 

244 ConfigClass = AnalysisBaseConfig 

245 

246 def _runTools(self, data: KeyedData, **kwargs) -> Struct: 

247 results = Struct() 

248 results.metrics = MetricMeasurementBundle() # type: ignore 

249 # copy plot info to be sure each action sees its own copy 

250 plotInfo = kwargs.get("plotInfo") 

251 plotKey = f"{self.config.connections.outputName}_{{name}}" 

252 for name, action in self.config.atools.items(): 

253 kwargs["plotInfo"] = deepcopy(plotInfo) 

254 actionResult = action(data, **kwargs) 

255 metricAccumulate = [] 

256 for resultName, value in actionResult.items(): 

257 match value: 

258 case PlotTypes(): 

259 setattr(results, plotKey.format(name=resultName), value) 

260 case Measurement(): 

261 metricAccumulate.append(value) 

262 # only add the metrics if there are some 

263 if metricAccumulate: 

264 results.metrics[name] = metricAccumulate 

265 return results 

266 

267 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct: 

268 """Produce the outputs associated with this `PipelineTask` 

269 

270 Parameters 

271 ---------- 

272 data : `KeyedData` 

273 The input data from which all `AnalysisTools` will run and produce 

274 outputs. A side note, the python typing specifies that this can be 

275 None, but this is only due to a limitation in python where in order 

276 to specify that all arguments be passed only as keywords the 

277 argument must be given a default. This argument most not actually 

278 be None. 

279 **kwargs 

280 Additional arguments that are passed through to the `AnalysisTools` 

281 specified in the configuration. 

282 

283 Returns 

284 ------- 

285 results : `~lsst.pipe.base.Struct` 

286 The accumulated results of all the plots and metrics produced by 

287 this `PipelineTask`. 

288 

289 Raises 

290 ------ 

291 ValueError 

292 Raised if the supplied data argument is `None` 

293 """ 

294 if data is None: 

295 raise ValueError("data must not be none") 

296 if "bands" not in kwargs: 

297 kwargs["bands"] = list(self.config.bands) 

298 if "plotInfo" not in kwargs: 

299 kwargs["plotInfo"] = _StandinPlotInfo() 

300 kwargs["plotInfo"]["bands"] = kwargs["bands"] 

301 if "SN" not in kwargs["plotInfo"].keys(): 

302 kwargs["plotInfo"]["SN"] = "-" 

303 return self._runTools(data, **kwargs) 

304 

305 def runQuantum( 

306 self, 

307 butlerQC: ButlerQuantumContext, 

308 inputRefs: InputQuantizedConnection, 

309 outputRefs: OutputQuantizedConnection, 

310 ) -> None: 

311 """Override default runQuantum to load the minimal columns necessary 

312 to complete the action. 

313 

314 Parameters 

315 ---------- 

316 butlerQC : `ButlerQuantumContext` 

317 A butler which is specialized to operate in the context of a 

318 `lsst.daf.butler.Quantum`. 

319 inputRefs : `InputQuantizedConnection` 

320 Datastructure whose attribute names are the names that identify 

321 connections defined in corresponding `PipelineTaskConnections` 

322 class. The values of these attributes are the 

323 `lsst.daf.butler.DatasetRef` objects associated with the defined 

324 input/prerequisite connections. 

325 outputRefs : `OutputQuantizedConnection` 

326 Datastructure whose attribute names are the names that identify 

327 connections defined in corresponding `PipelineTaskConnections` 

328 class. The values of these attributes are the 

329 `lsst.daf.butler.DatasetRef` objects associated with the defined 

330 output connections. 

331 """ 

332 inputs = butlerQC.get(inputRefs) 

333 dataId = butlerQC.quantum.dataId 

334 plotInfo = self.parsePlotInfo(inputs, dataId) 

335 data = self.loadData(inputs["data"]) 

336 if "skymap" in inputs.keys(): 

337 skymap = inputs["skymap"] 

338 else: 

339 skymap = None 

340 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap) 

341 butlerQC.put(outputs, outputRefs) 

342 

343 def _populatePlotInfoWithDataId( 

344 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None 

345 ) -> None: 

346 """Update the plotInfo with the dataId values. 

347 

348 Parameters 

349 ---------- 

350 plotInfo : `dict` 

351 The plotInfo dictionary to update. 

352 dataId : `lsst.daf.butler.DataCoordinate` 

353 The dataId to use to update the plotInfo. 

354 """ 

355 if dataId is not None: 

356 for dataInfo in dataId: 

357 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

358 

359 def parsePlotInfo( 

360 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data" 

361 ) -> Mapping[str, str]: 

362 """Parse the inputs and dataId to get the information needed to 

363 to add to the figure. 

364 

365 Parameters 

366 ---------- 

367 inputs: `dict` 

368 The inputs to the task 

369 dataCoordinate: `lsst.daf.butler.DataCoordinate` 

370 The dataId that the task is being run on. 

371 connectionName: `str`, optional 

372 Name of the input connection to use for determining table name. 

373 

374 Returns 

375 ------- 

376 plotInfo : `dict` 

377 """ 

378 

379 if inputs is None: 

380 tableName = "" 

381 run = "" 

382 else: 

383 tableName = inputs[connectionName].ref.datasetType.name 

384 run = inputs[connectionName].ref.run 

385 

386 # Initialize the plot info dictionary 

387 plotInfo = {"tableName": tableName, "run": run} 

388 

389 self._populatePlotInfoWithDataId(plotInfo, dataId) 

390 return plotInfo 

391 

392 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData: 

393 """Load the minimal set of keyed data from the input dataset. 

394 

395 Parameters 

396 ---------- 

397 handle : `DeferredDatasetHandle` 

398 Handle to load the dataset with only the specified columns. 

399 names : `Iterable` of `str` 

400 The names of keys to extract from the dataset. 

401 If `names` is `None` then the `collectInputNames` method 

402 is called to generate the names. 

403 For most purposes these are the names of columns to load from 

404 a catalog or data frame. 

405 

406 Returns 

407 ------- 

408 result: `KeyedData` 

409 The dataset with only the specified keys loaded. 

410 """ 

411 if names is None: 

412 names = self.collectInputNames() 

413 return cast(KeyedData, handle.get(parameters={"columns": names})) 

414 

415 def collectInputNames(self) -> Iterable[str]: 

416 """Get the names of the inputs. 

417 

418 If using the default `loadData` method this will gather the names 

419 of the keys to be loaded from an input dataset. 

420 

421 Returns 

422 ------- 

423 inputs : `Iterable` of `str` 

424 The names of the keys in the `KeyedData` object to extract. 

425 

426 """ 

427 inputs = set() 

428 for band in self.config.bands: 

429 for action in self.config.atools: 

430 for key, _ in action.getFormattedInputSchema(band=band): 

431 inputs.add(key) 

432 return inputs