Coverage for python/lsst/analysis/tools/interfaces/_task.py: 19%

202 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-20 13:17 +0000

1# This file is part of analysis_tools. # 

2# Developed for the LSST Data Management System. 

3# This product includes software developed by the LSST Project 

4# (https://www.lsst.org). 

5# See the COPYRIGHT file at the top-level directory of this distribution 

6# for details of code ownership. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the GNU General Public License 

19# along with this program. If not, see <https://www.gnu.org/licenses/>. 

20from __future__ import annotations 

21 

22"""Base class implementation for the classes needed in creating `PipelineTasks` 

23which execute `AnalysisTools`. 

24 

25The classes defined in this module have all the required behaviors for 

26defining, introspecting, and executing `AnalysisTools` against an input dataset 

27type. 

28 

29Subclasses of these tasks should specify specific datasets to consume in their 

30connection classes and should specify a unique name 

31""" 

32 

33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask") 

34 

35import datetime 

36import logging 

37import warnings 

38import weakref 

39from collections.abc import Iterable 

40from copy import deepcopy 

41from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast 

42 

43import matplotlib.pyplot as plt 

44from lsst.verify import Measurement 

45 

46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true

47 from lsst.daf.butler import DeferredDatasetHandle 

48 from lsst.pipe.base import QuantumContext 

49 

50from lsst.daf.butler import DataCoordinate 

51from lsst.pex.config import Field, ListField 

52from lsst.pex.config.configurableActions import ConfigurableActionStructField 

53from lsst.pipe.base import Instrument, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

54from lsst.pipe.base import connectionTypes as ct 

55from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection 

56from lsst.pipe.base.pipelineIR import ConfigIR, ParametersIR 

57 

58from ._actions import JointAction, MetricAction, NoMetric 

59from ._analysisTools import AnalysisTool 

60from ._interfaces import KeyedData, PlotTypes 

61from ._metricMeasurementBundle import MetricMeasurementBundle 

62 

63 

64# TODO: This _plotCloser function assists in closing all open plots at the 

65# conclusion of a PipelineTask. When DM-39114 is implemented, this function and 

66# all associated usage thereof should be removed. 

67def _plotCloser(*args): 

68 """Close all the plots in the given list.""" 

69 for plot in args: 

70 plt.close(plot) 

71 

72 

73class AnalysisBaseConnections( 

74 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"} 

75): 

76 r"""Base class for Connections used for AnalysisTools PipelineTasks. 

77 

78 This class has a pre-defined output connection for the 

79 MetricMeasurementMapping. The dataset type name for this connection is 

80 determined by the template ``outputName``. 

81 

82 Output connections for plots created by `AnalysisPlot`\ s are created 

83 dynamically when an instance of the class is created. The init method 

84 examines all the `AnalysisPlot` actions specified in the associated 

85 `AnalysisBaseConfig` subclass accumulating all the info needed to 

86 create the output connections. 

87 

88 The dimensions for all of the output connections (metric and plot) will 

89 be the same as the dimensions specified for the AnalysisBaseConnections 

90 subclass (i.e. quantum dimensions). 

91 """ 

92 

93 metrics = ct.Output( 

94 doc="Metrics calculated on input dataset type", 

95 name="{outputName}_metrics", 

96 storageClass="MetricMeasurementBundle", 

97 ) 

98 

99 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore 

100 # Validate that the outputName template has been set in config. This 

101 # should have been checked early with the configs validate method, but 

102 # it is possible for someone to manually create everything in a script 

103 # without running validate, so also check it late here. 

104 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore 

105 raise RuntimeError( 

106 "Subclasses must specify an alternative value for the defaultTemplate `outputName`" 

107 ) 

108 super().__init__(config=config) 

109 

110 # All arguments must be passed by kw, but python has not method to do 

111 # that without specifying a default, so None is used. Validate that 

112 # it is not None. This is largely for typing reasons, as in the normal 

113 # course of operation code execution paths ensure this will not be None 

114 assert config is not None 

115 

116 for tool in config.atools: 

