Coverage for python/lsst/analysis/tools/interfaces/_analysisTools.py: 22%

166 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-24 11:21 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("AnalysisTool",) 

25 

26from collections.abc import Mapping 

27from functools import wraps 

28from operator import attrgetter 

29from typing import Callable, Iterable, Protocol, runtime_checkable 

30 

31import lsst.pex.config as pexConfig 

32from lsst.obs.base import Instrument 

33from lsst.pex.config import Field, FieldValidationError, ListField 

34from lsst.pex.config.configurableActions import ConfigurableActionField 

35from lsst.pipe.base import Pipeline 

36from lsst.verify import Measurement 

37 

38from ._actions import AnalysisAction, JointAction, JointResults, NoPlot, PlotAction 

39from ._interfaces import KeyedData, KeyedDataSchema, KeyedResults, PlotTypes 

40from ._stages import BasePrep, BaseProcess, BaseProduce 

41 

42 

43@runtime_checkable 

44class _HasOutputNames(Protocol): 

45 def getOutputNames(self, config: pexConfig.Config | None = None) -> Iterable[str]: ... 45 ↛ exitline 45 didn't return from function 'getOutputNames'

46 

47 

48def _finalizeWrapper( 

49 f: Callable[[AnalysisTool], None], cls: type[AnalysisTool] 

50) -> Callable[[AnalysisTool], None]: 

51 """Wrap a classes finalize function to ensure the base classes special 

52 finalize method only fires after the most derived finalize method. 

53 

54 Parameters 

55 ---------- 

56 f : `Callable` 

57 Function that is being wrapped 

58 cls : `type` of `AnalysisTool` 

59 The class which is having its function wrapped 

60 

61 Returns 

62 ------- 

63 function : `Callable` 

64 The new function which wraps the old 

65 """ 

66 

67 @wraps(f) 

68 def wrapper(self: AnalysisTool) -> None: 

69 # call the wrapped finalize function 

70 f(self) 

71 # get the method resolution order for the self variable 

72 mro = self.__class__.mro() 

73 

74 # Find which class in the mro that last defines a finalize method 

75 # note that this is in the reverse order from the mro, because the 

76 # last class in an inheritance stack is the first in the mro (aka you 

77 # walk from the furthest child first. 

78 # 

79 # Also note that the most derived finalize method need not be the same 

80 # as the type of self, as that might inherit from a parent somewhere 

81 # between it and the furthest parent. 

82 mostDerived: type | None = None 

83 for klass in mro: 

84 # inspect the classes dictionary to see if it specifically defines 

85 # finalize. This is needed because normal lookup will go through 

86 # the mro, but this needs to be restricted to each class. 

87 if "finalize" in vars(klass): 

88 mostDerived = klass 

89 break 

90 

91 # Find what stage in the MRO walking process the recursive function 

92 # call is in. 

93 this = super(cls, self).__thisclass__ 

94 

95 # If the current place in the MRO walking is also the class that 

96 # defines the most derived instance of finalize, then call the base 

97 # classes private finalize that must be called after everything else. 

98 if mostDerived is not None and this == mostDerived: 

99 self._baseFinalize() 

100 

101 return wrapper 

102 

103 

104class AnalysisTool(AnalysisAction): 

105 r"""A tool which which calculates a single type of analysis on input data, 

106 though it may return more than one result. 

107 

108 Although `AnalysisTool`\ s are considered a single type of analysis, the 

109 classes themselves can be thought of as a container. `AnalysisTool`\ s 

110 are aggregations of `AnalysisAction`\ s to form prep, process, and 

111 produce stages. These stages allow better reuse of individual 

112 `AnalysisActions` and easier introspection in contexts such as a notebook 

113 or interpreter. 

114 

115 An `AnalysisTool` can be thought of an an individual configuration that 

116 specifies which `AnalysisAction` should run for each stage. 

117 

118 The stages themselves are also configurable, allowing control over various 

119 aspects of the individual `AnalysisAction`\ s. 

120 """ 

121 

122 prep = ConfigurableActionField[AnalysisAction](doc="Action to run to prepare inputs", default=BasePrep) 

123 process = ConfigurableActionField[AnalysisAction]( 

124 doc="Action to process data into intended form", default=BaseProcess 

125 ) 

126 produce = ConfigurableActionField[AnalysisAction]( 

127 doc="Action to perform any finalization steps", default=BaseProduce 

128 ) 

129 metric_tags = ListField[str]( 

130 doc="List of tags which will be associated with metric measurement(s)", default=[] 

131 ) 

132 

133 def __init_subclass__(cls: type[AnalysisTool], **kwargs): 

134 super().__init_subclass__(**kwargs) 

135 # Wrap all definitions of the finalize method in a special wrapper that 

136 # ensures that the bases classes private finalize is called last. 

137 if "finalize" in vars(cls): 

