Coverage for python / lsst / pipe / tasks / rgb2hips / _high_order_hips.py: 0%
132 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:37 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:37 +0000
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HighOrderHipsTaskConnections", "HighOrderHipsTaskConfig", "HighOrderHipsTask")
25import numpy as np
26from enum import Enum
27from numpy.typing import NDArray
29from lsst.afw.geom import makeHpxWcs
30from lsst.pipe.base import (
31 PipelineTask,
32 PipelineTaskConfig,
33 PipelineTaskConnections,
34 Struct,
35 QuantumContext,
36 InputQuantizedConnection,
37 OutputQuantizedConnection,
38)
39from lsst.pex.config import ConfigField, Field, ChoiceField
40from lsst.pipe.base.connectionTypes import Input, Output
41from lsst.skymap import BaseSkyMap
42from lsst.afw.geom import SkyWcs
43from lsst.geom import Box2I, Point2I, Extent2I
44from lsst.afw.math import Warper
45from lsst.daf.butler import DeferredDatasetHandle
46from lsst.afw.image import ImageF
47from lsst.resources import ResourcePath
49from collections.abc import Iterable
50from lsst.sphgeom import RangeSet
52import cv2
54from ._utils import _write_hips_image
55from ..prettyPictureMaker import FeatheredMosaicCreator
58class ColorChannel(Enum):
59 """Enum representing color channels for image processing."""
61 RED = 0
62 GREEN = 1
63 BLUE = 2
66class HighOrderHipsTaskConnections(PipelineTaskConnections, dimensions=("healpix8",)):
67 input_images = Input(
68 doc="Color images which are to be turned into hips tiles",
69 name="rgb_picture",
70 storageClass="ColorImage",
71 dimensions=("tract", "patch", "skymap"),
72 multiple=True,
73 deferLoad=True,
74 )
75 skymap = Input(
76 doc="The skymap which the data has been mapped onto",
77 storageClass="SkyMap",
78 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
79 dimensions=("skymap",),
80 )
81 output_hpx = Output(
82 doc="Healpix tiles at order 8, but binned to 256x256",
83 name="rgb_picture_hips8",
84 storageClass="NumpyArray",
85 dimensions=("healpix8",),
86 )
89class HighOrderHipsTaskConfig(PipelineTaskConfig, pipelineConnections=HighOrderHipsTaskConnections):
90 """Configuration class for the HighOrderHipsTask pipeline task."""
92 hips_order = 8
93 """HealPix order to generate tiles for."""
94 warp = ConfigField[Warper.ConfigClass](
95 doc="Warper configuration",
96 )
97 hips_base_uri = Field[str](
98 doc="URI to HiPS base for output.",
99 optional=False,
100 )
101 color_ordering = Field[str](
102 doc=(
103 "A string of the astrophysical bands that correspond to the RGB channels in the color image "
104 "inputs to high_order_hips task. This is in making the hips metadata"
105 ),
106 optional=False,
107 )
108 file_extension = ChoiceField[str](
109 doc="Extension for the presisted image.",
110 allowed={"png": "Use the png image extension", "webp": "Use the webp image extension"},
111 default="png",
112 )
113 array_type = ChoiceField[str](
114 doc="The dataset type for the output image array",
115 default="uint8",
116 allowed={
117 "uint8": "Use 8 bit arrays, 255 max",
118 "uint16": "Use 16 bit arrays, 65535 max",
119 "half": "Use 16 bit float arrays, 1 max",
120 "float": "Use 32 bit float arrays, 1 max",
121 },
122 )
124 def setDefaults(self):
125 self.warp.warpingKernelName = "lanczos5"
128class HighOrderHipsTask(PipelineTask):
129 """Pipeline task that generates high-order HealPix tiles from RGB images.
131 Of Note; This task has special dispensation to write "out-of-tree" to a
132 location not within the butler. DO NOT model other tasks on this one.
134 This task takes in RGB images generated on a tract patch grid. It assembles
135 them into a 4096 x 4096 image aligned with the wcs coordinates of hips
136 order 8 pixels. This is then divided up into an 8x8 grid to produce 512x512
137 images at hips order 11. The images is then resampled using lanczos order 4
138 such that the image is half the size. The original image is then divided
139 into a 4x4 grid to produce hips images at order 10. The process is repeated
140 to produce hips images at order 9, and finally the image is resampled down
141 to 512x512 and saved out at hips order 8.
143 The order 8 image is resampled one more time to 256x256 and presisted by
144 the butler for later consumption in the `LowOrderHipsTask`.
146 The difference at producding wcs at order 8 and working up to 11, is tested
147 to be less than 6 decimal places when converting ra dec to pixel coordinates,
148 and even that is likely to be due to differences in warping kernels,
149 and not an intrinsic error. Doing processing like this allows hips generation
150 to be more effectively split across compute nodes.
151 """
153 _DefaultName = "highOrderHipsTask"
154 ConfigClass = HighOrderHipsTaskConfig
156 config: ConfigClass
158 def __init__(self, **kwargs):
159 super().__init__(**kwargs)
160 self.warper = Warper.fromConfig(self.config.warp)
162 # Set the base resource path that will be used for all outputs
163 self.hips_base_path = ResourcePath(self.config.hips_base_uri, forceDirectory=True)
164 self.hips_base_path = self.hips_base_path.join(
165 f"color_{self.config.color_ordering}", forceDirectory=True
166 )
168 def run(self, input_images: Iterable[tuple[NDArray, SkyWcs, Box2I]], healpix_id) -> Struct:
169 """Main execution method for generating HealPix tiles.
171 Parameters
172 ----------
173 input_images : Iterable[tuple[NDArray, SkyWcs, Box2I]]
174 Iterable of tuples containing image data, WCS, and bounding box information.
175 healpix_id : int
176 The HealPix order 8 ID to process.
178 Returns
179 -------
180 Struct
181 Output structure containing the processed HealPix order 8 tile.
182 This has been downsampled to 256x256 corresponding to a quarter of a healpix
183 order 7 image.
184 """
185 # Make the WCS for the transform, intentionally over-sampled to shift order 12.
186 # This creates as 4096 x 4096 image that can be broken apart to form the higher
187 # orders, binning each as needed
188 target_wcs = makeHpxWcs(8, healpix_id, 12, False)
190 # construct a bounding box that holds the warping results for each channel
191 exp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(2**12, 2**12))
193 output_array_hpx = np.zeros((4096, 4096, 3), dtype=np.float32)
194 output_array_hpx[:, :, :] = np.nan
196 self.log.info("Warping input exposures and populating hpx8 super tile.")
197 # Need to loop over input arrays then channel
198 # Warp and combine input images into the HealPix tile
199 for input_image, in_wcs, in_box in input_images:
200 tmp_image = ImageF(in_box)
201 in_image: NDArray = input_image
203 # Normalize image data based on dtype
204 match in_image.dtype:
205 case np.uint8:
206 in_image = in_image.astype(np.float32) / 255.0
207 case np.uint16:
208 in_image = in_image.astype(np.float32) / 65535
209 case np.float16:
210 in_image = in_image.astype(np.float32)
212 # Process each color channel separately
213 for channel in ColorChannel:
214 # existing data
215 existing = output_array_hpx[..., channel.value]
217 # construct an Exposure object from one channel in the array
218 channel_array = in_image[..., channel.value]
219 tmp_image.array[:, :] = channel_array
221 # Warp the image to the target WCS
222 warpped = self.warper.warpImage(target_wcs, tmp_image, in_wcs, maxBBox=exp_bbox)
223 warpped_box_slices = warpped.getBBox().slices
225 # Update the output array with valid (non-NaN) values
226 are_warpped = np.isfinite(warpped.array)
227 existing[warpped_box_slices][are_warpped] = warpped.array[are_warpped]
229 # Replace any remaining NaN values with zeros
230 output_array_hpx[np.isnan(output_array_hpx)] = 0
232 # Flip the y-axis to match HealPix indexing
233 output_array_hpx = output_array_hpx[::-1, :, :]
235 # Generate tiles for different HealPix orders using Lanczos resampling instead of binning.
236 # This handles how intensities should change as the hips level changes.
237 #
238 # what this does is take a single 4096 x 4096 image and resamples it in a courser grain such
239 # that the output pixels correspond to a 4x4 grid of hips pixels at an increasingly lower scale.
240 # This works because hips is a hierarchy of tiles all contained in the same area of the sky.
241 # This allows us to generate all the output images by resampling the inputs and saves the time
242 # required to generate whole new images at each scale.
243 #
244 # The loop variables are the resampling factor, the hips order, and the number of sub-divisions
245 # a pixel has gone through (used to determine quadrant).
246 for zoom, hips_level, factor in zip((0, 2, 4, 8), (11, 10, 9, 8), (3, 2, 1, 0)):
247 self.log.info("Generating tiles for hxp level %d", hips_level)
248 if zoom:
249 size = 4096 // zoom
250 binned_array = cv2.resize(output_array_hpx, (size, size), interpolation=cv2.INTER_LANCZOS4)
251 else:
252 binned_array = output_array_hpx
253 # always create blocks of 512x512 as that is native shift order 9 size
254 #
255 # Figure out the hips pixel ids at this hips order. This is complicated because each hipx pixel
256 # turns into 4 at a higher level, but must be in a specific order to correspond to how the data
257 # is layed out in an y,x grid. So if a hips order 8 pixel A turns into four pixels b,c,d,e, they
258 # are layed out like [[b,d], [c,e]]. This is true for every pixel as you go up in order. So
259 # if you start at order 8 with one pixel, you need to do order 9 and calculate the layout. Then
260 # for each order 9 pixel, do the same to get the layout in order 10, etc. This leaves a grid
261 # of pixels that are the ids of the corresponding 512,512 sub grid pixel in the input image.
262 tmp_pixels = np.array([[healpix_id]])
263 for _ in range(factor):
264 tmp_array = np.zeros(np.array(tmp_pixels.shape) * 2)
265 for ii in range(tmp_pixels.shape[0]):
266 for jj in range(tmp_pixels.shape[1]):
267 tmp_array_view = tmp_array[ii * 2 : ii * 2 + 2, jj * 2 : jj * 2 + 2]
268 tmp_range_set = RangeSet(int(tmp_pixels[ii, jj]))
269 tmp_array_view[:, :] = (
270 np.array([x for x in range(*tmp_range_set.scaled(4)[0])], dtype=int)[[0, 2, 1, 3]]
271 ).reshape(2, 2)
272 tmp_pixels = tmp_array
274 # now for each 512x512 sub pixel region write the hips image with the corresponding healpix id
275 hpx_id_array = tmp_pixels
276 for i in range(binned_array.shape[0] // 512):
277 for j in range(binned_array.shape[1] // 512):
278 pixel_id = int(hpx_id_array[i, j])
279 sub_pixel = binned_array[i * 512 : i * 512 + 512, j * 512 : j * 512 + 512, :]
280 self.log.info(f"writing sub_pixel {pixel_id}")
281 _write_hips_image(
282 sub_pixel,
283 pixel_id,
284 hips_level,
285 self.hips_base_path,
286 self.config.file_extension,
287 self.config.array_type,
288 )
290 # Finally, bin the level 8 hpx to 256x256 (1/4 order 7) to save to the butler.
291 # This makes smaller arrays to load, and saves the binning operation in the joint phase.
292 zoomed = cv2.resize(output_array_hpx, (256, 256), interpolation=cv2.INTER_LANCZOS4)
294 return Struct(output_hpx=zoomed)
296 def _assemble_sub_region(
297 self, tract_patch: dict[int, Iterable[tuple[DeferredDatasetHandle, SkyWcs, Box2I]]], patch_grow: int
298 ) -> list[tuple[NDArray, SkyWcs, Box2I]]:
299 """Assemble all the patches in each tract into images.
301 This function takes in an input keyed by tract, with values
302 corresponding the patches in that tract that overlap the quatum's
303 healpix value. It assembles each of these into a single image such
304 that the return values is a list of images (and metadata) one element
305 for each input tract.
307 Parameters
308 ----------
309 tract_patch : `dict` of `int` to `iterable` of `tuple` of
310 `DeferredDatasetHandle`, `SkyWcs` and `Box2I`
311 Input images and metadata organized into corresponding tracts.
312 patch_grow : `int`
313 Amount to grow patches by
315 Returns
316 -------
317 output_list : `list` of `tuple` of `NDArray` `SkyWcs` and `Box2I`
318 List of assembled images and metadata, one element for each tract
320 """
322 boxes = []
323 for _, iterable in tract_patch.items():
324 mosaic_maker = FeatheredMosaicCreator(patch_grow)
325 new_box = Box2I()
326 for _, _, bbox in iterable:
327 new_box.include(bbox)
328 # allocate tmp array
329 new_array = np.zeros((new_box.getHeight(), new_box.getWidth(), 3), dtype=np.float32)
330 for handle, skyWcs, box in iterable:
331 # Make a new box of the same size, but with the origin centered
332 # on the lowest corner were there is data.
333 localOrigin = box.getBegin() - new_box.getBegin()
334 localOrigin = Point2I(
335 x=int(np.floor(localOrigin.x)),
336 y=int(np.floor(localOrigin.y)),
337 )
339 localExtent = Extent2I(
340 x=int(np.floor(box.getWidth())),
341 y=int(np.floor(box.getHeight())),
342 )
343 tmpBox = Box2I(localOrigin, localExtent)
344 tmp_new_box = Box2I(Point2I(x=0, y=0), Extent2I(x=new_box.getWidth(), y=new_box.getHeight()))
346 image = handle.get()
347 mosaic_maker.add_to_image(new_array, image.array, tmp_new_box, tmpBox, reverse=False)
348 boxes.append((new_array, skyWcs, new_box))
349 return boxes
351 def runQuantum(
352 self,
353 butlerQC: QuantumContext,
354 inputRefs: InputQuantizedConnection,
355 outputRefs: OutputQuantizedConnection,
356 ) -> None:
357 # First get what healpix pixel this task is working on
358 healpix_id = butlerQC.quantum.dataId["healpix8"]
360 # grab the skymap
361 skymap: BaseSkyMap = butlerQC.get(inputRefs.skymap)
363 # Iterate over the input image refs, to get the corresponding bbox
364 # and assemble into container for run
365 inputs_by_tract = {}
366 for input_image_ref in inputRefs.input_images:
367 tract = input_image_ref.dataId["tract"]
368 patch = input_image_ref.dataId["patch"]
369 # All boxes in a given skymap will have the same inner dimensions
370 # for x and y and will be the same for all patches
371 imageWcs = skymap[tract][patch].getWcs()
372 box = skymap[tract][patch].getOuterBBox()
373 patch_grow = skymap[tract][patch].getCellInnerDimensions().getX()
374 imageHandle = butlerQC.get(input_image_ref)
375 container = inputs_by_tract.setdefault(tract, list())
376 container.append((imageHandle, imageWcs, box))
378 input_images = self._assemble_sub_region(inputs_by_tract, patch_grow)
380 outputs = self.run(input_images, healpix_id)
381 butlerQC.put(outputs, outputRefs)