lsst.jointcal  15.0-9-g5661f8f+1
dataIds.py
Go to the documentation of this file.
1 # See COPYRIGHT file at the top of the source tree.
2 
3 # Copied from HyperSuprime-Cam/pipe_tasks
4 import collections
5 
6 import lsst.log
8 import lsst.afw.table
9 import lsst.afw.image
10 from lsst.geom import convexHull
11 
12 from lsst.coadd.utils import CoaddDataIdContainer
13 
14 __all__ = ["PerTractCcdDataIdContainer", "overlapsTract"]
15 
16 
17 class PerTractCcdDataIdContainer(CoaddDataIdContainer):
18  """A version of lsst.pipe.base.DataIdContainer that combines raw data IDs (defined as whatever we use
19  for 'src') with a tract.
20  """
21 
22  def castDataIds(self, butler):
23  """Validate data IDs and cast them to the correct type (modify idList in place).
24 
25  Parameters
26  ----------
27  butler
28  data butler
29  """
30  try:
31  idKeyTypeDict = butler.getKeys(datasetType="src", level=self.level)
32  except KeyError as e:
33  raise KeyError("Cannot get keys for datasetType %s at level %s: %s" % ("src", self.level, e))
34 
35  idKeyTypeDict = idKeyTypeDict.copy()
36  idKeyTypeDict["tract"] = int
37 
38  for dataDict in self.idList:
39  for key, strVal in dataDict.items():
40  try:
41  keyType = idKeyTypeDict[key]
42  except KeyError:
43  validKeys = sorted(idKeyTypeDict.keys())
44  raise KeyError("Unrecognized ID key %r; valid keys are: %s" % (key, validKeys))
45  if keyType != str:
46  try:
47  castVal = keyType(strVal)
48  except Exception:
49  raise TypeError("Cannot cast value %r to %s for ID key %r" % (strVal, keyType, key,))
50  dataDict[key] = castVal
51 
52  def _addDataRef(self, namespace, dataId, tract):
53  """Construct a dataRef based on dataId, but with an added tract key"""
54  forcedDataId = dataId.copy()
55  forcedDataId['tract'] = tract
56  dataRef = namespace.butler.dataRef(datasetType=self.datasetType, dataId=forcedDataId)
57  self.refList.append(dataRef)
58 
59  def makeDataRefList(self, namespace):
60  """Make self.refList from self.idList
61  """
62  if self.datasetType is None:
63  raise RuntimeError("Must call setDatasetType first")
64  skymap = None
65  log = lsst.log.Log.getLogger("jointcal.dataIds")
66  visitTract = {} # Set of tracts for each visit
67  visitRefs = {} # List of data references for each visit
68  for dataId in self.idList:
69  if "tract" not in dataId:
70  # Discover which tracts the data overlaps
71  log.infof("Reading WCS to determine tracts for components of dataId={}", dict(dataId))
72  if skymap is None:
73  skymap = self.getSkymap(namespace)
74 
75  for ref in namespace.butler.subset("calexp", dataId=dataId):
76  if not ref.datasetExists("calexp"):
77  log.warnf("calexp with dataId: {} not found.", dict(dataId))
78  continue
79 
80  # XXX fancier mechanism to select an individual exposure than just pulling out "visit"?
81  if "visit" in ref.dataId.keys():
82  visit = ref.dataId["visit"]
83  else:
84  # Fallback if visit is not in the dataId
85  visit = namespace.butler.queryMetadata("calexp", ("visit"), ref.dataId)[0]
86  if visit not in visitRefs:
87  visitRefs[visit] = list()
88  visitRefs[visit].append(ref)
89 
90  wcs = ref.get("calexp_wcs", immediate=True)
91  box = lsst.afw.geom.Box2D(ref.get("calexp_bbox"))
92  # Going with just the nearest tract. Since we're throwing all tracts for the visit
93  # together, this shouldn't be a problem unless the tracts are much smaller than a CCD.
94  tract = skymap.findTract(wcs.pixelToSky(box.getCenter()))
95  if overlapsTract(tract, wcs, box):
96  if visit not in visitTract:
97  visitTract[visit] = set()
98  visitTract[visit].add(tract.getId())
99  else:
100  tract = dataId.pop("tract")
101  # making a DataRef for src fills out any missing keys and allows us to iterate
102  for ref in namespace.butler.subset("src", dataId=dataId):
103  if ref.datasetExists():
104  self._addDataRef(namespace, ref.dataId, tract)
105 
106  # Ensure all components of a visit are kept together by putting them all in the same set of tracts
107  # NOTE: sorted() here is to keep py2 and py3 dataRefs in the same order.
108  # NOTE: see DM-9393 for details.
109  for visit, tractSet in sorted(visitTract.items()):
110  for ref in visitRefs[visit]:
111  for tract in sorted(tractSet):
112  self._addDataRef(namespace, ref.dataId, tract)
113  if visitTract:
114  tractCounter = collections.Counter()
115  for tractSet in visitTract.values():
116  tractCounter.update(tractSet)
117  log.infof("Number of visits per tract: {}", dict(tractCounter))
118 
119 
120 def overlapsTract(tract, imageWcs, imageBox):
121  """Return whether the image (specified by Wcs and bounding box) overlaps the tract
122 
123  Parameters
124  ----------
125  tract
126  TractInfo specifying a tract
127  imageWcs
128  Wcs for image
129  imageBox
130  Bounding box for image
131 
132  Returns
133  -------
134  bool
135  True if the image overlaps the tract, False otherwise.
136  """
137  tractWcs = tract.getWcs()
138  tractCorners = [tractWcs.pixelToSky(lsst.afw.geom.Point2D(coord)).getVector() for
139  coord in tract.getBBox().getCorners()]
140  tractPoly = convexHull(tractCorners)
141 
142  try:
143  imageCorners = [imageWcs.pixelToSky(lsst.afw.geom.Point2D(pix)) for pix in imageBox.getCorners()]
144  except lsst.pex.exceptions.LsstCppException as e:
145  # Protecting ourselves from awful Wcs solutions in input images
146  if (not isinstance(e.message, lsst.pex.exceptions.DomainErrorException) and
147  not isinstance(e.message, lsst.pex.exceptions.RuntimeErrorException)):
148  raise
149  return False
150 
151  imagePoly = convexHull([coord.getVector() for coord in imageCorners])
152  if imagePoly is None:
153  return False
154  return tractPoly.intersects(imagePoly) # "intersects" also covers "contains" or "is contained by"
def _addDataRef(self, namespace, dataId, tract)
Definition: dataIds.py:52
static Log getLogger(std::string const &loggername)
def overlapsTract(tract, imageWcs, imageBox)
Definition: dataIds.py:120
STL class.
STL class.