Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_tasks 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import functools 

23import pandas as pd 

24from collections import defaultdict 

25 

26import lsst.pex.config as pexConfig 

27import lsst.pipe.base as pipeBase 

28from lsst.pipe.base import CmdLineTask, ArgumentParser 

29from lsst.coadd.utils.coaddDataIdContainer import CoaddDataIdContainer 

30 

31from .parquetTable import ParquetTable 

32from .multiBandUtils import makeMergeArgumentParser, MergeSourcesRunner 

33from .functors import CompositeFunctor, RAColumn, DecColumn, Column 

34 

35 

36def flattenFilters(df, filterDict, noDupCols=['coord_ra', 'coord_dec'], camelCase=False): 

37 """Flattens a dataframe with multilevel column index 

38 """ 

39 newDf = pd.DataFrame() 

40 for filt, filtShort in filterDict.items(): 

41 subdf = df[filt] 

42 columnFormat = '{0}{1}' if camelCase else '{0}_{1}' 

43 newColumns = {c: columnFormat.format(filtShort, c) 

44 for c in subdf.columns if c not in noDupCols} 

45 cols = list(newColumns.keys()) 

46 newDf = pd.concat([newDf, subdf[cols].rename(columns=newColumns)], axis=1) 

47 

48 newDf = pd.concat([subdf[noDupCols], newDf], axis=1) 

49 return newDf 

50 

51 

52class WriteObjectTableConfig(pexConfig.Config): 

53 priorityList = pexConfig.ListField( 

54 dtype=str, 

55 default=[], 

56 doc="Priority-ordered list of bands for the merge." 

57 ) 

58 engine = pexConfig.Field( 

59 dtype=str, 

60 default="pyarrow", 

61 doc="Parquet engine for writing (pyarrow or fastparquet)" 

62 ) 

63 coaddName = pexConfig.Field( 

64 dtype=str, 

65 default="deep", 

66 doc="Name of coadd" 

67 ) 

68 

69 def validate(self): 

70 pexConfig.Config.validate(self) 

71 if len(self.priorityList) == 0: 

72 raise RuntimeError("No priority list provided") 

73 

74 

75class WriteObjectTableTask(CmdLineTask): 

76 """Write filter-merged source tables to parquet 

77 """ 

78 _DefaultName = "writeObjectTable" 

79 ConfigClass = WriteObjectTableConfig 

80 RunnerClass = MergeSourcesRunner 

81 

82 # Names of table datasets to be merged 

83 inputDatasets = ('forced_src', 'meas', 'ref') 

84 

85 # Tag of output dataset written by `MergeSourcesTask.write` 

86 outputDataset = 'obj' 

87 

88 def __init__(self, butler=None, schema=None, **kwargs): 

89 # It is a shame that this class can't use the default init for CmdLineTask 

90 # But to do so would require its own special task runner, which is many 

91 # more lines of specialization, so this is how it is for now 

92 CmdLineTask.__init__(self, **kwargs) 

93 

94 def runDataRef(self, patchRefList): 

95 """! 

96 @brief Merge coadd sources from multiple bands. Calls @ref `run` which must be defined in 

97 subclasses that inherit from MergeSourcesTask. 

98 @param[in] patchRefList list of data references for each filter 

99 """ 

100 catalogs = dict(self.readCatalog(patchRef) for patchRef in patchRefList) 

101 dataId = patchRefList[0].dataId 

102 mergedCatalog = self.run(catalogs, tract=dataId['tract'], patch=dataId['patch']) 

103 self.write(patchRefList[0], mergedCatalog) 

104 

105 @classmethod 

106 def _makeArgumentParser(cls): 

107 """Create a suitable ArgumentParser. 

108 

109 We will use the ArgumentParser to get a list of data 

110 references for patches; the RunnerClass will sort them into lists 

111 of data references for the same patch. 

112 

113 References first of self.inputDatasets, rather than 

114 self.inputDataset 

115 """ 

116 return makeMergeArgumentParser(cls._DefaultName, cls.inputDatasets[0]) 

117 

