Coverage for python/lsst/analysis/tools/interfaces/_task.py: 21%

143 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-09 03:19 -0700

1# This file is part of analysis_tools. # 

2# Developed for the LSST Data Management System. 

3# This product includes software developed by the LSST Project 

4# (https://www.lsst.org). 

5# See the COPYRIGHT file at the top-level directory of this distribution 

6# for details of code ownership. 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the GNU General Public License 

19# along with this program. If not, see <https://www.gnu.org/licenses/>. 

20from __future__ import annotations 

21 

22"""Base class implementation for the classes needed in creating `PipelineTasks` 

23which execute `AnalysisTools`. 

24 

25The classes defined in this module have all the required behaviors for 

26defining, introspecting, and executing `AnalysisTools` against an input dataset 

27type. 

28 

29Subclasses of these tasks should specify specific datasets to consume in their 

30connection classes and should specify a unique name 

31""" 

32 

33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask") 

34 

35from copy import deepcopy 

36from typing import TYPE_CHECKING, Any, Iterable, Mapping, MutableMapping, cast 

37 

38from lsst.verify import Measurement 

39 

40if TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true

41 from lsst.daf.butler import DeferredDatasetHandle 

42 

43from lsst.daf.butler import DataCoordinate 

44from lsst.pex.config import Field, ListField 

45from lsst.pex.config.configurableActions import ConfigurableActionStructField 

46from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

47from lsst.pipe.base import connectionTypes as ct 

48from lsst.pipe.base.butlerQuantumContext import ButlerQuantumContext 

49from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection 

50 

51from ._actions import JointAction, MetricAction, NoMetric 

52from ._analysisTools import AnalysisTool 

53from ._interfaces import KeyedData, PlotTypes 

54from ._metricMeasurementBundle import MetricMeasurementBundle 

55 

56 

57class AnalysisBaseConnections( 

58 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"} 

59): 

60 r"""Base class for Connections used for AnalysisTools PipelineTasks. 

61 

62 This class has a pre-defined output connection for the 

63 MetricMeasurementMapping. The dataset type name for this connection is 

64 determined by the template ``outputName``. 

65 

66 Output connections for plots created by `AnalysisPlot`\ s are created 

67 dynamically when an instance of the class is created. The init method 

68 examines all the `AnalysisPlot` actions specified in the associated 

69 `AnalysisBaseConfig` subclass accumulating all the info needed to 

70 create the output connections. 

71 

72 The dimensions for all of the output connections (metric and plot) will 

73 be the same as the dimensions specified for the AnalysisBaseConnections 

74 subclass (i.e. quantum dimensions). 

75 """ 

76 

77 metrics = ct.Output( 

78 doc="Metrics calculated on input dataset type", 

79 name="{outputName}_metrics", 

80 storageClass="MetricMeasurementBundle", 

81 ) 

82 

83 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore 

84 # Validate that the outputName template has been set in config. This 

85 # should have been checked early with the configs validate method, but 

86 # it is possible for someone to manually create everything in a script 

87 # without running validate, so also check it late here. 

88 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore 

89 raise RuntimeError( 

90 "Subclasses must specify an alternative value for the defaultTemplate `outputName`" 

91 ) 

92 super().__init__(config=config) 

93 

94 # All arguments must be passed by kw, but python has not method to do 

95 # that without specifying a default, so None is used. Validate that 

96 # it is not None. This is largely for typing reasons, as in the normal 

97 # course of operation code execution paths ensure this will not be None 

98 assert config is not None 

99 

100 for tool in config.atools: 

101 match tool.produce: 

102 case JointAction(): 

103 if isinstance(tool.produce.metric, NoMetric): 

104 continue 

105 if len(tool.produce.metric.units) != 0: 

106 hasMetrics = True 

107 break 

108 case MetricAction(): 

109 hasMetrics = True 

110 break 

111 else: 

112 hasMetrics = False 

113 

114 # Set the dimensions for the metric 

115 if hasMetrics: 

116 self.metrics = ct.Output( 

117 name=self.metrics.name, 

118 doc=self.metrics.doc, 

119 storageClass=self.metrics.storageClass, 

120 dimensions=self.dimensions, 

121 multiple=False, 

122 isCalibration=False, 

123 ) 

124 else: 

125 # There are no metrics to produce, remove the output connection 

126 self.outputs.remove("metrics") 

127 

128 # Look for any conflicting names, creating a set of them, as these 

129 # will be added to the instance as well as recorded in the outputs 

130 # set. 

131 existingNames = set(dir(self)) 

132 

133 # Accumulate all the names to be used from all of the defined 

134 # AnalysisPlots. 

135 names: list[str] = [] 

136 for action in config.atools: 

137 if action.parameterizedBand: 

138 for band in config.bands: 

139 names.extend(name.format(band=band) for name in action.getOutputNames()) 

140 else: 

141 names.extend(action.getOutputNames()) 

142 

143 # For each of the names found, create output connections. 

144 for name in names: 

145 name = f"{outputName}_{name}" 

146 if name in self.outputs or name in existingNames: 