117 match tool.produce: 

118 case JointAction(): 

119 if isinstance(tool.produce.metric, NoMetric): 

120 continue 

121 if len(tool.produce.metric.units) != 0: 

122 hasMetrics = True 

123 break 

124 case MetricAction(): 

125 hasMetrics = True 

126 break 

127 else: 

128 hasMetrics = False 

129 

130 # Set the dimensions for the metric 

131 if hasMetrics: 

132 self.metrics = ct.Output( 

133 name=self.metrics.name, 

134 doc=self.metrics.doc, 

135 storageClass=self.metrics.storageClass, 

136 dimensions=self.dimensions, 

137 multiple=False, 

138 isCalibration=False, 

139 ) 

140 else: 

141 # There are no metrics to produce, remove the output connection 

142 self.outputs.remove("metrics") 

143 

144 # Look for any conflicting names, creating a set of them, as these 

145 # will be added to the instance as well as recorded in the outputs 

146 # set. 

147 existingNames = set(dir(self)) 

148 

149 # Accumulate all the names to be used from all of the defined 

150 # AnalysisPlots. 

151 names: list[str] = [] 

152 for action in config.atools: 

153 if action.dynamicOutputNames: 

154 outNames = action.getOutputNames(config=config) 

155 else: 

156 outNames = action.getOutputNames() 

157 if action.parameterizedBand: 

158 for band in config.bands: 

159 names.extend(name.format(band=band) for name in outNames) 

160 else: 

161 names.extend(outNames) 

162 

163 # For each of the names found, create output connections. 

164 for name in names: 

165 name = f"{outputName}_{name}" 

166 if name in self.outputs or name in existingNames: 

167 raise NameError( 

168 f"Plot with name {name} conflicts with existing connection" 

169 " are two plots named the same?" 

170 ) 

171 outConnection = ct.Output( 

172 name=name, 

173 storageClass="Plot", 

174 doc="Dynamic connection for plotting", 

175 dimensions=self.dimensions, 

176 ) 

177 object.__setattr__(self, name, outConnection) 

178 self.outputs.add(name) 

179 

180 

181def _timestampValidator(value: str) -> bool: 

182 if value in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"): 

183 return True 

184 elif "explicit_timestamp" in value: 

185 try: 

186 _, splitTime = value.split(":") 

187 except ValueError: 

188 logging.error( 

189 "Explicit timestamp must be given in the format 'explicit_timestamp:datetime', " 

190 r"where datetime is given in the form '%Y%m%dT%H%M%S%z" 

191 ) 

192 return False 

193 try: 

194 datetime.datetime.strptime(splitTime, r"%Y%m%dT%H%M%S%z") 

195 except ValueError: 

196 # This is explicitly chosen to be an f string as the string 

197 # contains control characters. 

198 logging.error( 

199 f"The supplied datetime {splitTime} could not be parsed correctly into " 

200 r"%Y%m%dT%H%M%S%z format" 

201 ) 

202 return False 

203 return True 

204 else: 

205 return False 

206 

207 

208class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections): 

209 """Base class for all configs used to define an `AnalysisPipelineTask`. 

210 

211 This base class defines two fields that should be used in all subclasses, 

212 atools, and bands. 

213 

214 The ``atools`` field is where the user configures which analysis tools will 

215 be run as part of this `PipelineTask`. 

216 

217 The bands field specifies which bands will be looped over for 

218 `AnalysisTools` which support parameterized bands. I.e. called once for 

219 each band in the list. 

220 """ 

221 

222 atools = ConfigurableActionStructField[AnalysisTool]( 

223 doc="The analysis tools that are to be run by this task at execution" 

224 ) 

225 # Temporarally alias these for backwards compatibility 

226 plots = atools 

227 metrics = atools 

228 bands = ListField[str]( 

229 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"] 

230 ) 

231 metric_tags = ListField[str]( 

232 doc="List of tags which will be added to all configurable actions", default=[] 

233 ) 

234 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True) 

