Coverage for python/lsst/analysis/tools/interfaces/_task.py: 19%

201 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-10 11:04 +0000

1# This file is part of analysis_tools. # 

2# Developed for the LSST Data Management System. 

3# This product includes software developed by the LSST Project 

4# (https://www.lsst.org). 

5# See the COPYRIGHT file at the top-level directory of this distribution 

6# for details of code ownership. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the GNU General Public License 

19# along with this program. If not, see <https://www.gnu.org/licenses/>. 

20from __future__ import annotations 

21 

22"""Base class implementation for the classes needed in creating `PipelineTasks` 

23which execute `AnalysisTools`. 

24 

25The classes defined in this module have all the required behaviors for 

26defining, introspecting, and executing `AnalysisTools` against an input dataset 

27type. 

28 

29Subclasses of these tasks should specify specific datasets to consume in their 

30connection classes and should specify a unique name 

31""" 

32 

33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask") 

34 

35import datetime 

36import logging 

37import warnings 

38import weakref 

39from collections.abc import Iterable 

40from copy import deepcopy 

41from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast 

42 

43import matplotlib.pyplot as plt 

44from lsst.verify import Measurement 

45 

46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true

47 from lsst.daf.butler import DeferredDatasetHandle 

48 from lsst.pipe.base import QuantumContext 

49 

50from lsst.daf.butler import DataCoordinate 

51from lsst.pex.config import Field, ListField 

52from lsst.pex.config.configurableActions import ConfigurableActionStructField 

53from lsst.pipe.base import Instrument, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

54from lsst.pipe.base import connectionTypes as ct 

55from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection 

56from lsst.pipe.base.pipelineIR import ConfigIR, ParametersIR 

57 

58from ._actions import JointAction, MetricAction, NoMetric 

59from ._analysisTools import AnalysisTool 

60from ._interfaces import KeyedData, PlotTypes 

61from ._metricMeasurementBundle import MetricMeasurementBundle 

62 

63# TODO: This rcParams modification is a temporary solution, hiding 

64# a matplotlib warning indicating too many figures have been opened. 

65# When DM-39114 is implemented, this should be removed. 

66plt.rcParams.update({"figure.max_open_warning": 0}) 

67 

68 

69# TODO: This _plotCloser function assists in closing all open plots at the 

70# conclusion of a PipelineTask. When DM-39114 is implemented, this function and 

71# all associated usage thereof should be removed. 

72def _plotCloser(*args): 

73 """Close all the plots in the given list.""" 

74 for plot in args: 

75 plt.close(plot) 

76 

77 

78class AnalysisBaseConnections( 

79 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"} 

80): 

81 r"""Base class for Connections used for AnalysisTools PipelineTasks. 

82 

83 This class has a pre-defined output connection for the 

84 MetricMeasurementMapping. The dataset type name for this connection is 

85 determined by the template ``outputName``. 

86 

87 Output connections for plots created by `AnalysisPlot`\ s are created 

88 dynamically when an instance of the class is created. The init method 

89 examines all the `AnalysisPlot` actions specified in the associated 

90 `AnalysisBaseConfig` subclass accumulating all the info needed to 

91 create the output connections. 

92 

93 The dimensions for all of the output connections (metric and plot) will 

94 be the same as the dimensions specified for the AnalysisBaseConnections 

95 subclass (i.e. quantum dimensions). 

96 """ 

97 

98 metrics = ct.Output( 

99 doc="Metrics calculated on input dataset type", 

100 name="{outputName}_metrics", 

101 storageClass="MetricMeasurementBundle", 

102 ) 

103 

104 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore 

105 # Validate that the outputName template has been set in config. This 

106 # should have been checked early with the configs validate method, but 

107 # it is possible for someone to manually create everything in a script 

108 # without running validate, so also check it late here. 

109 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore 

110 raise RuntimeError( 

111 "Subclasses must specify an alternative value for the defaultTemplate `outputName`" 

112 ) 

113 super().__init__(config=config) 

114 

115 # All arguments must be passed by kw, but python has not method to do 

