Coverage for python/lsst/analysis/tools/interfaces/_task.py: 20%

170 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-29 11:33 +0000

1# This file is part of analysis_tools. # 

2# Developed for the LSST Data Management System. 

3# This product includes software developed by the LSST Project 

4# (https://www.lsst.org). 

5# See the COPYRIGHT file at the top-level directory of this distribution 

6# for details of code ownership. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the GNU General Public License 

19# along with this program. If not, see <https://www.gnu.org/licenses/>. 

20from __future__ import annotations 

21 

22"""Base class implementation for the classes needed in creating `PipelineTasks` 

23which execute `AnalysisTools`. 

24 

25The classes defined in this module have all the required behaviors for 

26defining, introspecting, and executing `AnalysisTools` against an input dataset 

27type. 

28 

29Subclasses of these tasks should specify specific datasets to consume in their 

30connection classes and should specify a unique name 

31""" 

32 

33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask") 

34 

35import weakref 

36from collections.abc import Iterable 

37from copy import deepcopy 

38from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast 

39 

40import matplotlib.pyplot as plt 

41from lsst.verify import Measurement 

42 

43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true

44 from lsst.daf.butler import DeferredDatasetHandle 

45 from lsst.pipe.base import QuantumContext 

46 

47from lsst.daf.butler import DataCoordinate 

48from lsst.pex.config import Field, ListField 

49from lsst.pex.config.configurableActions import ConfigurableActionStructField 

50from lsst.pipe.base import Instrument, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

51from lsst.pipe.base import connectionTypes as ct 

52from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection 

53from lsst.pipe.base.pipelineIR import ConfigIR, ParametersIR 

54 

55from ._actions import JointAction, MetricAction, NoMetric 

56from ._analysisTools import AnalysisTool 

57from ._interfaces import KeyedData, PlotTypes 

58from ._metricMeasurementBundle import MetricMeasurementBundle 

59 

60 

61# TODO: This _plotCloser function assists in closing all open plots at the 

62# conclusion of a PipelineTask. When DM-39114 is implemented, this function and 

63# all associated usage thereof should be removed. 

64def _plotCloser(*args): 

65 """Close all the plots in the given list.""" 

66 for plot in args: 

67 plt.close(plot) 

68 

69 

70class AnalysisBaseConnections( 

71 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"} 

72): 

73 r"""Base class for Connections used for AnalysisTools PipelineTasks. 

74 

75 This class has a pre-defined output connection for the 

76 MetricMeasurementMapping. The dataset type name for this connection is 

77 determined by the template ``outputName``. 

78 

79 Output connections for plots created by `AnalysisPlot`\ s are created 

80 dynamically when an instance of the class is created. The init method 

81 examines all the `AnalysisPlot` actions specified in the associated 

82 `AnalysisBaseConfig` subclass accumulating all the info needed to 

83 create the output connections. 

84 

85 The dimensions for all of the output connections (metric and plot) will 

86 be the same as the dimensions specified for the AnalysisBaseConnections 

87 subclass (i.e. quantum dimensions). 

88 """ 

89 

90 metrics = ct.Output( 

91 doc="Metrics calculated on input dataset type", 

92 name="{outputName}_metrics", 

93 storageClass="MetricMeasurementBundle", 

94 ) 

95 

96 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore 

97 # Validate that the outputName template has been set in config. This 

98 # should have been checked early with the configs validate method, but 

99 # it is possible for someone to manually create everything in a script 

100 # without running validate, so also check it late here. 

101 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore 

102 raise RuntimeError( 

103 "Subclasses must specify an alternative value for the defaultTemplate `outputName`" 

104 ) 

105 super().__init__(config=config) 

106 

107 # All arguments must be passed by kw, but python has not method to do 

108 # that without specifying a default, so None is used. Validate that 

109 # it is not None. This is largely for typing reasons, as in the normal 

110 # course of operation code execution paths ensure this will not be None 

111 assert config is not None 

112 

113 for tool in config.atools: 

114 match tool.produce: 

115 case JointAction(): 

116 if isinstance(tool.produce.metric, NoMetric): 

117 continue 

118 if len(tool.produce.metric.units) != 0: 

119 hasMetrics = True 

120 break 

121 case MetricAction(): 

122 hasMetrics = True 

123 break 

124 else: 

125 hasMetrics = False 

126 

127 # Set the dimensions for the metric 

128 if hasMetrics: 