235 reference_package = Field[str]( 

236 doc="A package who's version, at the time of metric upload to a " 

237 "time series database, will be converted to a timestamp of when " 

238 "that version was produced", 

239 default="lsst_distrib", 

240 ) 

241 timestamp_version = Field[str]( 

242 doc="Which time stamp should be used as the reference timestamp for a " 

243 "metric in a time series database, valid values are; " 

244 "reference_package_timestamp, run_timestamp, current_timestamp, " 

245 "dataset_timestamp and explicit_timestamp:datetime where datetime is " 

246 "given in the form %Y%m%dT%H%M%S%z", 

247 default="run_timestamp", 

248 check=_timestampValidator, 

249 ) 

250 

251 def applyConfigOverrides( 

252 self, 

253 instrument: Instrument | None, 

254 taskDefaultName: str, 

255 pipelineConfigs: Iterable[ConfigIR] | None, 

256 parameters: ParametersIR, 

257 label: str, 

258 ) -> None: 

259 extraConfig = {} 

260 if (value := parameters.mapping.get("sasquatch_dataset_identifier", None)) is not None: 

261 extraConfig["dataset_identifier"] = value 

262 if (value := parameters.mapping.get("sasquatch_reference_package", None)) is not None: 

263 extraConfig["reference_package"] = value 

264 if (value := parameters.mapping.get("sasquatch_timestamp_version", None)) is not None: 

265 if "explicit_timestamp" in value: 

266 try: 

267 _, splitTime = value.split(":") 

268 except ValueError as excpt: 

269 raise ValueError( 

270 "Explicit timestamp must be given in the format 'explicit_timestamp:datetime', " 

271 "where datetime is given in the form '%Y%m%dT%H%M%S%z" 

272 ) from excpt 

273 try: 

274 datetime.datetime.strptime(splitTime, r"%Y%m%dT%H%M%S%z") 

275 except ValueError as excpt: 

276 raise ValueError( 

277 f"The supplied datetime {splitTime} could not be parsed correctly into " 

278 "%Y%m%dT%H%M%S%z format" 

279 ) from excpt 

280 extraConfig["timestamp_version"] = value 

281 if extraConfig: 

282 newPipelineConfigs = [ConfigIR(rest=extraConfig)] 

283 if pipelineConfigs is not None: 

284 newPipelineConfigs.extend(pipelineConfigs) 

285 pipelineConfigs = newPipelineConfigs 

286 return super().applyConfigOverrides(instrument, taskDefaultName, pipelineConfigs, parameters, label) 

287 

288 def freeze(self): 

289 # Copy the meta configuration values to each of the configured tools 

290 # only do this if the tool has not been further specialized 

291 if not self._frozen: 

292 for tool in self.atools: 

293 for tag in self.metric_tags: 

294 tool.metric_tags.insert(-1, tag) 

295 super().freeze() 

296 

297 def validate(self): 

298 super().validate() 

299 # Validate that the required connections template is set. 

300 if self.connections.outputName == "Placeholder": # type: ignore 

301 raise RuntimeError("Connections class 'outputName' must have a config explicitly set") 

302 

303 

304class _StandinPlotInfo(dict): 

305 """This class is an implementation detail to support plots in the instance 

306 no PlotInfo object is present in the call to run. 

307 """ 

308 

309 def __missing__(self, key): 

310 return "" 

311 

312 

313class AnalysisPipelineTask(PipelineTask): 

314 """Base class for `PipelineTasks` intended to run `AnalysisTools`. 

315 

316 The run method will run all of the `AnalysisTools` defined in the config 

317 class. 

318 """ 

319 

320 # Typing config because type checkers dont know about our Task magic 

321 config: AnalysisBaseConfig 

322 ConfigClass = AnalysisBaseConfig 

323 

324 warnings_all = ( 

325 "divide by zero encountered in divide", 

326 "invalid value encountered in arcsin", 

327 "invalid value encountered in cos", 

328 "invalid value encountered in divide", 

329 "invalid value encountered in log10", 

330 "invalid value encountered in scalar divide", 

331 "invalid value encountered in sin", 

332 "invalid value encountered in sqrt", 

333 "invalid value encountered in true_divide", 

334 "Mean of empty slice", 

335 ) 

