Coverage for python/lsst/analysis/tools/interfaces/_task.py: 20%

168 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-15 10:10 +0000

1# This file is part of analysis_tools. # 

2# Developed for the LSST Data Management System. 

3# This product includes software developed by the LSST Project 

4# (https://www.lsst.org). 

5# See the COPYRIGHT file at the top-level directory of this distribution 

6# for details of code ownership. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the GNU General Public License 

19# along with this program. If not, see <https://www.gnu.org/licenses/>. 

20from __future__ import annotations 

21 

22"""Base class implementation for the classes needed in creating `PipelineTasks` 

23which execute `AnalysisTools`. 

24 

25The classes defined in this module have all the required behaviors for 

26defining, introspecting, and executing `AnalysisTools` against an input dataset 

27type. 

28 

29Subclasses of these tasks should specify specific datasets to consume in their 

30connection classes and should specify a unique name 

31""" 

32 

33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask") 

34 

35import weakref 

36from collections.abc import Iterable 

37from copy import deepcopy 

38from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast 

39 

40import matplotlib.pyplot as plt 

41from lsst.verify import Measurement 

42 

43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true

44 from lsst.daf.butler import DeferredDatasetHandle 

45 from lsst.pipe.base import QuantumContext 

46 

47from lsst.daf.butler import DataCoordinate 

48from lsst.pex.config import Field, ListField 

49from lsst.pex.config.configurableActions import ConfigurableActionStructField 

50from lsst.pipe.base import Instrument, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

51from lsst.pipe.base import connectionTypes as ct 

52from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection 

53from lsst.pipe.base.pipelineIR import ConfigIR, ParametersIR 

54 

55from ._actions import JointAction, MetricAction, NoMetric 

56from ._analysisTools import AnalysisTool 

57from ._interfaces import KeyedData, PlotTypes 

58from ._metricMeasurementBundle import MetricMeasurementBundle 

59 

60 

61# TODO: This _plotCloser function assists in closing all open plots at the 

62# conclusion of a PipelineTask. When DM-39114 is implemented, this function and 

63# all associated usage thereof should be removed. 

64def _plotCloser(*args): 

65 """Close all the plots in the given list.""" 

66 for plot in args: 

67 plt.close(plot) 

68 

69 

70class AnalysisBaseConnections( 

71 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"} 

72): 

73 r"""Base class for Connections used for AnalysisTools PipelineTasks. 

74 

75 This class has a pre-defined output connection for the 

76 MetricMeasurementMapping. The dataset type name for this connection is 

77 determined by the template ``outputName``. 

78 

79 Output connections for plots created by `AnalysisPlot`\ s are created 

80 dynamically when an instance of the class is created. The init method 

81 examines all the `AnalysisPlot` actions specified in the associated 

82 `AnalysisBaseConfig` subclass accumulating all the info needed to 

83 create the output connections. 

84 

85 The dimensions for all of the output connections (metric and plot) will 

86 be the same as the dimensions specified for the AnalysisBaseConnections 

87 subclass (i.e. quantum dimensions). 

88 """ 

89 

90 metrics = ct.Output( 

91 doc="Metrics calculated on input dataset type", 

92 name="{outputName}_metrics", 

93 storageClass="MetricMeasurementBundle", 

94 ) 

95 

96 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore 

97 # Validate that the outputName template has been set in config. This 

98 # should have been checked early with the configs validate method, but 

99 # it is possible for someone to manually create everything in a script 

100 # without running validate, so also check it late here. 

101 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore 

102 raise RuntimeError( 

103 "Subclasses must specify an alternative value for the defaultTemplate `outputName`" 

104 ) 

105 super().__init__(config=config) 

106 

107 # All arguments must be passed by kw, but python has not method to do 

108 # that without specifying a default, so None is used. Validate that 

109 # it is not None. This is largely for typing reasons, as in the normal 

110 # course of operation code execution paths ensure this will not be None 

111 assert config is not None 

112 

113 for tool in config.atools: 

114 match tool.produce: 

115 case JointAction(): 

116 if isinstance(tool.produce.metric, NoMetric): 

117 continue 

118 if len(tool.produce.metric.units) != 0: 

119 hasMetrics = True 

120 break 

121 case MetricAction(): 

122 hasMetrics = True 

123 break 

124 else: 

125 hasMetrics = False 

126 

127 # Set the dimensions for the metric 

128 if hasMetrics: 