129 self.metrics = ct.Output( 

130 name=self.metrics.name, 

131 doc=self.metrics.doc, 

132 storageClass=self.metrics.storageClass, 

133 dimensions=self.dimensions, 

134 multiple=False, 

135 isCalibration=False, 

136 ) 

137 else: 

138 # There are no metrics to produce, remove the output connection 

139 self.outputs.remove("metrics") 

140 

141 # Look for any conflicting names, creating a set of them, as these 

142 # will be added to the instance as well as recorded in the outputs 

143 # set. 

144 existingNames = set(dir(self)) 

145 

146 # Accumulate all the names to be used from all of the defined 

147 # AnalysisPlots. 

148 names: list[str] = [] 

149 for action in config.atools: 

150 if action.dynamicOutputNames: 

151 outNames = action.getOutputNames(config=config) 

152 else: 

153 outNames = action.getOutputNames() 

154 if action.parameterizedBand: 

155 for band in config.bands: 

156 names.extend(name.format(band=band) for name in outNames) 

157 else: 

158 names.extend(outNames) 

159 

160 # For each of the names found, create output connections. 

161 for name in names: 

162 name = f"{outputName}_{name}" 

163 if name in self.outputs or name in existingNames: 

164 raise NameError( 

165 f"Plot with name {name} conflicts with existing connection" 

166 " are two plots named the same?" 

167 ) 

168 outConnection = ct.Output( 

169 name=name, 

170 storageClass="Plot", 

171 doc="Dynamic connection for plotting", 

172 dimensions=self.dimensions, 

173 ) 

174 object.__setattr__(self, name, outConnection) 

175 self.outputs.add(name) 

176 

177 

178class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections): 

179 """Base class for all configs used to define an `AnalysisPipelineTask`. 

180 

181 This base class defines two fields that should be used in all subclasses, 

182 atools, and bands. 

183 

184 The ``atools`` field is where the user configures which analysis tools will 

185 be run as part of this `PipelineTask`. 

186 

187 The bands field specifies which bands will be looped over for 

188 `AnalysisTools` which support parameterized bands. I.e. called once for 

189 each band in the list. 

190 """ 

191 

192 atools = ConfigurableActionStructField[AnalysisTool]( 

193 doc="The analysis tools that are to be run by this task at execution" 

194 ) 

195 # Temporarally alias these for backwards compatibility 

196 plots = atools 

197 metrics = atools 

198 bands = ListField[str]( 

199 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"] 

200 ) 

201 metric_tags = ListField[str]( 

202 doc="List of tags which will be added to all configurable actions", default=[] 

203 ) 

204 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True) 

205 reference_package = Field[str]( 

206 doc="A package who's version, at the time of metric upload to a " 

207 "time series database, will be converted to a timestamp of when " 

208 "that version was produced", 

209 default="lsst_distrib", 

210 ) 

211 timestamp_version = Field[str]( 211 ↛ exitline 211 didn't jump to the function exit

212 doc="Which time stamp should be used as the reference timestamp for a " 

213 "metric in a time series database, valid values are; " 

214 "reference_package_timestamp, run_timestamp, current_timestamp, " 

215 "and dataset_timestamp", 

216 default="run_timestamp", 

217 check=lambda x: x 

218 in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"), 

219 ) 

220 

221 def applyConfigOverrides( 

222 self, 

223 instrument: Instrument | None, 

224 taskDefaultName: str, 

225 pipelineConfigs: Iterable[ConfigIR] | None, 

226 parameters: ParametersIR, 

227 label: str, 

228 ) -> None: 

229 extraConfig = {} 

230 if (value := parameters.mapping.get("sasquatch_dataset_identifier", None)) is not None: 

231 extraConfig["dataset_identifier"] = value 

232 if (value := parameters.mapping.get("sasquatch_reference_package", None)) is not None: 

233 extraConfig["reference_package"] = value 

234 if (value := parameters.mapping.get("sasquatch_timestamp_version", None)) is not None: 

235 extraConfig["timestamp_version"] = value 

236 if extraConfig: 

237 newPipelineConfigs = [ConfigIR(rest=extraConfig)] 

238 if pipelineConfigs is not None: 

239 newPipelineConfigs.extend(pipelineConfigs) 

240 pipelineConfigs = newPipelineConfigs 

241 return super().applyConfigOverrides(instrument, taskDefaultName, pipelineConfigs, parameters, label) 

242 

243 def freeze(self): 