116 # that without specifying a default, so None is used. Validate that 

117 # it is not None. This is largely for typing reasons, as in the normal 

118 # course of operation code execution paths ensure this will not be None 

119 assert config is not None 

120 

121 for tool in config.atools: 

122 match tool.produce: 

123 case JointAction(): 

124 if isinstance(tool.produce.metric, NoMetric): 

125 continue 

126 if len(tool.produce.metric.units) != 0: 

127 hasMetrics = True 

128 break 

129 case MetricAction(): 

130 hasMetrics = True 

131 break 

132 else: 

133 hasMetrics = False 

134 

135 # Set the dimensions for the metric 

136 if hasMetrics: 

137 self.metrics = ct.Output( 

138 name=self.metrics.name, 

139 doc=self.metrics.doc, 

140 storageClass=self.metrics.storageClass, 

141 dimensions=self.dimensions, 

142 multiple=False, 

143 isCalibration=False, 

144 ) 

145 else: 

146 # There are no metrics to produce, remove the output connection 

147 self.outputs.remove("metrics") 

148 

149 # Look for any conflicting names, creating a set of them, as these 

150 # will be added to the instance as well as recorded in the outputs 

151 # set. 

152 existingNames = set(dir(self)) 

153 

154 # Accumulate all the names to be used from all of the defined 

155 # AnalysisPlots. 

156 names: list[str] = [] 

157 for action in config.atools: 

158 if action.dynamicOutputNames: 

159 outNames = action.getOutputNames(config=config) 

160 else: 

161 outNames = action.getOutputNames() 

162 if action.parameterizedBand: 

163 for band in config.bands: 

164 names.extend(name.format(band=band) for name in outNames) 

165 else: 

166 names.extend(outNames) 

167 

168 # For each of the names found, create output connections. 

169 for name in names: 

170 name = f"{outputName}_{name}" 

171 if name in self.outputs or name in existingNames: 

172 raise NameError( 

173 f"Plot with name {name} conflicts with existing connection" 

174 " are two plots named the same?" 

175 ) 

176 outConnection = ct.Output( 

177 name=name, 

178 storageClass="Plot", 

179 doc="Dynamic connection for plotting", 

180 dimensions=self.dimensions, 

181 ) 

182 object.__setattr__(self, name, outConnection) 

183 self.outputs.add(name) 

184 

185 

186def _timestampValidator(value: str) -> bool: 

187 if value in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"): 

188 return True 

189 elif "explicit_timestamp" in value: 

190 try: 

191 _, splitTime = value.split(":") 

192 except ValueError: 

193 logging.error( 

194 "Explicit timestamp must be given in the format 'explicit_timestamp:datetime', " 

195 r"where datetime is given in the form '%Y%m%dT%H%M%S%z" 

196 ) 

197 return False 

198 try: 

199 datetime.datetime.strptime(splitTime, r"%Y%m%dT%H%M%S%z") 

200 except ValueError: 

201 # This is explicitly chosen to be an f string as the string 

202 # contains control characters. 

203 logging.error( 

204 f"The supplied datetime {splitTime} could not be parsed correctly into " 

205 r"%Y%m%dT%H%M%S%z format" 

206 ) 

207 return False 

208 return True 

209 else: 

210 return False 

211 

212 

213class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections): 

214 """Base class for all configs used to define an `AnalysisPipelineTask`. 

215 

216 This base class defines two fields that should be used in all subclasses, 

217 atools, and bands. 

218 

219 The ``atools`` field is where the user configures which analysis tools will 

220 be run as part of this `PipelineTask`. 

221 

222 The bands field specifies which bands will be looped over for 

223 `AnalysisTools` which support parameterized bands. I.e. called once for 

224 each band in the list. 

225 """ 

226 

227 atools = ConfigurableActionStructField[AnalysisTool]( 

228 doc="The analysis tools that are to be run by this task at execution" 

229 ) 

230 # Temporarally alias these for backwards compatibility 

231 plots = atools 

232 metrics = atools 

233 bands = ListField[str]( 

234 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"] 

235 ) 

