lsst.jointcal  14.0-26-gc4bc114+9
dataIds.py
Go to the documentation of this file.
1 # See COPYRIGHT file at the top of the source tree.
2 
3 # Copied from HyperSuprime-Cam/pipe_tasks
4 
5 from __future__ import division, absolute_import, print_function
6 
7 import collections
8 
9 import lsst.log
11 import lsst.afw.table
12 import lsst.afw.image
13 from lsst.geom import convexHull
14 
15 from lsst.coadd.utils import CoaddDataIdContainer
16 
17 
18 class PerTractCcdDataIdContainer(CoaddDataIdContainer):
19  """A version of lsst.pipe.base.DataIdContainer that combines raw data IDs (defined as whatever we use
20  for 'src') with a tract.
21  """
22 
23  def castDataIds(self, butler):
24  """Validate data IDs and cast them to the correct type (modify idList in place).
25  @param butler: data butler
26  """
27  try:
28  idKeyTypeDict = butler.getKeys(datasetType="src", level=self.level)
29  except KeyError as e:
30  raise KeyError("Cannot get keys for datasetType %s at level %s: %s" % ("src", self.level, e))
31 
32  idKeyTypeDict = idKeyTypeDict.copy()
33  idKeyTypeDict["tract"] = int
34 
35  for dataDict in self.idList:
36  for key, strVal in dataDict.items():
37  try:
38  keyType = idKeyTypeDict[key]
39  except KeyError:
40  validKeys = sorted(idKeyTypeDict.keys())
41  raise KeyError("Unrecognized ID key %r; valid keys are: %s" % (key, validKeys))
42  if keyType != str:
43  try:
44  castVal = keyType(strVal)
45  except Exception:
46  raise TypeError("Cannot cast value %r to %s for ID key %r" % (strVal, keyType, key,))
47  dataDict[key] = castVal
48 
49  def _addDataRef(self, namespace, dataId, tract):
50  """Construct a dataRef based on dataId, but with an added tract key"""
51  forcedDataId = dataId.copy()
52  forcedDataId['tract'] = tract
53  dataRef = namespace.butler.dataRef(datasetType=self.datasetType, dataId=forcedDataId)
54  self.refList.append(dataRef)
55 
56  def makeDataRefList(self, namespace):
57  """Make self.refList from self.idList
58  """
59  if self.datasetType is None:
60  raise RuntimeError("Must call setDatasetType first")
61  skymap = None
62  log = lsst.log.Log.getLogger("jointcal.dataIds")
63  visitTract = {} # Set of tracts for each visit
64  visitRefs = {} # List of data references for each visit
65  for dataId in self.idList:
66  if "tract" not in dataId:
67  # Discover which tracts the data overlaps
68  log.infof("Reading WCS to determine tracts for components of dataId={}", dict(dataId))
69  if skymap is None:
70  skymap = self.getSkymap(namespace)
71 
72  for ref in namespace.butler.subset("calexp", dataId=dataId):
73  if not ref.datasetExists("calexp"):
74  log.warnf("calexp with dataId: {} not found.", dict(dataId))
75  continue
76 
77  # XXX fancier mechanism to select an individual exposure than just pulling out "visit"?
78  if "visit" in ref.dataId.keys():
79  visit = ref.dataId["visit"]
80  else:
81  # Fallback if visit is not in the dataId
82  visit = namespace.butler.queryMetadata("calexp", ("visit"), ref.dataId)[0]
83  if visit not in visitRefs:
84  visitRefs[visit] = list()
85  visitRefs[visit].append(ref)
86 
87  wcs = ref.get("calexp_wcs", immediate=True)
88  box = lsst.afw.geom.Box2D(ref.get("calexp_bbox"))
89  # Going with just the nearest tract. Since we're throwing all tracts for the visit
90  # together, this shouldn't be a problem unless the tracts are much smaller than a CCD.
91  tract = skymap.findTract(wcs.pixelToSky(box.getCenter()))
92  if overlapsTract(tract, wcs, box):
93  if visit not in visitTract:
94  visitTract[visit] = set()
95  visitTract[visit].add(tract.getId())
96  else:
97  tract = dataId.pop("tract")
98  # making a DataRef for src fills out any missing keys and allows us to iterate
99  for ref in namespace.butler.subset("src", dataId=dataId):
100  if ref.datasetExists():
101  self._addDataRef(namespace, ref.dataId, tract)
102 
103  # Ensure all components of a visit are kept together by putting them all in the same set of tracts
104  # NOTE: sorted() here is to keep py2 and py3 dataRefs in the same order.
105  # NOTE: see DM-9393 for details.
106  for visit, tractSet in sorted(visitTract.items()):
107  for ref in visitRefs[visit]:
108  for tract in sorted(tractSet):
109  self._addDataRef(namespace, ref.dataId, tract)
110  if visitTract:
111  tractCounter = collections.Counter()
112  for tractSet in visitTract.values():
113  tractCounter.update(tractSet)
114  log.infof("Number of visits per tract: {}", dict(tractCounter))
115 
116 
117 def overlapsTract(tract, imageWcs, imageBox):
118  """Return whether the image (specified by Wcs and bounding box) overlaps the tract
119  @param tract: TractInfo specifying a tract
120  @param imageWcs: Wcs for image
121  @param imageBox: Bounding box for image
122  @return bool
123  """
124  tractWcs = tract.getWcs()
125  tractCorners = [tractWcs.pixelToSky(lsst.afw.geom.Point2D(coord)).getVector() for
126  coord in tract.getBBox().getCorners()]
127  tractPoly = convexHull(tractCorners)
128 
129  try:
130  imageCorners = [imageWcs.pixelToSky(lsst.afw.geom.Point2D(pix)) for pix in imageBox.getCorners()]
131  except lsst.pex.exceptions.LsstCppException as e:
132  # Protecting ourselves from awful Wcs solutions in input images
133  if (not isinstance(e.message, lsst.pex.exceptions.DomainErrorException) and
134  not isinstance(e.message, lsst.pex.exceptions.RuntimeErrorException)):
135  raise
136  return False
137 
138  imagePoly = convexHull([coord.getVector() for coord in imageCorners])
139  if imagePoly is None:
140  return False
141  return tractPoly.intersects(imagePoly) # "intersects" also covers "contains" or "is contained by"
def _addDataRef(self, namespace, dataId, tract)
Definition: dataIds.py:49
static Log getLogger(std::string const &loggername)
def overlapsTract(tract, imageWcs, imageBox)
Definition: dataIds.py:117
STL class.
STL class.