336 

337 def _runTools(self, data: KeyedData, **kwargs) -> Struct: 

338 with warnings.catch_warnings(): 

339 # Change below to "in self.warnings_all" to find otherwise 

340 # unfiltered numpy warnings. 

341 for warning in (): 

342 warnings.filterwarnings("error", warning, RuntimeWarning) 

343 results = Struct() 

344 results.metrics = MetricMeasurementBundle( 

345 dataset_identifier=self.config.dataset_identifier, 

346 reference_package=self.config.reference_package, 

347 timestamp_version=self.config.timestamp_version, 

348 ) 

349 # copy plot info to be sure each action sees its own copy 

350 plotInfo = kwargs.get("plotInfo") 

351 plotKey = f"{self.config.connections.outputName}_{{name}}" 

352 weakrefArgs = [] 

353 for name, action in self.config.atools.items(): 

354 kwargs["plotInfo"] = deepcopy(plotInfo) 

355 actionResult = action(data, **kwargs) 

356 metricAccumulate = [] 

357 for resultName, value in actionResult.items(): 

358 match value: 

359 case PlotTypes(): 

360 setattr(results, plotKey.format(name=resultName), value) 

361 weakrefArgs.append(value) 

362 case Measurement(): 

363 metricAccumulate.append(value) 

364 # only add the metrics if there are some 

365 if metricAccumulate: 

366 results.metrics[name] = metricAccumulate 

367 # Wrap the return struct in a finalizer so that when results is 

368 # garbage collected the plots will be closed. 

369 # TODO: This finalize step closes all open plots at the conclusion 

370 # of a task. When DM-39114 is implemented, this step should not 

371 # be required and may be removed. 

372 weakref.finalize(results, _plotCloser, *weakrefArgs) 

373 return results 

374 

375 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct: 

376 """Produce the outputs associated with this `PipelineTask`. 

377 

378 Parameters 

379 ---------- 

380 data : `KeyedData` 

381 The input data from which all `AnalysisTools` will run and produce 

382 outputs. A side note, the python typing specifies that this can be 

383 None, but this is only due to a limitation in python where in order 

384 to specify that all arguments be passed only as keywords the 

385 argument must be given a default. This argument most not actually 

386 be None. 

387 **kwargs 

388 Additional arguments that are passed through to the `AnalysisTools` 

389 specified in the configuration. 

390 

391 Returns 

392 ------- 

393 results : `~lsst.pipe.base.Struct` 

394 The accumulated results of all the plots and metrics produced by 

395 this `PipelineTask`. 

396 

397 Raises 

398 ------ 

399 ValueError 

400 Raised if the supplied data argument is `None` 

401 """ 

402 if data is None: 

403 raise ValueError("data must not be none") 

404 if "bands" not in kwargs: 

405 kwargs["bands"] = list(self.config.bands) 

406 if "plotInfo" not in kwargs: 

407 kwargs["plotInfo"] = _StandinPlotInfo() 

408 kwargs["plotInfo"]["bands"] = kwargs["bands"] 

409 if "SN" not in kwargs["plotInfo"].keys(): 

410 kwargs["plotInfo"]["SN"] = "-" 

411 return self._runTools(data, **kwargs) 

412 

413 def runQuantum( 

414 self, 

415 butlerQC: QuantumContext, 

416 inputRefs: InputQuantizedConnection, 

417 outputRefs: OutputQuantizedConnection, 

418 ) -> None: 

419 """Override default runQuantum to load the minimal columns necessary 

420 to complete the action. 

421 

422 Parameters 

423 ---------- 

424 butlerQC : `~lsst.pipe.base.QuantumContext` 

425 A butler which is specialized to operate in the context of a 

426 `lsst.daf.butler.Quantum`. 

427 inputRefs : `InputQuantizedConnection` 

428 Datastructure whose attribute names are the names that identify 

429 connections defined in corresponding `PipelineTaskConnections` 

430 class. The values of these attributes are the 

431 `lsst.daf.butler.DatasetRef` objects associated with the defined 

432 input/prerequisite connections. 

433 outputRefs : `OutputQuantizedConnection` 

434 Datastructure whose attribute names are the names that identify 

435 connections defined in corresponding `PipelineTaskConnections` 

436 class. The values of these attributes are the 

437 `lsst.daf.butler.DatasetRef` objects associated with the defined 

438 output connections. 

439 """ 