118 def readCatalog(self, patchRef): 

119 """Read input catalogs 

120 

121 Read all the input datasets given by the 'inputDatasets' 

122 attribute. 

123 

124 Parameters 

125 ---------- 

126 patchRef : `lsst.daf.persistence.ButlerDataRef` 

127 Data reference for patch 

128 

129 Returns 

130 ------- 

131 Tuple consisting of filter name and a dict of catalogs, keyed by 

132 dataset name 

133 """ 

134 filterName = patchRef.dataId["filter"] 

135 catalogDict = {} 

136 for dataset in self.inputDatasets: 

137 catalog = patchRef.get(self.config.coaddName + "Coadd_" + dataset, immediate=True) 

138 self.log.info("Read %d sources from %s for filter %s: %s" % 

139 (len(catalog), dataset, filterName, patchRef.dataId)) 

140 catalogDict[dataset] = catalog 

141 return filterName, catalogDict 

142 

143 def run(self, catalogs, tract, patch): 

144 """Merge multiple catalogs. 

145 

146 Parameters 

147 ---------- 

148 catalogs : `dict` 

149 Mapping from filter names to dict of catalogs. 

150 tract : int 

151 tractId to use for the tractId column 

152 patch : str 

153 patchId to use for the patchId column 

154 

155 Returns 

156 ------- 

157 catalog : `lsst.pipe.tasks.parquetTable.ParquetTable` 

158 Merged dataframe, with each column prefixed by 

159 `filter_tag(filt)`, wrapped in the parquet writer shim class. 

160 """ 

161 

162 dfs = [] 

163 for filt, tableDict in catalogs.items(): 

164 for dataset, table in tableDict.items(): 

165 # Convert afwTable to pandas DataFrame 

166 df = table.asAstropy().to_pandas().set_index('id', drop=True) 

167 

168 # Sort columns by name, to ensure matching schema among patches 

169 df = df.reindex(sorted(df.columns), axis=1) 

170 df['tractId'] = tract 

171 df['patchId'] = patch 

172 

173 # Make columns a 3-level MultiIndex 

174 df.columns = pd.MultiIndex.from_tuples([(dataset, filt, c) for c in df.columns], 

175 names=('dataset', 'filter', 'column')) 

176 dfs.append(df) 

177 

178 catalog = functools.reduce(lambda d1, d2: d1.join(d2), dfs) 

179 return ParquetTable(dataFrame=catalog) 

180 

181 def write(self, patchRef, catalog): 

182 """Write the output. 

183 

184 Parameters 

185 ---------- 

186 catalog : `ParquetTable` 

187 Catalog to write 

188 patchRef : `lsst.daf.persistence.ButlerDataRef` 

189 Data reference for patch 

190 """ 

191 patchRef.put(catalog, self.config.coaddName + "Coadd_" + self.outputDataset) 

192 # since the filter isn't actually part of the data ID for the dataset we're saving, 

193 # it's confusing to see it in the log message, even if the butler simply ignores it. 

194 mergeDataId = patchRef.dataId.copy() 

195 del mergeDataId["filter"] 

196 self.log.info("Wrote merged catalog: %s" % (mergeDataId,)) 

197 

198 def writeMetadata(self, dataRefList): 

199 """No metadata to write, and not sure how to write it for a list of dataRefs. 

200 """ 

201 pass 

202 

203 

204class PostprocessAnalysis(object): 