129 self.metrics = ct.Output( 

130 name=self.metrics.name, 

131 doc=self.metrics.doc, 

132 storageClass=self.metrics.storageClass, 

133 dimensions=self.dimensions, 

134 multiple=False, 

135 isCalibration=False, 

136 ) 

137 else: 

138 # There are no metrics to produce, remove the output connection 

139 self.outputs.remove("metrics") 

140 

141 # Look for any conflicting names, creating a set of them, as these 

142 # will be added to the instance as well as recorded in the outputs 

143 # set. 

144 existingNames = set(dir(self)) 

145 

146 # Accumulate all the names to be used from all of the defined 

147 # AnalysisPlots. 

148 names: list[str] = [] 

149 for action in config.atools: 

150 if action.parameterizedBand: 

151 for band in config.bands: 

152 names.extend(name.format(band=band) for name in action.getOutputNames()) 

153 else: 

154 names.extend(action.getOutputNames()) 

155 

156 # For each of the names found, create output connections. 

157 for name in names: 

158 name = f"{outputName}_{name}" 

159 if name in self.outputs or name in existingNames: 

160 raise NameError( 

161 f"Plot with name {name} conflicts with existing connection" 

162 " are two plots named the same?" 

163 ) 

164 outConnection = ct.Output( 

165 name=name, 

166 storageClass="Plot", 

167 doc="Dynamic connection for plotting", 

168 dimensions=self.dimensions, 

169 ) 

170 object.__setattr__(self, name, outConnection) 

171 self.outputs.add(name) 

172 

173 

174class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections): 

175 """Base class for all configs used to define an `AnalysisPipelineTask`. 

176 

177 This base class defines two fields that should be used in all subclasses, 

178 atools, and bands. 

179 

180 The ``atools`` field is where the user configures which analysis tools will 

181 be run as part of this `PipelineTask`. 

182 

183 The bands field specifies which bands will be looped over for 

184 `AnalysisTools` which support parameterized bands. I.e. called once for 

185 each band in the list. 

186 """ 

187 

188 atools = ConfigurableActionStructField[AnalysisTool]( 

189 doc="The analysis tools that are to be run by this task at execution" 

190 ) 

191 # Temporarally alias these for backwards compatibility 

192 plots = atools 

193 metrics = atools 

194 bands = ListField[str]( 

195 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"] 

196 ) 

197 metric_tags = ListField[str]( 

198 doc="List of tags which will be added to all configurable actions", default=[] 

199 ) 

200 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True) 

201 reference_package = Field[str]( 

202 doc="A package who's version, at the time of metric upload to a " 

203 "time series database, will be converted to a timestamp of when " 

204 "that version was produced", 

205 default="lsst_distrib", 

206 ) 

207 timestamp_version = Field[str]( 207 ↛ exitline 207 didn't jump to the function exit

208 doc="Which time stamp should be used as the reference timestamp for a " 

209 "metric in a time series database, valid values are; " 

210 "reference_package_timestamp, run_timestamp, current_timestamp, " 

211 "and dataset_timestamp", 

212 default="run_timestamp", 

213 check=lambda x: x 

214 in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"), 

215 ) 

216 

217 def applyConfigOverrides( 

218 self, 

219 instrument: Instrument | None, 

220 taskDefaultName: str, 

221 pipelineConfigs: Iterable[ConfigIR] | None, 

222 parameters: ParametersIR, 

223 label: str, 

224 ) -> None: 

225 extraConfig = {} 

226 if (value := parameters.mapping.get("sasquatch_dataset_identifier", None)) is not None: 

227 extraConfig["dataset_identifier"] = value 

228 if (value := parameters.mapping.get("sasquatch_reference_package", None)) is not None: 

229 extraConfig["reference_package"] = value 

230 if (value := parameters.mapping.get("sasquatch_timestamp_version", None)) is not None: 

231 extraConfig["timestamp_version"] = value 

232 if extraConfig: 

233 newPipelineConfigs = [ConfigIR(rest=extraConfig)] 

234 if pipelineConfigs is not None: 

235 newPipelineConfigs.extend(pipelineConfigs) 

236 pipelineConfigs = newPipelineConfigs 

237 return super().applyConfigOverrides(instrument, taskDefaultName, pipelineConfigs, parameters, label) 

238 

239 def freeze(self): 

240 # Copy the meta configuration values to each of the configured tools 

241 # only do this if the tool has not been further specialized 