440 # TODO: This rcParams modification is a temporary solution, hiding 

441 # a matplotlib warning indicating too many figures have been opened. 

442 # When DM-39114 is implemented, this should be removed. 

443 plt.rcParams.update({"figure.max_open_warning": 0}) 

444 inputs = butlerQC.get(inputRefs) 

445 dataId = butlerQC.quantum.dataId 

446 plotInfo = self.parsePlotInfo(inputs, dataId) 

447 data = self.loadData(inputs["data"]) 

448 if "skymap" in inputs.keys(): 

449 skymap = inputs["skymap"] 

450 else: 

451 skymap = None 

452 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap) 

453 butlerQC.put(outputs, outputRefs) 

454 

455 def _populatePlotInfoWithDataId( 

456 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None 

457 ) -> None: 

458 """Update the plotInfo with the dataId values. 

459 

460 Parameters 

461 ---------- 

462 plotInfo : `dict` 

463 The plotInfo dictionary to update. 

464 dataId : `lsst.daf.butler.DataCoordinate` 

465 The dataId to use to update the plotInfo. 

466 """ 

467 if dataId is not None: 

468 plotInfo.update(dataId.mapping) 

469 

470 def parsePlotInfo( 

471 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data" 

472 ) -> Mapping[str, str]: 

473 """Parse the inputs and dataId to get the information needed to 

474 to add to the figure. 

475 

476 Parameters 

477 ---------- 

478 inputs: `dict` 

479 The inputs to the task 

480 dataCoordinate: `lsst.daf.butler.DataCoordinate` 

481 The dataId that the task is being run on. 

482 connectionName: `str`, optional 

483 Name of the input connection to use for determining table name. 

484 

485 Returns 

486 ------- 

487 plotInfo : `dict` 

488 """ 

489 

490 if inputs is None: 

491 tableName = "" 

492 run = "" 

493 else: 

494 tableName = inputs[connectionName].ref.datasetType.name 

495 run = inputs[connectionName].ref.run 

496 

497 # Initialize the plot info dictionary 

498 plotInfo = {"tableName": tableName, "run": run} 

499 

500 self._populatePlotInfoWithDataId(plotInfo, dataId) 

501 return plotInfo 

502 

503 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData: 

504 """Load the minimal set of keyed data from the input dataset. 

505 

506 Parameters 

507 ---------- 

508 handle : `DeferredDatasetHandle` 

509 Handle to load the dataset with only the specified columns. 

510 names : `Iterable` of `str` 

511 The names of keys to extract from the dataset. 

512 If `names` is `None` then the `collectInputNames` method 

513 is called to generate the names. 

514 For most purposes these are the names of columns to load from 

515 a catalog or data frame. 

516 

517 Returns 

518 ------- 

519 result: `KeyedData` 

520 The dataset with only the specified keys loaded. 

521 """ 

522 if names is None: 

523 names = self.collectInputNames() 

524 return cast(KeyedData, handle.get(parameters={"columns": names})) 

525 

526 def collectInputNames(self) -> Iterable[str]: 

527 """Get the names of the inputs. 

528 

529 If using the default `loadData` method this will gather the names 

530 of the keys to be loaded from an input dataset. 

531 

532 Returns 

533 ------- 

534 inputs : `Iterable` of `str` 

535 The names of the keys in the `KeyedData` object to extract. 

536 

537 """ 

538 inputs = set() 

539 for band in self.config.bands: 

540 for action in self.config.atools: 

541 for key, _ in action.getFormattedInputSchema(band=band): 

542 inputs.add(key) 

543 return inputs