205 """Calculate columns from ParquetTable 

206 

207 This object manages and organizes an arbitrary set of computations 

208 on a catalog. The catalog is defined by a 

209 `lsst.pipe.tasks.parquetTable.ParquetTable` object (or list thereof), such as a 

210 `deepCoadd_obj` dataset, and the computations are defined by a collection 

211 of `lsst.pipe.tasks.functor.Functor` objects (or, equivalently, 

212 a `CompositeFunctor`). 

213 

214 After the object is initialized, accessing the `.df` attribute (which 

215 holds the `pandas.DataFrame` containing the results of the calculations) triggers 

216 computation of said dataframe. 

217 

218 One of the conveniences of using this object is the ability to define a desired common 

219 filter for all functors. This enables the same functor collection to be passed to 

220 several different `PostprocessAnalysis` objects without having to change the original 

221 functor collection, since the `filt` keyword argument of this object triggers an 

222 overwrite of the `filt` property for all functors in the collection. 

223 

224 This object also allows a list of refFlags to be passed, and defines a set of default 

225 refFlags that are always included even if not requested. 

226 

227 If a list of `ParquetTable` object is passed, rather than a single one, then the 

228 calculations will be mapped over all the input catalogs. In principle, it should 

229 be straightforward to parallelize this activity, but initial tests have failed 

230 (see TODO in code comments). 

231 

232 Parameters 

233 ---------- 

234 parq : `lsst.pipe.tasks.ParquetTable` (or list of such) 

235 Source catalog(s) for computation 

236 

237 functors : `list`, `dict`, or `lsst.pipe.tasks.functors.CompositeFunctor` 

238 Computations to do (functors that act on `parq`). 

239 If a dict, the output 

240 DataFrame will have columns keyed accordingly. 

241 If a list, the column keys will come from the 

242 `.shortname` attribute of each functor. 

243 

244 filt : `str` (optional) 

245 Filter in which to calculate. If provided, 

246 this will overwrite any existing `.filt` attribute 

247 of the provided functors. 

248 

249 flags : `list` (optional) 

250 List of flags (per-band) to include in output table. 

251 

252 refFlags : `list` (optional) 

253 List of refFlags (only reference band) to include in output table. 

254 

255 

256 """ 

257 _defaultRefFlags = ('calib_psf_used', 'detect_isPrimary') 

258 _defaultFuncs = (('coord_ra', RAColumn()), 

259 ('coord_dec', DecColumn())) 

260 

261 def __init__(self, parq, functors, filt=None, flags=None, refFlags=None): 

262 self.parq = parq 

263 self.functors = functors 

264 

265 self.filt = filt 

266 self.flags = list(flags) if flags is not None else [] 

267 self.refFlags = list(self._defaultRefFlags) 

268 if refFlags is not None: 

269 self.refFlags += list(refFlags) 

270 

271 self._df = None 

272 

273 @property 

274 def defaultFuncs(self): 

275 funcs = dict(self._defaultFuncs) 

276 return funcs 

277 

278 @property 

279 def func(self): 

280 additionalFuncs = self.defaultFuncs 

281 additionalFuncs.update({flag: Column(flag, dataset='ref') for flag in self.refFlags}) 

282 additionalFuncs.update({flag: Column(flag, dataset='meas') for flag in self.flags}) 

283 

284 if isinstance(self.functors, CompositeFunctor): 

285 func = self.functors 

286 else: 

287 func = CompositeFunctor(self.functors) 

288 

289 func.funcDict.update(additionalFuncs) 

290 func.filt = self.filt 

291 

292 return func 

293 

294 @property 

295 def noDupCols(self): 

296 return [name for name, func in self.func.funcDict.items() if func.noDup or func.dataset == 'ref'] 

297 

298 @property 

299 def df(self): 

300 if self._df is None: 

301 self.compute() 

302 return self._df 

303 

304 def compute(self, dropna=False, pool=None): 

305 # map over multiple parquet tables 

306 if type(self.parq) in (list, tuple): 

307 if pool is None: 

308 dflist = [self.func(parq, dropna=dropna) for parq in self.parq] 

309 else: 

310 # TODO: Figure out why this doesn't work (pyarrow pickling issues?) 

311 dflist = pool.map(functools.partial(self.func, dropna=dropna), self.parq) 

312 self._df = pd.concat(dflist) 

313 else: 

314 self._df = self.func(self.parq, dropna=dropna) 

315 

316 return self._df 

317 

318 

319class TransformCatalogBaseConfig(pexConfig.Config): 

320 coaddName = pexConfig.Field( 

321 dtype=str, 

322 default="deep", 

323 doc="Name of coadd" 

324 ) 

