Coverage for python/lsst/analysis/tools/interfaces/_analysisTools.py: 22%

176 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-26 04:07 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("AnalysisTool",) 

25 

26from collections import ChainMap 

27from collections.abc import Mapping 

28from functools import wraps 

29from operator import attrgetter 

30from typing import Callable, Iterable, Protocol, runtime_checkable 

31 

32import lsst.pex.config as pexConfig 

33from lsst.obs.base import Instrument 

34from lsst.pex.config import Field, FieldValidationError, ListField 

35from lsst.pex.config.configurableActions import ConfigurableActionField 

36from lsst.pipe.base import Pipeline 

37from lsst.verify import Measurement 

38 

39from ._actions import AnalysisAction, JointAction, JointResults, NoPlot, PlotAction 

40from ._interfaces import KeyedData, KeyedDataSchema, KeyedResults, PlotTypes 

41from ._stages import BasePrep, BaseProcess, BaseProduce 

42 

43 

44@runtime_checkable 

45class _HasOutputNames(Protocol): 

46 def getOutputNames(self, config: pexConfig.Config | None = None) -> Iterable[str]: ... 46 ↛ exitline 46 didn't jump to line 46, because

47 

48 

49def _finalizeWrapper( 

50 f: Callable[[AnalysisTool], None], cls: type[AnalysisTool] 

51) -> Callable[[AnalysisTool], None]: 

52 """Wrap a classes finalize function to ensure the base classes special 

53 finalize method only fires after the most derived finalize method. 

54 

55 Parameters 

56 ---------- 

57 f : `Callable` 

58 Function that is being wrapped 

59 cls : `type` of `AnalysisTool` 

60 The class which is having its function wrapped 

61 

62 Returns 

63 ------- 

64 function : `Callable` 

65 The new function which wraps the old 

66 """ 

67 

68 @wraps(f) 

69 def wrapper(self: AnalysisTool) -> None: 

70 # call the wrapped finalize function 

71 f(self) 

72 # get the method resolution order for the self variable 

73 mro = self.__class__.mro() 

74 

75 # Find which class in the mro that last defines a finalize method 

76 # note that this is in the reverse order from the mro, because the 

77 # last class in an inheritance stack is the first in the mro (aka you 

78 # walk from the furthest child first. 

79 # 

80 # Also note that the most derived finalize method need not be the same 

81 # as the type of self, as that might inherit from a parent somewhere 

82 # between it and the furthest parent. 

83 mostDerived: type | None = None 

84 for klass in mro: 

85 # inspect the classes dictionary to see if it specifically defines 

86 # finalize. This is needed because normal lookup will go through 

87 # the mro, but this needs to be restricted to each class. 

88 if "finalize" in vars(klass): 

89 mostDerived = klass 

90 break 

91 

92 # Find what stage in the MRO walking process the recursive function 

93 # call is in. 

94 this = super(cls, self).__thisclass__ 

95 

96 # If the current place in the MRO walking is also the class that 

97 # defines the most derived instance of finalize, then call the base 

98 # classes private finalize that must be called after everything else. 

99 if mostDerived is not None and this == mostDerived: 

100 self._baseFinalize() 

101 

102 return wrapper 

103 

104 

105class AnalysisTool(AnalysisAction): 

106 r"""A tool which which calculates a single type of analysis on input data, 

107 though it may return more than one result. 

108 

109 Although `AnalysisTool`\ s are considered a single type of analysis, the 

110 classes themselves can be thought of as a container. `AnalysisTool`\ s 

111 are aggregations of `AnalysisAction`\ s to form prep, process, and 

112 produce stages. These stages allow better reuse of individual 

113 `AnalysisActions` and easier introspection in contexts such as a notebook 

114 or interpreter. 

115 

116 An `AnalysisTool` can be thought of an an individual configuration that 

117 specifies which `AnalysisAction` should run for each stage. 

118 

119 The stages themselves are also configurable, allowing control over various 

120 aspects of the individual `AnalysisAction`\ s. 

121 """ 

122 

123 prep = ConfigurableActionField[AnalysisAction](doc="Action to run to prepare inputs", default=BasePrep) 

124 process = ConfigurableActionField[AnalysisAction]( 

125 doc="Action to process data into intended form", default=BaseProcess 

126 ) 

127 produce = ConfigurableActionField[AnalysisAction]( 

128 doc="Action to perform any finalization steps", default=BaseProduce 

129 ) 

130 metric_tags = ListField[str]( 

131 doc="List of tags which will be associated with metric measurement(s)", default=[] 

132 ) 

133 

134 def __init_subclass__(cls: type[AnalysisTool], **kwargs): 

135 super().__init_subclass__(**kwargs) 

136 # Wrap all definitions of the finalize method in a special wrapper that 

137 # ensures that the bases classes private finalize is called last. 

138 if "finalize" in vars(cls): 

139 cls.finalize = _finalizeWrapper(cls.finalize, cls) 

140 

141 dynamicOutputNames: bool | Field[bool] = False 