147 raise NameError( 

148 f"Plot with name {name} conflicts with existing connection" 

149 " are two plots named the same?" 

150 ) 

151 outConnection = ct.Output( 

152 name=name, 

153 storageClass="Plot", 

154 doc="Dynamic connection for plotting", 

155 dimensions=self.dimensions, 

156 ) 

157 object.__setattr__(self, name, outConnection) 

158 self.outputs.add(name) 

159 

160 

161class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections): 

162 """Base class for all configs used to define an `AnalysisPipelineTask`. 

163 

164 This base class defines two fields that should be used in all subclasses, 

165 atools, and bands. 

166 

167 The ``atools`` field is where the user configures which analysis tools will 

168 be run as part of this `PipelineTask`. 

169 

170 The bands field specifies which bands will be looped over for 

171 `AnalysisTools` which support parameterized bands. I.e. called once for 

172 each band in the list. 

173 """ 

174 

175 atools = ConfigurableActionStructField[AnalysisTool]( 

176 doc="The analysis tools that are to be run by this task at execution" 

177 ) 

178 # Temporarally alias these for backwards compatibility 

179 plots = atools 

180 metrics = atools 

181 bands = ListField[str]( 

182 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"] 

183 ) 

184 metric_tags = ListField[str]( 

185 doc="List of tags which will be added to all configurable actions", default=[] 

186 ) 

187 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True) 

188 reference_package = Field[str]( 

189 doc="A package who's version, at the time of metric upload to a " 

190 "time series database, will be converted to a timestamp of when " 

191 "that version was produced", 

192 default="lsst_distrib", 

193 ) 

194 timestamp_version = Field[str]( 194 ↛ exitline 194 didn't jump to the function exit

195 doc="Which time stamp should be used as the reference timestamp for a " 

196 "metric in a time series database, valid values are; " 

197 "reference_package_timestamp, run_timestamp, current_timestamp, " 

198 "and dataset_timestamp", 

199 default="run_timestamp", 

200 check=lambda x: x 

201 in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"), 

202 ) 

203 

204 def freeze(self): 

205 # Copy the meta configuration values to each of the configured tools 

206 # only do this if the tool has not been further specialized 

207 if not self._frozen: 

208 for tool in self.atools: 

209 for tag in self.metric_tags: 

210 tool.metric_tags.insert(-1, tag) 

211 super().freeze() 

212 

213 def validate(self): 

214 super().validate() 

215 # Validate that the required connections template is set. 

216 if self.connections.outputName == "Placeholder": # type: ignore 

217 raise RuntimeError("Connections class 'outputName' must have a config explicitly set") 

218 

219 

220class _StandinPlotInfo(dict): 

221 """This class is an implementation detail to support plots in the instance 

222 no PlotInfo object is present in the call to run. 

223 """ 

224 

225 def __missing__(self, key): 

226 return "" 

227 

228 

229class AnalysisPipelineTask(PipelineTask): 

230 """Base class for `PipelineTasks` intended to run `AnalysisTools`. 

231 

232 The run method will run all of the `AnalysisTools` defined in the config 

233 class. 

234 """ 

235 

236 # Typing config because type checkers dont know about our Task magic 

237 config: AnalysisBaseConfig 

238 ConfigClass = AnalysisBaseConfig 

239 

240 def _runTools(self, data: KeyedData, **kwargs) -> Struct: 

241 results = Struct() 

242 results.metrics = MetricMeasurementBundle( 

243 dataset_identifier=self.config.dataset_identifier, 

244 reference_package=self.config.reference_package, 

245 timestamp_version=self.config.timestamp_version, 

246 ) 

247 # copy plot info to be sure each action sees its own copy 

248 plotInfo = kwargs.get("plotInfo") 

249 plotKey = f"{self.config.connections.outputName}_{{name}}" 

250 for name, action in self.config.atools.items(): 

251 kwargs["plotInfo"] = deepcopy(plotInfo) 

252 actionResult = action(data, **kwargs) 

253 metricAccumulate = [] 

254 for resultName, value in actionResult.items(): 

255 match value: 

256 case PlotTypes(): 

257 setattr(results, plotKey.format(name=resultName), value) 

258 case Measurement(): 

259 metricAccumulate.append(value) 

260 # only add the metrics if there are some 

261 if metricAccumulate: 

262 results.metrics[name] = metricAccumulate 

263 return results 

264 

265 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct: 

266 """Produce the outputs associated with this `PipelineTask` 

267 

268 Parameters 

269 ---------- 

270 data : `KeyedData` 

271 The input data from which all `AnalysisTools` will run and produce 

272 outputs. A side note, the python typing specifies that this can be 

273 None, but this is only due to a limitation in python where in order 

274 to specify that all arguments be passed only as keywords the 

275 argument must be given a default. This argument most not actually 

276 be None. 

277 **kwargs 

278 Additional arguments that are passed through to the `AnalysisTools` 

279 specified in the configuration. 

280 

281 Returns 

282 ------- 

283 results : `~lsst.pipe.base.Struct` 

284 The accumulated results of all the plots and metrics produced by 

285 this `PipelineTask`. 

286 

287 Raises 

288 ------ 

289 ValueError 

290 Raised if the supplied data argument is `None` 

291 """ 