242 if not self._frozen: 

243 for tool in self.atools: 

244 for tag in self.metric_tags: 

245 tool.metric_tags.insert(-1, tag) 

246 super().freeze() 

247 

248 def validate(self): 

249 super().validate() 

250 # Validate that the required connections template is set. 

251 if self.connections.outputName == "Placeholder": # type: ignore 

252 raise RuntimeError("Connections class 'outputName' must have a config explicitly set") 

253 

254 

255class _StandinPlotInfo(dict): 

256 """This class is an implementation detail to support plots in the instance 

257 no PlotInfo object is present in the call to run. 

258 """ 

259 

260 def __missing__(self, key): 

261 return "" 

262 

263 

264class AnalysisPipelineTask(PipelineTask): 

265 """Base class for `PipelineTasks` intended to run `AnalysisTools`. 

266 

267 The run method will run all of the `AnalysisTools` defined in the config 

268 class. 

269 """ 

270 

271 # Typing config because type checkers dont know about our Task magic 

272 config: AnalysisBaseConfig 

273 ConfigClass = AnalysisBaseConfig 

274 

275 def _runTools(self, data: KeyedData, **kwargs) -> Struct: 

276 results = Struct() 

277 results.metrics = MetricMeasurementBundle( 

278 dataset_identifier=self.config.dataset_identifier, 

279 reference_package=self.config.reference_package, 

280 timestamp_version=self.config.timestamp_version, 

281 ) 

282 # copy plot info to be sure each action sees its own copy 

283 plotInfo = kwargs.get("plotInfo") 

284 plotKey = f"{self.config.connections.outputName}_{{name}}" 

285 weakrefArgs = [] 

286 for name, action in self.config.atools.items(): 

287 kwargs["plotInfo"] = deepcopy(plotInfo) 

288 actionResult = action(data, **kwargs) 

289 metricAccumulate = [] 

290 for resultName, value in actionResult.items(): 

291 match value: 

292 case PlotTypes(): 

293 setattr(results, plotKey.format(name=resultName), value) 

294 weakrefArgs.append(value) 

295 case Measurement(): 

296 metricAccumulate.append(value) 

297 # only add the metrics if there are some 

298 if metricAccumulate: 

299 results.metrics[name] = metricAccumulate 

300 # Wrap the return struct in a finalizer so that when results is 

301 # garbage collected the plots will be closed. 

302 # TODO: This finalize step closes all open plots at the conclusion of 

303 # a task. When DM-39114 is implemented, this step should no longer be 

304 # required and may be removed. 

305 weakref.finalize(results, _plotCloser, *weakrefArgs) 

306 return results 

307 

308 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct: 

309 """Produce the outputs associated with this `PipelineTask`. 

310 

311 Parameters 

312 ---------- 

313 data : `KeyedData` 

314 The input data from which all `AnalysisTools` will run and produce 

315 outputs. A side note, the python typing specifies that this can be 

316 None, but this is only due to a limitation in python where in order 

317 to specify that all arguments be passed only as keywords the 

318 argument must be given a default. This argument most not actually 

319 be None. 

320 **kwargs 

321 Additional arguments that are passed through to the `AnalysisTools` 

322 specified in the configuration. 

323 

324 Returns 

325 ------- 

326 results : `~lsst.pipe.base.Struct` 

327 The accumulated results of all the plots and metrics produced by 

328 this `PipelineTask`. 

329 

330 Raises 

331 ------ 

332 ValueError 

333 Raised if the supplied data argument is `None` 

334 """ 

335 if data is None: 

336 raise ValueError("data must not be none") 

337 if "bands" not in kwargs: 

338 kwargs["bands"] = list(self.config.bands) 

339 if "plotInfo" not in kwargs: 

340 kwargs["plotInfo"] = _StandinPlotInfo() 

341 kwargs["plotInfo"]["bands"] = kwargs["bands"] 

342 if "SN" not in kwargs["plotInfo"].keys(): 

343 kwargs["plotInfo"]["SN"] = "-" 

344 return self._runTools(data, **kwargs) 

345 

346 def runQuantum( 

347 self, 

348 butlerQC: QuantumContext, 

349 inputRefs: InputQuantizedConnection, 

350 outputRefs: OutputQuantizedConnection, 

351 ) -> None: 