142 """Determines whether to grant the ``getOutputNames`` method access to 

143 config parameters. 

144 """ 

145 

146 parameterizedBand: bool | Field[bool] = True 

147 """Specifies if an `AnalysisTool` may parameterize a band within any field 

148 in any stage, or if the set of bands is already uniquely determined though 

149 configuration. I.e. can this `AnalysisTool` be automatically looped over to 

150 produce a result for multiple bands. 

151 """ 

152 

153 propagateData: bool | Field[bool] = False 

154 """If this value is set to True, the input data `KeyedData` will be passed 

155 to each stage in addition to the ``prep`` stage. Any keys created in a 

156 stage with the same key that exists in the input ``data`` will shadow that 

157 key/value. 

158 """ 

159 

160 def __call__(self, data: KeyedData, **kwargs) -> KeyedResults: 

161 bands = kwargs.pop("bands", None) 

162 if "plotInfo" in kwargs and kwargs.get("plotInfo") is not None: 

163 if "plotName" not in kwargs["plotInfo"] or kwargs["plotInfo"]["plotName"] is None: 

164 kwargs["plotInfo"]["plotName"] = self.identity 

165 if not self.parameterizedBand or bands is None: 

166 if "band" not in kwargs: 

167 # Some tasks require a "band" key for naming. This shouldn't 

168 # affect the results. DM-35813 should make this unnecessary. 

169 kwargs["band"] = "analysisTools" 

170 return self._call_single(data, **kwargs) 

171 results: KeyedResults = {} 

172 for band in bands: 

173 kwargs["band"] = band 

174 if "plotInfo" in kwargs: 

175 kwargs["plotInfo"]["bands"] = band 

176 subResult = self._call_single(data, **kwargs) 

177 for key, value in subResult.items(): 

178 match value: 

179 case PlotTypes(): 

180 results[f"{band}_{key}"] = value 

181 case Measurement(): 

182 results[key] = value 

183 return results 

184 

185 def _call_single(self, data: KeyedData, **kwargs) -> KeyedResults: 

186 # create a shallow copy of kwargs 

187 kwargs = dict(**kwargs) 

188 kwargs["metric_tags"] = list(self.metric_tags or ()) 

189 prepped: KeyedData = self.prep(data, **kwargs) # type: ignore 

190 if self.propagateData: 

191 prepped = ChainMap(data, prepped) 

192 processed: KeyedData = self.process(prepped, **kwargs) # type: ignore 

193 if self.propagateData: 

194 processed = ChainMap(data, processed) 

195 finalized: ( 

196 Mapping[str, PlotTypes] | PlotTypes | Mapping[str, Measurement] | Measurement | JointResults 

197 ) = self.produce( 

198 processed, **kwargs 

199 ) # type: ignore 

200 return self._process_single_results(finalized) 

201 

202 def _getPlotType(self) -> str: 

203 match self.produce: 

204 case PlotAction(): 

205 return self.produce.getPlotType() 

206 case JointAction(plot=NoPlot()): 

207 pass 

208 case JointAction(plot=plotter): 

209 return plotter.getPlotType() 

210 

211 return "" 

212 

213 def _process_single_results( 

214 self, 

215 results: Mapping[str, PlotTypes] | PlotTypes | Mapping[str, Measurement] | Measurement | JointResults, 

216 ) -> KeyedResults: 

217 accumulation = {} 

218 predicate = f"{self.identity}" if self.identity else "" 

219 match results: 

220 case Mapping(): 

221 suffix = self._getPlotType() 

222 for key, value in results.items(): 

223 match value: 

224 case Measurement(): 

225 iterable = (predicate, key) 

226 case PlotTypes(): 

227 iterable = (predicate, key, suffix) 

228 case _: 

229 raise RuntimeError(f"Unexpected {key=}, {value=} from:\n{self=}") 

230 refKey = "_".join(x for x in iterable if x) 

231 accumulation[refKey] = value 

232 case PlotTypes(): 

233 suffix = self._getPlotType() 

234 refKey = "_".join(x for x in (predicate, suffix) if x) 

235 accumulation[refKey] = results 

236 case Measurement(): 

237 accumulation[f"{predicate}"] = results 

238 case JointResults(plot=plotResults, metric=metricResults): 

239 if plotResults is not None: 

240 subResult = self._process_single_results(plotResults) 

241 accumulation.update(subResult) 

242 if metricResults is not None: 

243 subResult = self._process_single_results(metricResults) 

244 accumulation.update(subResult) 

245 return accumulation 

246 

247 def getInputSchema(self) -> KeyedDataSchema: 

248 return self.prep.getInputSchema() 

249 

250 def populatePrepFromProcess(self): 

251 """Add additional inputs to the prep stage if supported. 

252 

253 If the configured prep action supports adding to it's input schema, 

254 attempt to add the required inputs schema from the process stage to the 

255 prep stage. 

256 

257 This method will be a no-op if the prep action does not support this 

258 feature. 

259 """ 

260 self.prep.addInputSchema(self.process.getInputSchema()) 

261 

262 def getOutputNames(self, config: pexConfig.Config | None = None) -> Iterable[str]: 