325 functorFile = pexConfig.Field( 

326 dtype=str, 

327 doc='Path to YAML file specifying functors to be computed', 

328 default=None 

329 ) 

330 

331 

332class TransformCatalogBaseTask(CmdLineTask): 

333 """Base class for transforming/standardizing a catalog 

334 

335 by applying functors that convert units and apply calibrations. 

336 The purpose of this task is to perform a set of computations on 

337 an input `ParquetTable` dataset (such as `deepCoadd_obj`) and write the 

338 results to a new dataset (which needs to be declared in an `outputDataset` 

339 attribute). 

340 

341 The calculations to be performed are defined in a YAML file that specifies 

342 a set of functors to be computed, provided as 

343 a `--functorFile` config parameter. An example of such a YAML file 

344 is the following: 

345 

346 funcs: 

347 psfMag: 

348 functor: Mag 

349 args: 

350 - base_PsfFlux 

351 filt: HSC-G 

352 dataset: meas 

353 cmodel_magDiff: 

354 functor: MagDiff 

355 args: 

356 - modelfit_CModel 

357 - base_PsfFlux 

358 filt: HSC-G 

359 gauss_magDiff: 

360 functor: MagDiff 

361 args: 

362 - base_GaussianFlux 

363 - base_PsfFlux 

364 filt: HSC-G 

365 count: 

366 functor: Column 

367 args: 

368 - base_InputCount_value 

369 filt: HSC-G 

370 deconvolved_moments: 

371 functor: DeconvolvedMoments 

372 filt: HSC-G 

373 dataset: forced_src 

374 refFlags: 

375 - calib_psfUsed 

376 - merge_measurement_i 

377 - merge_measurement_r 

378 - merge_measurement_z 

379 - merge_measurement_y 

380 - merge_measurement_g 

381 - base_PixelFlags_flag_inexact_psfCenter 

382 - detect_isPrimary 

383 

384 The names for each entry under "func" will become the names of columns in the 

385 output dataset. All the functors referenced are defined in `lsst.pipe.tasks.functors`. 

386 Positional arguments to be passed to each functor are in the `args` list, 

387 and any additional entries for each column other than "functor" or "args" (e.g., `'filt'`, 

388 `'dataset'`) are treated as keyword arguments to be passed to the functor initialization. 

389 

390 The "refFlags" entry is shortcut for a bunch of `Column` functors with the original column and 

391 taken from the `'ref'` dataset. 

392 

393 The "flags" entry will be expanded out per band. 

394 

395 Note, if `'filter'` is provided as part of the `dataId` when running this task (even though 

396 `deepCoadd_obj` does not use `'filter'`), then this will override the `filt` kwargs 

397 provided in the YAML file, and the calculations will be done in that filter. 

398 

399 This task uses the `lsst.pipe.tasks.postprocess.PostprocessAnalysis` object 

400 to organize and excecute the calculations. 

401 

402 """ 

403 @property 

404 def _DefaultName(self): 

405 raise NotImplementedError('Subclass must define "_DefaultName" attribute') 

406 

407 @property 

408 def outputDataset(self): 

409 raise NotImplementedError('Subclass must define "outputDataset" attribute') 

410 

411 @property 

412 def inputDataset(self): 

413 raise NotImplementedError('Subclass must define "inputDataset" attribute') 

414 

415 @property 

416 def ConfigClass(self): 

417 raise NotImplementedError('Subclass must define "ConfigClass" attribute') 

418 

419 def runDataRef(self, patchRef): 

420 parq = patchRef.get() 

421 dataId = patchRef.dataId 

422 funcs = self.getFunctors() 

423 self.log.info("Transforming/standardizing the catalog of %s", dataId) 

424 df = self.run(parq, funcs=funcs, dataId=dataId) 

425 self.write(df, patchRef) 

426 return df 

427 

428 def run(self, parq, funcs=None, dataId=None): 