236 metric_tags = ListField[str]( 

237 doc="List of tags which will be added to all configurable actions", default=[] 

238 ) 

239 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True) 

240 reference_package = Field[str]( 

241 doc="A package who's version, at the time of metric upload to a " 

242 "time series database, will be converted to a timestamp of when " 

243 "that version was produced", 

244 default="lsst_distrib", 

245 ) 

246 timestamp_version = Field[str]( 

247 doc="Which time stamp should be used as the reference timestamp for a " 

248 "metric in a time series database, valid values are; " 

249 "reference_package_timestamp, run_timestamp, current_timestamp, " 

250 "dataset_timestamp and explicit_timestamp:datetime where datetime is " 

251 "given in the form %Y%m%dT%H%M%S%z", 

252 default="run_timestamp", 

253 check=_timestampValidator, 

254 ) 

255 

256 def applyConfigOverrides( 

257 self, 

258 instrument: Instrument | None, 

259 taskDefaultName: str, 

260 pipelineConfigs: Iterable[ConfigIR] | None, 

261 parameters: ParametersIR, 

262 label: str, 

263 ) -> None: 

264 extraConfig = {} 

265 if (value := parameters.mapping.get("sasquatch_dataset_identifier", None)) is not None: 

266 extraConfig["dataset_identifier"] = value 

267 if (value := parameters.mapping.get("sasquatch_reference_package", None)) is not None: 

268 extraConfig["reference_package"] = value 

269 if (value := parameters.mapping.get("sasquatch_timestamp_version", None)) is not None: 

270 if "explicit_timestamp" in value: 

271 try: 

272 _, splitTime = value.split(":") 

273 except ValueError as excpt: 

274 raise ValueError( 

275 "Explicit timestamp must be given in the format 'explicit_timestamp:datetime', " 

276 "where datetime is given in the form '%Y%m%dT%H%M%S%z" 

277 ) from excpt 

278 try: 

279 datetime.datetime.strptime(splitTime, r"%Y%m%dT%H%M%S%z") 

280 except ValueError as excpt: 

281 raise ValueError( 

282 f"The supplied datetime {splitTime} could not be parsed correctly into " 

283 "%Y%m%dT%H%M%S%z format" 

284 ) from excpt 

285 extraConfig["timestamp_version"] = value 

286 if extraConfig: 

287 newPipelineConfigs = [ConfigIR(rest=extraConfig)] 

288 if pipelineConfigs is not None: 

289 newPipelineConfigs.extend(pipelineConfigs) 

290 pipelineConfigs = newPipelineConfigs 

291 return super().applyConfigOverrides(instrument, taskDefaultName, pipelineConfigs, parameters, label) 

292 

293 def freeze(self): 

294 # Copy the meta configuration values to each of the configured tools 

295 # only do this if the tool has not been further specialized 

296 if not self._frozen: 

297 for tool in self.atools: 

298 for tag in self.metric_tags: 

299 tool.metric_tags.insert(-1, tag) 

300 super().freeze() 

301 

302 def validate(self): 

303 super().validate() 

304 # Validate that the required connections template is set. 

305 if self.connections.outputName == "Placeholder": # type: ignore 

306 raise RuntimeError("Connections class 'outputName' must have a config explicitly set") 

307 

308 

309class _StandinPlotInfo(dict): 

310 """This class is an implementation detail to support plots in the instance 

311 no PlotInfo object is present in the call to run. 

312 """ 

313 

314 def __missing__(self, key): 

315 return "" 

316 

317 

318class AnalysisPipelineTask(PipelineTask): 

319 """Base class for `PipelineTasks` intended to run `AnalysisTools`. 

320 

321 The run method will run all of the `AnalysisTools` defined in the config 

322 class. 

323 """ 

324 

325 # Typing config because type checkers dont know about our Task magic 

326 config: AnalysisBaseConfig 

327 ConfigClass = AnalysisBaseConfig 

328 