244 # Copy the meta configuration values to each of the configured tools 

245 # only do this if the tool has not been further specialized 

246 if not self._frozen: 

247 for tool in self.atools: 

248 for tag in self.metric_tags: 

249 tool.metric_tags.insert(-1, tag) 

250 super().freeze() 

251 

252 def validate(self): 

253 super().validate() 

254 # Validate that the required connections template is set. 

255 if self.connections.outputName == "Placeholder": # type: ignore 

256 raise RuntimeError("Connections class 'outputName' must have a config explicitly set") 

257 

258 

259class _StandinPlotInfo(dict): 

260 """This class is an implementation detail to support plots in the instance 

261 no PlotInfo object is present in the call to run. 

262 """ 

263 

264 def __missing__(self, key): 

265 return "" 

266 

267 

268class AnalysisPipelineTask(PipelineTask): 

269 """Base class for `PipelineTasks` intended to run `AnalysisTools`. 

270 

271 The run method will run all of the `AnalysisTools` defined in the config 

272 class. 

273 """ 

274 

275 # Typing config because type checkers dont know about our Task magic 

276 config: AnalysisBaseConfig 

277 ConfigClass = AnalysisBaseConfig 

278 

279 def _runTools(self, data: KeyedData, **kwargs) -> Struct: 

280 results = Struct() 

281 results.metrics = MetricMeasurementBundle( 

282 dataset_identifier=self.config.dataset_identifier, 

283 reference_package=self.config.reference_package, 

284 timestamp_version=self.config.timestamp_version, 

285 ) 

286 # copy plot info to be sure each action sees its own copy 

287 plotInfo = kwargs.get("plotInfo") 

288 plotKey = f"{self.config.connections.outputName}_{{name}}" 

289 weakrefArgs = [] 

290 for name, action in self.config.atools.items(): 

291 kwargs["plotInfo"] = deepcopy(plotInfo) 

292 actionResult = action(data, **kwargs) 

293 metricAccumulate = [] 

294 for resultName, value in actionResult.items(): 

295 match value: 

296 case PlotTypes(): 

297 setattr(results, plotKey.format(name=resultName), value) 

298 weakrefArgs.append(value) 

299 case Measurement(): 

300 metricAccumulate.append(value) 

301 # only add the metrics if there are some 

302 if metricAccumulate: 

303 results.metrics[name] = metricAccumulate 

304 # Wrap the return struct in a finalizer so that when results is 

305 # garbage collected the plots will be closed. 

306 # TODO: This finalize step closes all open plots at the conclusion of 

307 # a task. When DM-39114 is implemented, this step should no longer be 

308 # required and may be removed. 

309 weakref.finalize(results, _plotCloser, *weakrefArgs) 

310 return results 

311 

312 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct: 

313 """Produce the outputs associated with this `PipelineTask`. 

314 

315 Parameters 

316 ---------- 

317 data : `KeyedData` 

318 The input data from which all `AnalysisTools` will run and produce 

319 outputs. A side note, the python typing specifies that this can be 

320 None, but this is only due to a limitation in python where in order 

321 to specify that all arguments be passed only as keywords the 

322 argument must be given a default. This argument most not actually 

323 be None. 

324 **kwargs 

325 Additional arguments that are passed through to the `AnalysisTools` 

326 specified in the configuration. 

327 

328 Returns 

329 ------- 

330 results : `~lsst.pipe.base.Struct` 

331 The accumulated results of all the plots and metrics produced by 

332 this `PipelineTask`. 

333 

334 Raises 

335 ------ 

336 ValueError 

337 Raised if the supplied data argument is `None` 

338 """ 

339 if data is None: 

340 raise ValueError("data must not be none") 

341 if "bands" not in kwargs: 

342 kwargs["bands"] = list(self.config.bands) 

343 if "plotInfo" not in kwargs: 

344 kwargs["plotInfo"] = _StandinPlotInfo() 

345 kwargs["plotInfo"]["bands"] = kwargs["bands"] 

346 if "SN" not in kwargs["plotInfo"].keys(): 

347 kwargs["plotInfo"]["SN"] = "-" 

348 return self._runTools(data, **kwargs) 

349 

350 def runQuantum( 

351 self, 

352 butlerQC: QuantumContext, 

353 inputRefs: InputQuantizedConnection, 

354 outputRefs: OutputQuantizedConnection, 

355 ) -> None: 