138 cls.finalize = _finalizeWrapper(cls.finalize, cls) 

139 

140 dynamicOutputNames: bool | Field[bool] = False 

141 """Determines whether to grant the ``getOutputNames`` method access to 

142 config parameters. 

143 """ 

144 

145 parameterizedBand: bool | Field[bool] = True 

146 """Specifies if an `AnalysisTool` may parameterize a band within any field 

147 in any stage, or if the set of bands is already uniquely determined though 

148 configuration. I.e. can this `AnalysisTool` be automatically looped over to 

149 produce a result for multiple bands. 

150 """ 

151 

152 def __call__(self, data: KeyedData, **kwargs) -> KeyedResults: 

153 bands = kwargs.pop("bands", None) 

154 if "plotInfo" in kwargs and kwargs.get("plotInfo") is not None: 

155 if "plotName" not in kwargs["plotInfo"] or kwargs["plotInfo"]["plotName"] is None: 

156 kwargs["plotInfo"]["plotName"] = self.identity 

157 if not self.parameterizedBand or bands is None: 

158 if "band" not in kwargs: 

159 # Some tasks require a "band" key for naming. This shouldn't 

160 # affect the results. DM-35813 should make this unnecessary. 

161 kwargs["band"] = "analysisTools" 

162 return self._call_single(data, **kwargs) 

163 results: KeyedResults = {} 

164 for band in bands: 

165 kwargs["band"] = band 

166 if "plotInfo" in kwargs: 

167 kwargs["plotInfo"]["bands"] = band 

168 subResult = self._call_single(data, **kwargs) 

169 for key, value in subResult.items(): 

170 match value: 

171 case PlotTypes(): 

172 results[f"{band}_{key}"] = value 

173 case Measurement(): 

174 results[key] = value 

175 return results 

176 

177 def _call_single(self, data: KeyedData, **kwargs) -> KeyedResults: 

178 # create a shallow copy of kwargs 

179 kwargs = dict(**kwargs) 

180 kwargs["metric_tags"] = list(self.metric_tags or ()) 

181 prepped: KeyedData = self.prep(data, **kwargs) # type: ignore 

182 processed: KeyedData = self.process(prepped, **kwargs) # type: ignore 

183 finalized: ( 

184 Mapping[str, PlotTypes] | PlotTypes | Mapping[str, Measurement] | Measurement | JointResults 

185 ) = self.produce( 

186 processed, **kwargs 

187 ) # type: ignore 

188 return self._process_single_results(finalized) 

189 

190 def _getPlotType(self) -> str: 

191 match self.produce: 

192 case PlotAction(): 

193 return type(self.produce).__name__ 

194 case JointAction(plot=NoPlot()): 

195 pass 

196 case JointAction(plot=plotter): 

197 return type(plotter).__name__ 

198 

199 return "" 

200 

201 def _process_single_results( 

202 self, 

203 results: Mapping[str, PlotTypes] | PlotTypes | Mapping[str, Measurement] | Measurement | JointResults, 

204 ) -> KeyedResults: 

205 accumulation = {} 

206 suffix = self._getPlotType() 

207 predicate = f"{self.identity}" if self.identity else "" 

208 match results: 

209 case Mapping(): 

210 for key, value in results.items(): 

211 match value: 

212 case PlotTypes(): 

213 iterable = (predicate, key, suffix) 

214 case Measurement(): 

215 iterable = (predicate, key) 

216 refKey = "_".join(x for x in iterable if x) 

217 accumulation[refKey] = value 

218 case PlotTypes(): 

219 refKey = "_".join(x for x in (predicate, suffix) if x) 

220 accumulation[refKey] = results 

221 case Measurement(): 

222 accumulation[f"{predicate}"] = results 

223 case JointResults(plot=plotResults, metric=metricResults): 

224 if plotResults is not None: 

225 subResult = self._process_single_results(plotResults) 

226 accumulation.update(subResult) 

227 if metricResults is not None: 

228 subResult = self._process_single_results(metricResults) 

229 accumulation.update(subResult) 

230 return accumulation 

231 

232 def getInputSchema(self) -> KeyedDataSchema: 

233 return self.prep.getInputSchema() 

234 

235 def populatePrepFromProcess(self): 

236 """Add additional inputs to the prep stage if supported. 

237 

238 If the configured prep action supports adding to it's input schema, 

239 attempt to add the required inputs schema from the process stage to the 

240 prep stage. 

241 

242 This method will be a no-op if the prep action does not support this 

243 feature. 

244 """ 

245 self.prep.addInputSchema(self.process.getInputSchema()) 

246 

247 def getOutputNames(self, config: pexConfig.Config | None = None) -> Iterable[str]: 