329 warnings_all = ( 

330 "divide by zero encountered in divide", 

331 "invalid value encountered in arcsin", 

332 "invalid value encountered in cos", 

333 "invalid value encountered in divide", 

334 "invalid value encountered in log10", 

335 "invalid value encountered in scalar divide", 

336 "invalid value encountered in sin", 

337 "invalid value encountered in sqrt", 

338 "invalid value encountered in true_divide", 

339 "Mean of empty slice", 

340 ) 

341 

342 def _runTools(self, data: KeyedData, **kwargs) -> Struct: 

343 with warnings.catch_warnings(): 

344 # Change below to "in self.warnings_all" to find otherwise 

345 # unfiltered numpy warnings. 

346 for warning in (): 

347 warnings.filterwarnings("error", warning, RuntimeWarning) 

348 results = Struct() 

349 results.metrics = MetricMeasurementBundle( 

350 dataset_identifier=self.config.dataset_identifier, 

351 reference_package=self.config.reference_package, 

352 timestamp_version=self.config.timestamp_version, 

353 ) 

354 # copy plot info to be sure each action sees its own copy 

355 plotInfo = kwargs.get("plotInfo") 

356 plotKey = f"{self.config.connections.outputName}_{{name}}" 

357 weakrefArgs = [] 

358 for name, action in self.config.atools.items(): 

359 kwargs["plotInfo"] = deepcopy(plotInfo) 

360 actionResult = action(data, **kwargs) 

361 metricAccumulate = [] 

362 for resultName, value in actionResult.items(): 

363 match value: 

364 case PlotTypes(): 

365 setattr(results, plotKey.format(name=resultName), value) 

366 weakrefArgs.append(value) 

367 case Measurement(): 

368 metricAccumulate.append(value) 

369 # only add the metrics if there are some 

370 if metricAccumulate: 

371 results.metrics[name] = metricAccumulate 

372 # Wrap the return struct in a finalizer so that when results is 

373 # garbage collected the plots will be closed. 

374 # TODO: This finalize step closes all open plots at the conclusion 

375 # of a task. When DM-39114 is implemented, this step should not 

376 # be required and may be removed. 

377 weakref.finalize(results, _plotCloser, *weakrefArgs) 

378 return results 

379 

380 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct: 

381 """Produce the outputs associated with this `PipelineTask`. 

382 

383 Parameters 

384 ---------- 

385 data : `KeyedData` 

386 The input data from which all `AnalysisTools` will run and produce 

387 outputs. A side note, the python typing specifies that this can be 

388 None, but this is only due to a limitation in python where in order 

389 to specify that all arguments be passed only as keywords the 

390 argument must be given a default. This argument most not actually 

391 be None. 

392 **kwargs 

393 Additional arguments that are passed through to the `AnalysisTools` 

394 specified in the configuration. 

395 

396 Returns 

397 ------- 

398 results : `~lsst.pipe.base.Struct` 

399 The accumulated results of all the plots and metrics produced by 

400 this `PipelineTask`. 

401 

402 Raises 

403 ------ 

404 ValueError 

405 Raised if the supplied data argument is `None` 

406 """ 

407 if data is None: 

408 raise ValueError("data must not be none") 

409 if "bands" not in kwargs: 

410 kwargs["bands"] = list(self.config.bands) 

411 if "plotInfo" not in kwargs: 

412 kwargs["plotInfo"] = _StandinPlotInfo() 

413 kwargs["plotInfo"]["bands"] = kwargs["bands"] 

414 return self._runTools(data, **kwargs) 

415 

416 def runQuantum( 

417 self, 

418 butlerQC: QuantumContext, 

419 inputRefs: InputQuantizedConnection, 

420 outputRefs: OutputQuantizedConnection, 

421 ) -> None: 

422 """Override default runQuantum to load the minimal columns necessary 

423 to complete the action. 

424 

425 Parameters 

426 ---------- 

427 butlerQC : `~lsst.pipe.base.QuantumContext` 

428 A butler which is specialized to operate in the context of a 

429 `lsst.daf.butler.Quantum`. 

430 inputRefs : `InputQuantizedConnection` 

431 Datastructure whose attribute names are the names that identify 

432 connections defined in corresponding `PipelineTaskConnections` 

433 class. The values of these attributes are the 

434 `lsst.daf.butler.DatasetRef` objects associated with the defined 

435 input/prerequisite connections. 

436 outputRefs : `OutputQuantizedConnection` 

437 Datastructure whose attribute names are the names that identify 

438 connections defined in corresponding `PipelineTaskConnections` 

439 class. The values of these attributes are the 

440 `lsst.daf.butler.DatasetRef` objects associated with the defined 

441 output connections. 

442 """ 