352 """Override default runQuantum to load the minimal columns necessary 

353 to complete the action. 

354 

355 Parameters 

356 ---------- 

357 butlerQC : `~lsst.pipe.base.QuantumContext` 

358 A butler which is specialized to operate in the context of a 

359 `lsst.daf.butler.Quantum`. 

360 inputRefs : `InputQuantizedConnection` 

361 Datastructure whose attribute names are the names that identify 

362 connections defined in corresponding `PipelineTaskConnections` 

363 class. The values of these attributes are the 

364 `lsst.daf.butler.DatasetRef` objects associated with the defined 

365 input/prerequisite connections. 

366 outputRefs : `OutputQuantizedConnection` 

367 Datastructure whose attribute names are the names that identify 

368 connections defined in corresponding `PipelineTaskConnections` 

369 class. The values of these attributes are the 

370 `lsst.daf.butler.DatasetRef` objects associated with the defined 

371 output connections. 

372 """ 

373 # TODO: This rcParams modification is a temporary solution, hiding 

374 # a matplotlib warning indicating too many figures have been opened. 

375 # When DM-39114 is implemented, this should be removed. 

376 plt.rcParams.update({"figure.max_open_warning": 0}) 

377 inputs = butlerQC.get(inputRefs) 

378 dataId = butlerQC.quantum.dataId 

379 plotInfo = self.parsePlotInfo(inputs, dataId) 

380 data = self.loadData(inputs["data"]) 

381 if "skymap" in inputs.keys(): 

382 skymap = inputs["skymap"] 

383 else: 

384 skymap = None 

385 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap) 

386 butlerQC.put(outputs, outputRefs) 

387 

388 def _populatePlotInfoWithDataId( 

389 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None 

390 ) -> None: 

391 """Update the plotInfo with the dataId values. 

392 

393 Parameters 

394 ---------- 

395 plotInfo : `dict` 

396 The plotInfo dictionary to update. 

397 dataId : `lsst.daf.butler.DataCoordinate` 

398 The dataId to use to update the plotInfo. 

399 """ 

400 if dataId is not None: 

401 for dataInfo in dataId: 

402 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

403 

404 def parsePlotInfo( 

405 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data" 

406 ) -> Mapping[str, str]: 

407 """Parse the inputs and dataId to get the information needed to 

408 to add to the figure. 

409 

410 Parameters 

411 ---------- 

412 inputs: `dict` 

413 The inputs to the task 

414 dataCoordinate: `lsst.daf.butler.DataCoordinate` 

415 The dataId that the task is being run on. 

416 connectionName: `str`, optional 

417 Name of the input connection to use for determining table name. 

418 

419 Returns 

420 ------- 

421 plotInfo : `dict` 

422 """ 

423 

424 if inputs is None: 

425 tableName = "" 

426 run = "" 

427 else: 

428 tableName = inputs[connectionName].ref.datasetType.name 

429 run = inputs[connectionName].ref.run 

430 

431 # Initialize the plot info dictionary 

432 plotInfo = {"tableName": tableName, "run": run} 

433 

434 self._populatePlotInfoWithDataId(plotInfo, dataId) 

435 return plotInfo 

436 

437 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData: 

438 """Load the minimal set of keyed data from the input dataset. 

439 

440 Parameters 

441 ---------- 

442 handle : `DeferredDatasetHandle` 

443 Handle to load the dataset with only the specified columns. 

444 names : `Iterable` of `str` 

445 The names of keys to extract from the dataset. 

446 If `names` is `None` then the `collectInputNames` method 

447 is called to generate the names. 

448 For most purposes these are the names of columns to load from 

449 a catalog or data frame. 

450 

451 Returns 

452 ------- 

453 result: `KeyedData` 

454 The dataset with only the specified keys loaded. 

455 """ 

456 if names is None: 

457 names = self.collectInputNames() 

458 return cast(KeyedData, handle.get(parameters={"columns": names})) 

459 

460 def collectInputNames(self) -> Iterable[str]: 

461 """Get the names of the inputs. 

462 

463 If using the default `loadData` method this will gather the names 

464 of the keys to be loaded from an input dataset. 

465 

466 Returns 

467 ------- 

468 inputs : `Iterable` of `str` 

469 The names of the keys in the `KeyedData` object to extract. 

470 

471 """ 

472 inputs = set() 

473 for band in self.config.bands: 

474 for action in self.config.atools: 

475 for key, _ in action.getFormattedInputSchema(band=band): 

476 inputs.add(key) 

477 return inputs