263 """Return the names of the plots produced by this analysis tool. 

264 

265 If there is a `PlotAction` defined in the produce action, these names 

266 will either come from the `PlotAction` if it defines a 

267 ``getOutputNames`` method (likely if it returns a mapping of figures), 

268 or a default value is used and a single figure is assumed. 

269 

270 Parameters 

271 ---------- 

272 config : `lsst.pex.config.Config`, optional 

273 Configuration of the task. This is only used if the output naming 

274 needs to be config-aware. 

275 

276 Returns 

277 ------- 

278 result : `tuple` of `str` 

279 Names for each plot produced by this action. 

280 """ 

281 match self.produce: 

282 case JointAction(plot=NoPlot()): 

283 return tuple() 

284 case _HasOutputNames(): 

285 outNames = tuple(self.produce.getOutputNames(config=config)) 

286 case _: 

287 raise ValueError(f"Unsupported Action type {type(self.produce)} for getting output names") 

288 

289 results = [] 

290 suffix = self._getPlotType() 

291 if self.parameterizedBand: 

292 prefix = "_".join(x for x in ("{band}", self.identity) if x) 

293 else: 

294 prefix = f"{self.identity}" if self.identity else "" 

295 

296 if outNames: 

297 for name in outNames: 

298 results.append("_".join(x for x in (prefix, name, suffix) if x)) 

299 else: 

300 results.append("_".join(x for x in (prefix, suffix) if x)) 

301 return results 

302 

303 @classmethod 

304 def fromPipeline( 

305 cls, 

306 pipeline: str | Pipeline, 

307 name: str, 

308 fullpath: bool = False, 

309 instrument: Instrument | str | None = None, 

310 ) -> AnalysisTool | None: 

311 """Construct an `AnalysisTool` from a definition written in a 

312 `~lsst.pipe.base.Pipeline`. 

313 

314 Parameters 

315 ---------- 

316 pipeline : `str` or `~lsst.pipe.base.Pipeline` 

317 The pipeline to load the `AnalysisTool` from. 

318 name : `str` 

319 The name of the analysis tool to run. This can either be just the 

320 name assigned to the tool, or an absolute name in a config 

321 hierarchy. 

322 fullpath : `bool` 

323 Determines if the name is interpreted as an absolute path in a 

324 config hierarchy, or is relative to an `AnalysisTool` ``atools`` 

325 `~lsst.pex.config.configurableActions.ConfigurableActionStructField` 

326 . 

327 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` or\ 

328 `None` 

329 Either a derived class object of a `lsst.daf.butler.instrument` or 

330 a string corresponding to a fully qualified 

331 `lsst.daf.butler.instrument` name or None if no instrument needs 

332 specified or if the pipeline contains the instrument in it. 

333 Defaults to None. 

334 

335 Returns 

336 ------- 

337 tool : `AnalysisTool` 

338 The loaded `AnalysisTool` as configured in the pipeline. 

339 

340 Raises 

341 ------ 

342 ValueError 

343 Raised if the config field specified does not point to an 

344 `AnalysisTool`. 

345 Raised if an instrument is specified and it conflicts with the 

346 pipelines instrument. 

347 """ 

348 if not isinstance(pipeline, Pipeline): 

349 pipeline = Pipeline.fromFile(pipeline) 

350 # If the caller specified an instrument, verify it does not conflict 

351 # with the pipelines instrument, and add it to the pipeline 

352 if instrument is not None: 

353 if (pipeInstrument := pipeline.getInstrument()) and pipeInstrument != instrument: 

354 raise ValueError( 

355 f"The supplied instrument {instrument} conflicts with the pipelines instrument " 

356 f"{pipeInstrument}." 

357 ) 

358 else: 

359 pipeline.addInstrument(instrument) 

360 try: 

361 pipelineGraph = pipeline.to_graph() 

362 except (FieldValidationError, ValueError) as err: 

363 raise ValueError( 

364 "There was an error instantiating the pipeline, do you need to specify an instrument?" 

365 ) from err 

366 if not fullpath: 

367 name = f"atools.{name}" 

368 for task in pipelineGraph.tasks.values(): 

369 config = task.config 

370 try: 

371 attr = attrgetter(name)(config) 

372 except AttributeError: 

373 continue 

374 if not isinstance(attr, AnalysisTool): 

375 raise ValueError("The requested name did not refer to an analysisTool") 

376 return attr 

377 return None 

378 

379 def finalize(self) -> None: 

380 """Run any finalization code that depends on configuration being 

381 complete. 

382 """ 

383 pass 

384 

385 def _baseFinalize(self) -> None: 

386 self.populatePrepFromProcess() 

387 

388 def freeze(self): 

389 if not self.__dict__.get("_finalizeRun"): 

390 self.finalize() 

391 self.__dict__["_finalizeRun"] = True 

392 super().freeze() 

393 

394 

395# explicitly wrap the finalize of the base class 

396AnalysisTool.finalize = _finalizeWrapper(AnalysisTool.finalize, AnalysisTool)