429 """Do postprocessing calculations 

430 

431 Takes a `ParquetTable` object and dataId, 

432 returns a dataframe with results of postprocessing calculations. 

433 

434 Parameters 

435 ---------- 

436 parq : `lsst.pipe.tasks.parquetTable.ParquetTable` 

437 ParquetTable from which calculations are done. 

438 funcs : `lsst.pipe.tasks.functors.Functors` 

439 Functors to apply to the table's columns 

440 dataId : dict, optional 

441 Used to add a `patchId` column to the output dataframe. 

442 

443 Returns 

444 ------ 

445 `pandas.DataFrame` 

446 

447 """ 

448 filt = dataId.get('filter', None) 

449 return self.transform(filt, parq, funcs, dataId).df 

450 

451 def getFunctors(self): 

452 funcs = CompositeFunctor.from_file(self.config.functorFile) 

453 funcs.update(dict(PostprocessAnalysis._defaultFuncs)) 

454 return funcs 

455 

456 def getAnalysis(self, parq, funcs=None, filt=None): 

457 # Avoids disk access if funcs is passed 

458 if funcs is None: 

459 funcs = self.getFunctors() 

460 analysis = PostprocessAnalysis(parq, funcs, filt=filt) 

461 return analysis 

462 

463 def transform(self, filt, parq, funcs, dataId): 

464 analysis = self.getAnalysis(parq, funcs=funcs, filt=filt) 

465 df = analysis.df 

466 if dataId is not None: 

467 for key, value in dataId.items(): 

468 df[key] = value 

469 

470 return pipeBase.Struct( 

471 df=df, 

472 analysis=analysis 

473 ) 

474 

475 def write(self, df, parqRef): 

476 parqRef.put(ParquetTable(dataFrame=df), self.outputDataset) 

477 

478 def writeMetadata(self, dataRef): 

479 """No metadata to write. 

480 """ 

481 pass 

482 

483 

484class TransformObjectCatalogConfig(TransformCatalogBaseConfig): 

485 filterMap = pexConfig.DictField( 

486 keytype=str, 

487 itemtype=str, 

488 default={}, 

489 doc=("Dictionary mapping full filter name to short one for column name munging." 

490 "These filters determine the output columns no matter what filters the " 

491 "input data actually contain.") 

492 ) 

493 camelCase = pexConfig.Field( 

494 dtype=bool, 

495 default=True, 

496 doc=("Write per-filter columns names with camelCase, else underscore " 

497 "For example: gPsfFlux instead of g_PsfFlux.") 

498 ) 

499 multilevelOutput = pexConfig.Field( 

500 dtype=bool, 

501 default=False, 

502 doc=("Whether results dataframe should have a multilevel column index (True) or be flat " 

503 "and name-munged (False).") 

504 ) 

505 

506 

507class TransformObjectCatalogTask(TransformCatalogBaseTask): 

508 """Compute Flatted Object Table as defined in the DPDD 

509 

510 Do the same set of postprocessing calculations on all bands 

511 

512 This is identical to `TransformCatalogBaseTask`, except for that it does the 

513 specified functor calculations for all filters present in the 

514 input `deepCoadd_obj` table. Any specific `"filt"` keywords specified 

515 by the YAML file will be superceded. 

516 """ 

517 _DefaultName = "transformObjectCatalog" 

518 ConfigClass = TransformObjectCatalogConfig 

519 

520 inputDataset = 'deepCoadd_obj' 

521 outputDataset = 'objectTable' 

522 

523 @classmethod 

524 def _makeArgumentParser(cls): 

525 parser = ArgumentParser(name=cls._DefaultName) 

526 parser.add_id_argument("--id", cls.inputDataset, 

527 ContainerClass=CoaddDataIdContainer, 

528 help="data ID, e.g. --id tract=12345 patch=1,2") 

529 return parser 

530 

531 def run(self, parq, funcs=None, dataId=None): 

532 dfDict = {} 

533 analysisDict = {} 

534 templateDf = pd.DataFrame() 

535 # Perform transform for data of filters that exist in parq and are 

536 # specified in config.filterMap 

537 for filt in parq.columnLevelNames['filter']: 

538 if filt not in self.config.filterMap: 

539 self.log.info("Ignoring %s data in the input", filt) 