443 inputs = butlerQC.get(inputRefs) 

444 dataId = butlerQC.quantum.dataId 

445 plotInfo = self.parsePlotInfo(inputs, dataId) 

446 # We implicitly assume that 'data' has been defined, but do not have a 

447 # corresponding input connection in the base class. Thus, we capture 

448 # and re-raise the error with a more helpful message. 

449 try: 

450 # data has to be popped out to avoid duplication in the call to the 

451 # `run` method. 

452 inputData = inputs.pop("data") 

453 except KeyError: 

454 raise RuntimeError("'data' is a required input connection, but is not defined.") 

455 data = self.loadData(inputData) 

456 outputs = self.run(data=data, plotInfo=plotInfo, **inputs) 

457 butlerQC.put(outputs, outputRefs) 

458 

459 def _populatePlotInfoWithDataId( 

460 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None 

461 ) -> None: 

462 """Update the plotInfo with the dataId values. 

463 

464 Parameters 

465 ---------- 

466 plotInfo : `dict` 

467 The plotInfo dictionary to update. 

468 dataId : `lsst.daf.butler.DataCoordinate` 

469 The dataId to use to update the plotInfo. 

470 """ 

471 if dataId is not None: 

472 plotInfo.update(dataId.mapping) 

473 

474 def parsePlotInfo( 

475 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data" 

476 ) -> Mapping[str, str]: 

477 """Parse the inputs and dataId to get the information needed to 

478 to add to the figure. 

479 

480 Parameters 

481 ---------- 

482 inputs: `dict` 

483 The inputs to the task 

484 dataCoordinate: `lsst.daf.butler.DataCoordinate` 

485 The dataId that the task is being run on. 

486 connectionName: `str`, optional 

487 Name of the input connection to use for determining table name. 

488 

489 Returns 

490 ------- 

491 plotInfo : `dict` 

492 """ 

493 

494 if inputs is None: 

495 tableName = "" 

496 run = "" 

497 else: 

498 tableName = inputs[connectionName].ref.datasetType.name 

499 run = inputs[connectionName].ref.run 

500 

501 # Initialize the plot info dictionary 

502 plotInfo = {"tableName": tableName, "run": run} 

503 

504 self._populatePlotInfoWithDataId(plotInfo, dataId) 

505 return plotInfo 

506 

507 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData: 

508 """Load the minimal set of keyed data from the input dataset. 

509 

510 Parameters 

511 ---------- 

512 handle : `DeferredDatasetHandle` 

513 Handle to load the dataset with only the specified columns. 

514 names : `Iterable` of `str` 

515 The names of keys to extract from the dataset. 

516 If `names` is `None` then the `collectInputNames` method 

517 is called to generate the names. 

518 For most purposes these are the names of columns to load from 

519 a catalog or data frame. 

520 

521 Returns 

522 ------- 

523 result: `KeyedData` 

524 The dataset with only the specified keys loaded. 

525 """ 

526 if names is None: 

527 names = self.collectInputNames() 

528 return cast(KeyedData, handle.get(parameters={"columns": names})) 

529 

530 def collectInputNames(self) -> Iterable[str]: 

531 """Get the names of the inputs. 

532 

533 If using the default `loadData` method this will gather the names 

534 of the keys to be loaded from an input dataset. 

535 

536 Returns 

537 ------- 

538 inputs : `Iterable` of `str` 

539 The names of the keys in the `KeyedData` object to extract. 

540 

541 """ 

542 inputs = set() 

543 for band in self.config.bands: 

544 for action in self.config.atools: 

545 for key, _ in action.getFormattedInputSchema(band=band): 

546 inputs.add(key) 

547 return inputs