356 """Override default runQuantum to load the minimal columns necessary 

357 to complete the action. 

358 

359 Parameters 

360 ---------- 

361 butlerQC : `~lsst.pipe.base.QuantumContext` 

362 A butler which is specialized to operate in the context of a 

363 `lsst.daf.butler.Quantum`. 

364 inputRefs : `InputQuantizedConnection` 

365 Datastructure whose attribute names are the names that identify 

366 connections defined in corresponding `PipelineTaskConnections` 

367 class. The values of these attributes are the 

368 `lsst.daf.butler.DatasetRef` objects associated with the defined 

369 input/prerequisite connections. 

370 outputRefs : `OutputQuantizedConnection` 

371 Datastructure whose attribute names are the names that identify 

372 connections defined in corresponding `PipelineTaskConnections` 

373 class. The values of these attributes are the 

374 `lsst.daf.butler.DatasetRef` objects associated with the defined 

375 output connections. 

376 """ 

377 # TODO: This rcParams modification is a temporary solution, hiding 

378 # a matplotlib warning indicating too many figures have been opened. 

379 # When DM-39114 is implemented, this should be removed. 

380 plt.rcParams.update({"figure.max_open_warning": 0}) 

381 inputs = butlerQC.get(inputRefs) 

382 dataId = butlerQC.quantum.dataId 

383 plotInfo = self.parsePlotInfo(inputs, dataId) 

384 data = self.loadData(inputs["data"]) 

385 if "skymap" in inputs.keys(): 

386 skymap = inputs["skymap"] 

387 else: 

388 skymap = None 

389 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap) 

390 butlerQC.put(outputs, outputRefs) 

391 

392 def _populatePlotInfoWithDataId( 

393 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None 

394 ) -> None: 

395 """Update the plotInfo with the dataId values. 

396 

397 Parameters 

398 ---------- 

399 plotInfo : `dict` 

400 The plotInfo dictionary to update. 

401 dataId : `lsst.daf.butler.DataCoordinate` 

402 The dataId to use to update the plotInfo. 

403 """ 

404 if dataId is not None: 

405 plotInfo.update(dataId.mapping) 

406 

407 def parsePlotInfo( 

408 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data" 

409 ) -> Mapping[str, str]: 

410 """Parse the inputs and dataId to get the information needed to 

411 to add to the figure. 

412 

413 Parameters 

414 ---------- 

415 inputs: `dict` 

416 The inputs to the task 

417 dataCoordinate: `lsst.daf.butler.DataCoordinate` 

418 The dataId that the task is being run on. 

419 connectionName: `str`, optional 

420 Name of the input connection to use for determining table name. 

421 

422 Returns 

423 ------- 

424 plotInfo : `dict` 

425 """ 

426 

427 if inputs is None: 

428 tableName = "" 

429 run = "" 

430 else: 

431 tableName = inputs[connectionName].ref.datasetType.name 

432 run = inputs[connectionName].ref.run 

433 

434 # Initialize the plot info dictionary 

435 plotInfo = {"tableName": tableName, "run": run} 

436 

437 self._populatePlotInfoWithDataId(plotInfo, dataId) 

438 return plotInfo 

439 

440 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData: 

441 """Load the minimal set of keyed data from the input dataset. 

442 

443 Parameters 

444 ---------- 

445 handle : `DeferredDatasetHandle` 

446 Handle to load the dataset with only the specified columns. 

447 names : `Iterable` of `str` 

448 The names of keys to extract from the dataset. 

449 If `names` is `None` then the `collectInputNames` method 

450 is called to generate the names. 

451 For most purposes these are the names of columns to load from 

452 a catalog or data frame. 

453 

454 Returns 

455 ------- 

456 result: `KeyedData` 

457 The dataset with only the specified keys loaded. 

458 """ 

459 if names is None: 

460 names = self.collectInputNames() 

461 return cast(KeyedData, handle.get(parameters={"columns": names})) 

462 

463 def collectInputNames(self) -> Iterable[str]: 

464 """Get the names of the inputs. 

465 

466 If using the default `loadData` method this will gather the names 

467 of the keys to be loaded from an input dataset. 

468 

469 Returns 

470 ------- 

471 inputs : `Iterable` of `str` 

472 The names of the keys in the `KeyedData` object to extract. 

473 

474 """ 

475 inputs = set() 

476 for band in self.config.bands: 

477 for action in self.config.atools: 

478 for key, _ in action.getFormattedInputSchema(band=band): 

479 inputs.add(key) 

480 return inputs