248 """Return the names of the plots produced by this analysis tool. 

249 

250 If there is a `PlotAction` defined in the produce action, these names 

251 will either come from the `PlotAction` if it defines a 

252 ``getOutputNames`` method (likely if it returns a mapping of figures), 

253 or a default value is used and a single figure is assumed. 

254 

255 Parameters 

256 ---------- 

257 config : `lsst.pex.config.Config`, optional 

258 Configuration of the task. This is only used if the output naming 

259 needs to be config-aware. 

260 

261 Returns 

262 ------- 

263 result : `tuple` of `str` 

264 Names for each plot produced by this action. 

265 """ 

266 match self.produce: 

267 case JointAction(plot=NoPlot()): 

268 return tuple() 

269 case _HasOutputNames(): 

270 outNames = tuple(self.produce.getOutputNames(config=config)) 

271 case _: 

272 raise ValueError(f"Unsupported Action type {type(self.produce)} for getting output names") 

273 

274 results = [] 

275 suffix = self._getPlotType() 

276 if self.parameterizedBand: 

277 prefix = "_".join(x for x in ("{band}", self.identity) if x) 

278 else: 

279 prefix = f"{self.identity}" if self.identity else "" 

280 

281 if outNames: 

282 for name in outNames: 

283 results.append("_".join(x for x in (prefix, name, suffix) if x)) 

284 else: 

285 results.append("_".join(x for x in (prefix, suffix) if x)) 

286 return results 

287 

288 @classmethod 

289 def fromPipeline( 

290 cls, 

291 pipeline: str | Pipeline, 

292 name: str, 

293 fullpath: bool = False, 

294 instrument: Instrument | str | None = None, 

295 ) -> AnalysisTool | None: 

296 """Construct an `AnalysisTool` from a definition written in a 

297 `~lsst.pipe.base.Pipeline`. 

298 

299 Parameters 

300 ---------- 

301 pipeline : `str` or `~lsst.pipe.base.Pipeline` 

302 The pipeline to load the `AnalysisTool` from. 

303 name : `str` 

304 The name of the analysis tool to run. This can either be just the 

305 name assigned to the tool, or an absolute name in a config 

306 hierarchy. 

307 fullpath : `bool` 

308 Determines if the name is interpreted as an absolute path in a 

309 config hierarchy, or is relative to an `AnalysisTool` ``atools`` 

310 `~lsst.pex.config.configurableActions.ConfigurableActionStructField` 

311 . 

312 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` or\ 

313 `None` 

314 Either a derived class object of a `lsst.daf.butler.instrument` or 

315 a string corresponding to a fully qualified 

316 `lsst.daf.butler.instrument` name or None if no instrument needs 

317 specified or if the pipeline contains the instrument in it. 

318 Defaults to None. 

319 

320 Returns 

321 ------- 

322 tool : `AnalysisTool` 

323 The loaded `AnalysisTool` as configured in the pipeline. 

324 

325 Raises 

326 ------ 

327 ValueError 

328 Raised if the config field specified does not point to an 

329 `AnalysisTool`. 

330 Raised if an instrument is specified and it conflicts with the 

331 pipelines instrument. 

332 """ 

333 if not isinstance(pipeline, Pipeline): 

334 pipeline = Pipeline.fromFile(pipeline) 

335 # If the caller specified an instrument, verify it does not conflict 

336 # with the pipelines instrument, and add it to the pipeline 

337 if instrument is not None: 

338 if (pipeInstrument := pipeline.getInstrument()) and pipeInstrument != instrument: 

339 raise ValueError( 

340 f"The supplied instrument {instrument} conflicts with the pipelines instrument " 

341 f"{pipeInstrument}." 

342 ) 

343 else: 

344 pipeline.addInstrument(instrument) 

345 try: 

346 pipelineGraph = pipeline.to_graph() 

347 except (FieldValidationError, ValueError) as err: 

348 raise ValueError( 

349 "There was an error instantiating the pipeline, do you need to specify an instrument?" 

350 ) from err 

351 if not fullpath: 

352 name = f"atools.{name}" 

353 for task in pipelineGraph.tasks.values(): 

354 config = task.config 

355 try: 

356 attr = attrgetter(name)(config) 

357 except AttributeError: 

358 continue 

359 if not isinstance(attr, AnalysisTool): 

360 raise ValueError("The requested name did not refer to an analysisTool") 

361 return attr 

362 return None 

363 

364 def finalize(self) -> None: 

365 """Run any finalization code that depends on configuration being 

366 complete. 

367 """ 

368 pass 

369 

370 def _baseFinalize(self) -> None: 

371 self.populatePrepFromProcess() 

372 

373 def freeze(self): 

374 if not self.__dict__.get("_finalizeRun"): 

375 self.finalize() 

376 self.__dict__["_finalizeRun"] = True 

377 super().freeze() 

378 

379 

380# explicitly wrap the finalize of the base class 

381AnalysisTool.finalize = _finalizeWrapper(AnalysisTool.finalize, AnalysisTool)