292 if data is None: 

293 raise ValueError("data must not be none") 

294 if "bands" not in kwargs: 

295 kwargs["bands"] = list(self.config.bands) 

296 if "plotInfo" not in kwargs: 

297 kwargs["plotInfo"] = _StandinPlotInfo() 

298 kwargs["plotInfo"]["bands"] = kwargs["bands"] 

299 if "SN" not in kwargs["plotInfo"].keys(): 

300 kwargs["plotInfo"]["SN"] = "-" 

301 return self._runTools(data, **kwargs) 

302 

303 def runQuantum( 

304 self, 

305 butlerQC: ButlerQuantumContext, 

306 inputRefs: InputQuantizedConnection, 

307 outputRefs: OutputQuantizedConnection, 

308 ) -> None: 

309 """Override default runQuantum to load the minimal columns necessary 

310 to complete the action. 

311 

312 Parameters 

313 ---------- 

314 butlerQC : `ButlerQuantumContext` 

315 A butler which is specialized to operate in the context of a 

316 `lsst.daf.butler.Quantum`. 

317 inputRefs : `InputQuantizedConnection` 

318 Datastructure whose attribute names are the names that identify 

319 connections defined in corresponding `PipelineTaskConnections` 

320 class. The values of these attributes are the 

321 `lsst.daf.butler.DatasetRef` objects associated with the defined 

322 input/prerequisite connections. 

323 outputRefs : `OutputQuantizedConnection` 

324 Datastructure whose attribute names are the names that identify 

325 connections defined in corresponding `PipelineTaskConnections` 

326 class. The values of these attributes are the 

327 `lsst.daf.butler.DatasetRef` objects associated with the defined 

328 output connections. 

329 """ 

330 inputs = butlerQC.get(inputRefs) 

331 dataId = butlerQC.quantum.dataId 

332 plotInfo = self.parsePlotInfo(inputs, dataId) 

333 data = self.loadData(inputs["data"]) 

334 if "skymap" in inputs.keys(): 

335 skymap = inputs["skymap"] 

336 else: 

337 skymap = None 

338 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap) 

339 butlerQC.put(outputs, outputRefs) 

340 

341 def _populatePlotInfoWithDataId( 

342 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None 

343 ) -> None: 

344 """Update the plotInfo with the dataId values. 

345 

346 Parameters 

347 ---------- 

348 plotInfo : `dict` 

349 The plotInfo dictionary to update. 

350 dataId : `lsst.daf.butler.DataCoordinate` 

351 The dataId to use to update the plotInfo. 

352 """ 

353 if dataId is not None: 

354 for dataInfo in dataId: 

355 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

356 

357 def parsePlotInfo( 

358 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data" 

359 ) -> Mapping[str, str]: 

360 """Parse the inputs and dataId to get the information needed to 

361 to add to the figure. 

362 

363 Parameters 

364 ---------- 

365 inputs: `dict` 

366 The inputs to the task 

367 dataCoordinate: `lsst.daf.butler.DataCoordinate` 

368 The dataId that the task is being run on. 

369 connectionName: `str`, optional 

370 Name of the input connection to use for determining table name. 

371 

372 Returns 

373 ------- 

374 plotInfo : `dict` 

375 """ 

376 

377 if inputs is None: 

378 tableName = "" 

379 run = "" 

380 else: 

381 tableName = inputs[connectionName].ref.datasetType.name 

382 run = inputs[connectionName].ref.run 

383 

384 # Initialize the plot info dictionary 

385 plotInfo = {"tableName": tableName, "run": run} 

386 

387 self._populatePlotInfoWithDataId(plotInfo, dataId) 

388 return plotInfo 

389 

390 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData: 

391 """Load the minimal set of keyed data from the input dataset. 

392 

393 Parameters 

394 ---------- 

395 handle : `DeferredDatasetHandle` 

396 Handle to load the dataset with only the specified columns. 

397 names : `Iterable` of `str` 

398 The names of keys to extract from the dataset. 

399 If `names` is `None` then the `collectInputNames` method 

400 is called to generate the names. 

401 For most purposes these are the names of columns to load from 

402 a catalog or data frame. 

403 

404 Returns 

405 ------- 

406 result: `KeyedData` 

407 The dataset with only the specified keys loaded. 

408 """ 

409 if names is None: 

410 names = self.collectInputNames() 

411 return cast(KeyedData, handle.get(parameters={"columns": names})) 

412 

413 def collectInputNames(self) -> Iterable[str]: 

414 """Get the names of the inputs. 

415 

416 If using the default `loadData` method this will gather the names 

417 of the keys to be loaded from an input dataset. 

418 

419 Returns 

420 ------- 

421 inputs : `Iterable` of `str` 

422 The names of the keys in the `KeyedData` object to extract. 

423 

424 """ 

425 inputs = set() 

426 for band in self.config.bands: 

427 for action in self.config.atools: 

428 for key, _ in action.getFormattedInputSchema(band=band): 

429 inputs.add(key) 

430 return inputs