540 continue 

541 self.log.info("Transforming the catalog of filter %s", filt) 

542 result = self.transform(filt, parq, funcs, dataId) 

543 dfDict[filt] = result.df 

544 analysisDict[filt] = result.analysis 

545 if templateDf.empty: 

546 templateDf = result.df 

547 

548 # Fill NaNs in columns of other wanted filters 

549 for filt in self.config.filterMap: 

550 if filt not in dfDict: 

551 self.log.info("Adding empty columns for filter %s", filt) 

552 dfDict[filt] = pd.DataFrame().reindex_like(templateDf) 

553 

554 # This makes a multilevel column index, with filter as first level 

555 df = pd.concat(dfDict, axis=1, names=['filter', 'column']) 

556 

557 if not self.config.multilevelOutput: 

558 noDupCols = list(set.union(*[set(v.noDupCols) for v in analysisDict.values()])) 

559 if dataId is not None: 

560 noDupCols += list(dataId.keys()) 

561 df = flattenFilters(df, self.config.filterMap, noDupCols=noDupCols, 

562 camelCase=self.config.camelCase) 

563 

564 self.log.info("Made a table of %d columns and %d rows", len(df.columns), len(df)) 

565 return df 

566 

567 

568class TractObjectDataIdContainer(CoaddDataIdContainer): 

569 

570 def makeDataRefList(self, namespace): 

571 """Make self.refList from self.idList 

572 

573 Generate a list of data references given tract and/or patch. 

574 This was adapted from `TractQADataIdContainer`, which was 

575 `TractDataIdContainer` modifie to not require "filter". 

576 Only existing dataRefs are returned. 

577 """ 

578 def getPatchRefList(tract): 

579 return [namespace.butler.dataRef(datasetType=self.datasetType, 

580 tract=tract.getId(), 

581 patch="%d,%d" % patch.getIndex()) for patch in tract] 

582 

583 tractRefs = defaultdict(list) # Data references for each tract 

584 for dataId in self.idList: 

585 skymap = self.getSkymap(namespace) 

586 

587 if "tract" in dataId: 

588 tractId = dataId["tract"] 

589 if "patch" in dataId: 

590 tractRefs[tractId].append(namespace.butler.dataRef(datasetType=self.datasetType, 

591 tract=tractId, 

592 patch=dataId['patch'])) 

593 else: 

594 tractRefs[tractId] += getPatchRefList(skymap[tractId]) 

595 else: 

596 tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract)) 

597 for tract in skymap) 

598 outputRefList = [] 

599 for tractRefList in tractRefs.values(): 

600 existingRefs = [ref for ref in tractRefList if ref.datasetExists()] 

601 outputRefList.append(existingRefs) 

602 

603 self.refList = outputRefList 

604 

605 

606class ConsolidateObjectTableConfig(pexConfig.Config): 

607 coaddName = pexConfig.Field( 

608 dtype=str, 

609 default="deep", 

610 doc="Name of coadd" 

611 ) 

612 

613 

614class ConsolidateObjectTableTask(CmdLineTask): 

615 """Write patch-merged source tables to a tract-level parquet file 

616 """ 

617 _DefaultName = "consolidateObjectTable" 

618 ConfigClass = ConsolidateObjectTableConfig 

619 

620 inputDataset = 'objectTable' 

621 outputDataset = 'objectTable_tract' 

622 

623 @classmethod 

624 def _makeArgumentParser(cls): 

625 parser = ArgumentParser(name=cls._DefaultName) 

626 

627 parser.add_id_argument("--id", cls.inputDataset, 

628 help="data ID, e.g. --id tract=12345", 

629 ContainerClass=TractObjectDataIdContainer) 

630 return parser 

631 

632 def runDataRef(self, patchRefList): 

633 df = pd.concat([patchRef.get().toDataFrame() for patchRef in patchRefList]) 

634 patchRefList[0].put(ParquetTable(dataFrame=df), self.outputDataset) 

635 

636 def writeMetadata(self, dataRef): 

637 """No metadata to write. 

638 """ 

639 pass