Coverage for python/lsst/pipe/tasks/isolatedStarAssociation.py: 13%
219 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-24 10:27 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-24 10:27 +0000
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ['IsolatedStarAssociationConnections',
23 'IsolatedStarAssociationConfig',
24 'IsolatedStarAssociationTask']
26import numpy as np
27import pandas as pd
28from smatch.matcher import Matcher
30import lsst.pex.config as pexConfig
31import lsst.pipe.base as pipeBase
32from lsst.skymap import BaseSkyMap
33from lsst.meas.algorithms.sourceSelector import sourceSelectorRegistry
36class IsolatedStarAssociationConnections(pipeBase.PipelineTaskConnections,
37 dimensions=('instrument', 'tract', 'skymap',),
38 defaultTemplates={}):
39 source_table_visit = pipeBase.connectionTypes.Input(
40 doc='Source table in parquet format, per visit',
41 name='sourceTable_visit',
42 storageClass='DataFrame',
43 dimensions=('instrument', 'visit'),
44 deferLoad=True,
45 multiple=True,
46 )
47 skymap = pipeBase.connectionTypes.Input(
48 doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
49 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
50 storageClass='SkyMap',
51 dimensions=('skymap',),
52 )
53 isolated_star_sources = pipeBase.connectionTypes.Output(
54 doc='Catalog of individual sources for the isolated stars',
55 name='isolated_star_sources',
56 storageClass='DataFrame',
57 dimensions=('instrument', 'tract', 'skymap'),
58 )
59 isolated_star_cat = pipeBase.connectionTypes.Output(
60 doc='Catalog of isolated star positions',
61 name='isolated_star_cat',
62 storageClass='DataFrame',
63 dimensions=('instrument', 'tract', 'skymap'),
64 )
67class IsolatedStarAssociationConfig(pipeBase.PipelineTaskConfig,
68 pipelineConnections=IsolatedStarAssociationConnections):
69 """Configuration for IsolatedStarAssociationTask."""
71 inst_flux_field = pexConfig.Field(
72 doc=('Full name of instFlux field to use for s/n selection and persistence. '
73 'The associated flag will be implicity included in bad_flags. '
74 'Note that this is expected to end in ``instFlux``.'),
75 dtype=str,
76 default='apFlux_12_0_instFlux',
77 )
78 match_radius = pexConfig.Field(
79 doc='Match radius (arcseconds)',
80 dtype=float,
81 default=1.0,
82 )
83 isolation_radius = pexConfig.Field(
84 doc=('Isolation radius (arcseconds). Any stars with average centroids '
85 'within this radius of another star will be rejected from the final '
86 'catalog. This radius should be at least 2x match_radius.'),
87 dtype=float,
88 default=2.0,
89 )
90 band_order = pexConfig.ListField(
91 doc=(('Ordered list of bands to use for matching/storage. '
92 'Any bands not listed will not be matched.')),
93 dtype=str,
94 default=['i', 'z', 'r', 'g', 'y', 'u'],
95 )
96 id_column = pexConfig.Field(
97 doc='Name of column with source id.',
98 dtype=str,
99 default='sourceId',
100 )
101 ra_column = pexConfig.Field(
102 doc='Name of column with right ascension.',
103 dtype=str,
104 default='ra',
105 )
106 dec_column = pexConfig.Field(
107 doc='Name of column with declination.',
108 dtype=str,
109 default='dec',
110 )
111 physical_filter_column = pexConfig.Field(
112 doc='Name of column with physical filter name',
113 dtype=str,
114 default='physical_filter',
115 )
116 band_column = pexConfig.Field(
117 doc='Name of column with band name',
118 dtype=str,
119 default='band',
120 )
121 extra_columns = pexConfig.ListField(
122 doc='Extra names of columns to read and persist (beyond instFlux and error).',
123 dtype=str,
124 default=['x',
125 'y',
126 'xErr',
127 'yErr',
128 'apFlux_17_0_instFlux',
129 'apFlux_17_0_instFluxErr',
130 'apFlux_17_0_flag',
131 'localBackground_instFlux',
132 'localBackground_flag',
133 'ixx',
134 'iyy',
135 'ixy',]
136 )
137 source_selector = sourceSelectorRegistry.makeField(
138 doc='How to select sources. Under normal usage this should not be changed.',
139 default='science'
140 )
142 def setDefaults(self):
143 super().setDefaults()
145 source_selector = self.source_selector['science']
146 source_selector.setDefaults()
148 source_selector.doFlags = True
149 source_selector.doUnresolved = True
150 source_selector.doSignalToNoise = True
151 source_selector.doIsolated = True
152 source_selector.doRequireFiniteRaDec = True
153 source_selector.doRequirePrimary = True
155 source_selector.signalToNoise.minimum = 10.0
156 source_selector.signalToNoise.maximum = 1000.0
158 flux_flag_name = self.inst_flux_field.replace("instFlux", "flag")
160 source_selector.flags.bad = ['pixelFlags_edge',
161 'pixelFlags_interpolatedCenter',
162 'pixelFlags_saturatedCenter',
163 'pixelFlags_crCenter',
164 'pixelFlags_bad',
165 'pixelFlags_interpolated',
166 'pixelFlags_saturated',
167 'centroid_flag',
168 flux_flag_name]
170 source_selector.signalToNoise.fluxField = self.inst_flux_field
171 source_selector.signalToNoise.errField = self.inst_flux_field + 'Err'
173 source_selector.isolated.parentName = 'parentSourceId'
174 source_selector.isolated.nChildName = 'deblend_nChild'
176 source_selector.unresolved.maximum = 0.5
177 source_selector.unresolved.name = 'extendedness'
179 source_selector.requireFiniteRaDec.raColName = self.ra_column
180 source_selector.requireFiniteRaDec.decColName = self.dec_column
183class IsolatedStarAssociationTask(pipeBase.PipelineTask):
184 """Associate sources into isolated star catalogs.
185 """
186 ConfigClass = IsolatedStarAssociationConfig
187 _DefaultName = 'isolatedStarAssociation'
189 def __init__(self, **kwargs):
190 super().__init__(**kwargs)
192 self.makeSubtask('source_selector')
193 # Only log warning and fatal errors from the source_selector
194 self.source_selector.log.setLevel(self.source_selector.log.WARN)
196 def runQuantum(self, butlerQC, inputRefs, outputRefs):
197 input_ref_dict = butlerQC.get(inputRefs)
199 tract = butlerQC.quantum.dataId['tract']
201 source_table_refs = input_ref_dict['source_table_visit']
203 self.log.info('Running with %d source_table_visit dataRefs',
204 len(source_table_refs))
206 source_table_ref_dict_temp = {source_table_ref.dataId['visit']: source_table_ref for
207 source_table_ref in source_table_refs}
209 bands = {source_table_ref.dataId['band'] for source_table_ref in source_table_refs}
210 for band in bands:
211 if band not in self.config.band_order:
212 self.log.warning('Input data has data from band %s but that band is not '
213 'configured for matching', band)
215 # TODO: Sort by visit until DM-31701 is done and we have deterministic
216 # dataset ordering.
217 source_table_ref_dict = {visit: source_table_ref_dict_temp[visit] for
218 visit in sorted(source_table_ref_dict_temp.keys())}
220 struct = self.run(input_ref_dict['skymap'], tract, source_table_ref_dict)
222 butlerQC.put(pd.DataFrame(struct.star_source_cat),
223 outputRefs.isolated_star_sources)
224 butlerQC.put(pd.DataFrame(struct.star_cat),
225 outputRefs.isolated_star_cat)
227 def run(self, skymap, tract, source_table_ref_dict):
228 """Run the isolated star association task.
230 Parameters
231 ----------
232 skymap : `lsst.skymap.SkyMap`
233 Skymap object.
234 tract : `int`
235 Tract number.
236 source_table_ref_dict : `dict`
237 Dictionary of source_table refs. Key is visit, value is dataref.
239 Returns
240 -------
241 struct : `lsst.pipe.base.struct`
242 Struct with outputs for persistence.
243 """
244 star_source_cat = self._make_all_star_sources(skymap[tract], source_table_ref_dict)
246 primary_bands = self.config.band_order
248 # Do primary matching
249 primary_star_cat = self._match_primary_stars(primary_bands, star_source_cat)
251 if len(primary_star_cat) == 0:
252 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype),
253 star_cat=np.zeros(0, primary_star_cat.dtype))
255 # Remove neighbors
256 primary_star_cat = self._remove_neighbors(primary_star_cat)
258 if len(primary_star_cat) == 0:
259 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype),
260 star_cat=np.zeros(0, primary_star_cat.dtype))
262 # Crop to inner tract region
263 inner_tract_ids = skymap.findTractIdArray(primary_star_cat[self.config.ra_column],
264 primary_star_cat[self.config.dec_column],
265 degrees=True)
266 use = (inner_tract_ids == tract)
267 self.log.info('Total of %d isolated stars in inner tract.', use.sum())
269 primary_star_cat = primary_star_cat[use]
271 if len(primary_star_cat) == 0:
272 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype),
273 star_cat=np.zeros(0, primary_star_cat.dtype))
275 # Set the unique ids.
276 primary_star_cat['isolated_star_id'] = self._compute_unique_ids(skymap,
277 tract,
278 len(primary_star_cat))
280 # Match to sources.
281 star_source_cat, primary_star_cat = self._match_sources(primary_bands,
282 star_source_cat,
283 primary_star_cat)
285 return pipeBase.Struct(star_source_cat=star_source_cat,
286 star_cat=primary_star_cat)
288 def _make_all_star_sources(self, tract_info, source_table_ref_dict):
289 """Make a catalog of all the star sources.
291 Parameters
292 ----------
293 tract_info : `lsst.skymap.TractInfo`
294 Information about the tract.
295 source_table_ref_dict : `dict`
296 Dictionary of source_table refs. Key is visit, value is dataref.
298 Returns
299 -------
300 star_source_cat : `np.ndarray`
301 Catalog of star sources.
302 """
303 # Internally, we use a numpy recarray, they are by far the fastest
304 # option in testing for relatively narrow tables.
305 # (have not tested wide tables)
306 all_columns, persist_columns = self._get_source_table_visit_column_names()
307 poly = tract_info.outer_sky_polygon
309 tables = []
310 for visit in source_table_ref_dict:
311 source_table_ref = source_table_ref_dict[visit]
312 df = source_table_ref.get(parameters={'columns': all_columns})
313 df.reset_index(inplace=True)
315 goodSrc = self.source_selector.selectSources(df)
317 table = df[persist_columns][goodSrc.selected].to_records()
319 # Append columns that include the row in the source table
320 # and the matched object index (to be filled later).
321 table = np.lib.recfunctions.append_fields(table,
322 ['source_row',
323 'obj_index'],
324 [np.where(goodSrc.selected)[0],
325 np.zeros(goodSrc.selected.sum(), dtype=np.int32)],
326 dtypes=['i4', 'i4'],
327 usemask=False)
329 # We cut to the outer tract polygon to ensure consistent matching
330 # from tract to tract.
331 tract_use = poly.contains(np.deg2rad(table[self.config.ra_column]),
332 np.deg2rad(table[self.config.dec_column]))
334 tables.append(table[tract_use])
336 # Combine tables
337 star_source_cat = np.concatenate(tables)
339 return star_source_cat
341 def _get_source_table_visit_column_names(self):
342 """Get the list of sourceTable_visit columns from the config.
344 Returns
345 -------
346 all_columns : `list` [`str`]
347 All columns to read
348 persist_columns : `list` [`str`]
349 Columns to persist (excluding selection columns)
350 """
351 columns = [self.config.id_column,
352 'visit', 'detector',
353 self.config.ra_column, self.config.dec_column,
354 self.config.physical_filter_column, self.config.band_column,
355 self.config.inst_flux_field, self.config.inst_flux_field + 'Err']
356 columns.extend(self.config.extra_columns)
358 all_columns = columns.copy()
359 if self.source_selector.config.doFlags:
360 all_columns.extend(self.source_selector.config.flags.bad)
361 if self.source_selector.config.doUnresolved:
362 all_columns.append(self.source_selector.config.unresolved.name)
363 if self.source_selector.config.doIsolated:
364 all_columns.append(self.source_selector.config.isolated.parentName)
365 all_columns.append(self.source_selector.config.isolated.nChildName)
366 if self.source_selector.config.doRequirePrimary:
367 all_columns.append(self.source_selector.config.requirePrimary.primaryColName)
369 return all_columns, columns
371 def _match_primary_stars(self, primary_bands, star_source_cat):
372 """Match primary stars.
374 Parameters
375 ----------
376 primary_bands : `list` [`str`]
377 Ordered list of primary bands.
378 star_source_cat : `np.ndarray`
379 Catalog of star sources.
381 Returns
382 -------
383 primary_star_cat : `np.ndarray`
384 Catalog of primary star positions
385 """
386 ra_col = self.config.ra_column
387 dec_col = self.config.dec_column
389 dtype = self._get_primary_dtype(primary_bands)
391 primary_star_cat = None
392 for primary_band in primary_bands:
393 use = (star_source_cat['band'] == primary_band)
395 ra = star_source_cat[ra_col][use]
396 dec = star_source_cat[dec_col][use]
398 with Matcher(ra, dec) as matcher:
399 try:
400 # New smatch API
401 idx = matcher.query_groups(self.config.match_radius/3600., min_match=1)
402 except AttributeError:
403 # Old smatch API
404 idx = matcher.query_self(self.config.match_radius/3600., min_match=1)
406 count = len(idx)
408 if count == 0:
409 self.log.info('Found 0 primary stars in %s band.', primary_band)
410 continue
412 band_cat = np.zeros(count, dtype=dtype)
413 band_cat['primary_band'] = primary_band
415 # If the tract cross ra=0 (that is, it has both low ra and high ra)
416 # then we need to remap all ra values from [0, 360) to [-180, 180)
417 # before doing any position averaging.
418 remapped = False
419 if ra.min() < 60.0 and ra.max() > 300.0:
420 ra_temp = (ra + 180.0) % 360. - 180.
421 remapped = True
422 else:
423 ra_temp = ra
425 # Compute mean position for each primary star
426 for i, row in enumerate(idx):
427 row = np.array(row)
428 band_cat[ra_col][i] = np.mean(ra_temp[row])
429 band_cat[dec_col][i] = np.mean(dec[row])
431 if remapped:
432 # Remap ra back to [0, 360)
433 band_cat[ra_col] %= 360.0
435 # Match to previous band catalog(s), and remove duplicates.
436 if primary_star_cat is None or len(primary_star_cat) == 0:
437 primary_star_cat = band_cat
438 else:
439 with Matcher(band_cat[ra_col], band_cat[dec_col]) as matcher:
440 idx = matcher.query_radius(primary_star_cat[ra_col],
441 primary_star_cat[dec_col],
442 self.config.match_radius/3600.)
443 # Any object with a match should be removed.
444 match_indices = np.array([i for i in range(len(idx)) if len(idx[i]) > 0])
445 if len(match_indices) > 0:
446 band_cat = np.delete(band_cat, match_indices)
448 primary_star_cat = np.append(primary_star_cat, band_cat)
449 self.log.info('Found %d primary stars in %s band.', len(band_cat), primary_band)
451 # If everything was cut, we still want the correct datatype.
452 if primary_star_cat is None:
453 primary_star_cat = np.zeros(0, dtype=dtype)
455 return primary_star_cat
457 def _remove_neighbors(self, primary_star_cat):
458 """Remove neighbors from the primary star catalog.
460 Parameters
461 ----------
462 primary_star_cat : `np.ndarray`
463 Primary star catalog.
465 Returns
466 -------
467 primary_star_cat_cut : `np.ndarray`
468 Primary star cat with neighbors removed.
469 """
470 ra_col = self.config.ra_column
471 dec_col = self.config.dec_column
473 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher:
474 # By setting min_match=2 objects that only match to themselves
475 # will not be recorded.
476 try:
477 # New smatch API
478 idx = matcher.query_groups(self.config.isolation_radius/3600., min_match=2)
479 except AttributeError:
480 # Old smatch API
481 idx = matcher.query_self(self.config.isolation_radius/3600., min_match=2)
483 try:
484 neighbor_indices = np.concatenate(idx)
485 except ValueError:
486 neighbor_indices = np.zeros(0, dtype=int)
488 if len(neighbor_indices) > 0:
489 neighbored = np.unique(neighbor_indices)
490 self.log.info('Cutting %d objects with close neighbors.', len(neighbored))
491 primary_star_cat = np.delete(primary_star_cat, neighbored)
493 return primary_star_cat
495 def _match_sources(self, bands, star_source_cat, primary_star_cat):
496 """Match individual sources to primary stars.
498 Parameters
499 ----------
500 bands : `list` [`str`]
501 List of bands.
502 star_source_cat : `np.ndarray`
503 Array of star sources.
504 primary_star_cat : `np.ndarray`
505 Array of primary stars.
507 Returns
508 -------
509 star_source_cat_sorted : `np.ndarray`
510 Sorted and cropped array of star sources.
511 primary_star_cat : `np.ndarray`
512 Catalog of isolated stars, with indexes to star_source_cat_cut.
513 """
514 ra_col = self.config.ra_column
515 dec_col = self.config.dec_column
517 # We match sources per-band because it allows us to have sorted
518 # sources for easy retrieval of per-band matches.
519 n_source_per_band_per_obj = np.zeros((len(bands),
520 len(primary_star_cat)),
521 dtype=np.int32)
522 band_uses = []
523 idxs = []
524 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher:
525 for b, band in enumerate(bands):
526 band_use, = np.where(star_source_cat['band'] == band)
528 idx = matcher.query_radius(star_source_cat[ra_col][band_use],
529 star_source_cat[dec_col][band_use],
530 self.config.match_radius/3600.)
531 n_source_per_band_per_obj[b, :] = np.array([len(row) for row in idx])
532 idxs.append(idx)
533 band_uses.append(band_use)
535 n_source_per_obj = np.sum(n_source_per_band_per_obj, axis=0)
537 primary_star_cat['nsource'] = n_source_per_obj
538 primary_star_cat['source_cat_index'][1:] = np.cumsum(n_source_per_obj)[:-1]
540 n_tot_source = primary_star_cat['source_cat_index'][-1] + primary_star_cat['nsource'][-1]
542 # Temporary arrays until we crop/sort the source catalog
543 source_index = np.zeros(n_tot_source, dtype=np.int32)
544 obj_index = np.zeros(n_tot_source, dtype=np.int32)
546 ctr = 0
547 for i in range(len(primary_star_cat)):
548 obj_index[ctr: ctr + n_source_per_obj[i]] = i
549 for b in range(len(bands)):
550 source_index[ctr: ctr + n_source_per_band_per_obj[b, i]] = band_uses[b][idxs[b][i]]
551 ctr += n_source_per_band_per_obj[b, i]
553 source_cat_index_band_offset = np.cumsum(n_source_per_band_per_obj, axis=0)
555 for b, band in enumerate(bands):
556 primary_star_cat[f'nsource_{band}'] = n_source_per_band_per_obj[b, :]
557 if b == 0:
558 # The first band listed is the same as the overall star
559 primary_star_cat[f'source_cat_index_{band}'] = primary_star_cat['source_cat_index']
560 else:
561 # Other band indices are offset from the previous band
562 primary_star_cat[f'source_cat_index_{band}'] = (primary_star_cat['source_cat_index']
563 + source_cat_index_band_offset[b - 1, :])
565 star_source_cat = star_source_cat[source_index]
566 star_source_cat['obj_index'] = obj_index
568 return star_source_cat, primary_star_cat
570 def _compute_unique_ids(self, skymap, tract, nstar):
571 """Compute unique star ids.
573 This is a simple hash of the tract and star to provide an
574 id that is unique for a given processing.
576 Parameters
577 ----------
578 skymap : `lsst.skymap.Skymap`
579 Skymap object.
580 tract : `int`
581 Tract id number.
582 nstar : `int`
583 Number of stars.
585 Returns
586 -------
587 ids : `np.ndarray`
588 Array of unique star ids.
589 """
590 # The end of the id will be big enough to hold the tract number
591 mult = 10**(int(np.log10(len(skymap))) + 1)
593 return (np.arange(nstar) + 1)*mult + tract
595 def _get_primary_dtype(self, primary_bands):
596 """Get the numpy datatype for the primary star catalog.
598 Parameters
599 ----------
600 primary_bands : `list` [`str`]
601 List of primary bands.
603 Returns
604 -------
605 dtype : `numpy.dtype`
606 Datatype of the primary catalog.
607 """
608 max_len = max([len(primary_band) for primary_band in primary_bands])
610 dtype = [('isolated_star_id', 'i8'),
611 (self.config.ra_column, 'f8'),
612 (self.config.dec_column, 'f8'),
613 ('primary_band', f'U{max_len}'),
614 ('source_cat_index', 'i4'),
615 ('nsource', 'i4')]
617 for band in primary_bands:
618 dtype.append((f'source_cat_index_{band}', 'i4'))
619 dtype.append((f'nsource_{band}', 'i4'))
621 return dtype