Coverage for python / lsst / pipe / tasks / rgb2hips / _high_order_hips.py: 0%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:53 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("HighOrderHipsTaskConnections", "HighOrderHipsTaskConfig", "HighOrderHipsTask") 

24 

25import numpy as np 

26from enum import Enum 

27from numpy.typing import NDArray 

28 

29from lsst.afw.geom import makeHpxWcs 

30from lsst.pipe.base import ( 

31 PipelineTask, 

32 PipelineTaskConfig, 

33 PipelineTaskConnections, 

34 Struct, 

35 QuantumContext, 

36 InputQuantizedConnection, 

37 OutputQuantizedConnection, 

38) 

39from lsst.pex.config import ConfigField, Field, ChoiceField 

40from lsst.pipe.base.connectionTypes import Input, Output 

41from lsst.skymap import BaseSkyMap 

42from lsst.afw.geom import SkyWcs 

43from lsst.geom import Box2I, Point2I, Extent2I 

44from lsst.afw.math import Warper 

45from lsst.daf.butler import DeferredDatasetHandle 

46from lsst.afw.image import ImageF 

47from lsst.resources import ResourcePath 

48 

49from collections.abc import Iterable 

50from lsst.sphgeom import RangeSet 

51 

52import cv2 

53 

54from ._utils import _write_hips_image 

55from ..prettyPictureMaker import FeatheredMosaicCreator 

56 

57 

58class ColorChannel(Enum): 

59 """Enum representing color channels for image processing.""" 

60 

61 RED = 0 

62 GREEN = 1 

63 BLUE = 2 

64 

65 

66class HighOrderHipsTaskConnections(PipelineTaskConnections, dimensions=("healpix8",)): 

67 input_images = Input( 

68 doc="Color images which are to be turned into hips tiles", 

69 name="rgb_picture_array", 

70 storageClass="NumpyArray", 

71 dimensions=("tract", "patch", "skymap"), 

72 multiple=True, 

73 deferLoad=True, 

74 ) 

75 skymap = Input( 

76 doc="The skymap which the data has been mapped onto", 

77 storageClass="SkyMap", 

78 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

79 dimensions=("skymap",), 

80 ) 

81 output_hpx = Output( 

82 doc="Healpix tiles at order 8, but binned to 256x256", 

83 name="rgb_picture_hips8", 

84 storageClass="NumpyArray", 

85 dimensions=("healpix8",), 

86 ) 

87 

88 

89class HighOrderHipsTaskConfig(PipelineTaskConfig, pipelineConnections=HighOrderHipsTaskConnections): 

90 """Configuration class for the HighOrderHipsTask pipeline task.""" 

91 

92 hips_order = 8 

93 """HealPix order to generate tiles for.""" 

94 warp = ConfigField[Warper.ConfigClass]( 

95 doc="Warper configuration", 

96 ) 

97 hips_base_uri = Field[str]( 

98 doc="URI to HiPS base for output.", 

99 optional=False, 

100 ) 

101 color_ordering = Field[str]( 

102 doc=( 

103 "A string of the astrophysical bands that correspond to the RGB channels in the color image " 

104 "inputs to high_order_hips task. This is in making the hips metadata" 

105 ), 

106 optional=False, 

107 ) 

108 file_extension = ChoiceField[str]( 

109 doc="Extension for the presisted image.", 

110 allowed={"png": "Use the png image extension", "webp": "Use the webp image extension"}, 

111 default="png", 

112 ) 

113 array_type = ChoiceField[str]( 

114 doc="The dataset type for the output image array", 

115 default="uint8", 

116 allowed={ 

117 "uint8": "Use 8 bit arrays, 255 max", 

118 "uint16": "Use 16 bit arrays, 65535 max", 

119 "half": "Use 16 bit float arrays, 1 max", 

120 "float": "Use 32 bit float arrays, 1 max", 

121 }, 

122 ) 

123 

124 def setDefaults(self): 

125 self.warp.warpingKernelName = "lanczos5" 

126 

127 

128class HighOrderHipsTask(PipelineTask): 

129 """Pipeline task that generates high-order HealPix tiles from RGB images. 

130 

131 Of Note; This task has special dispensation to write "out-of-tree" to a 

132 location not within the butler. DO NOT model other tasks on this one. 

133 

134 This task takes in RGB images generated on a tract patch grid. It assembles 

135 them into a 4096 x 4096 image aligned with the wcs coordinates of hips 

136 order 8 pixels. This is then divided up into an 8x8 grid to produce 512x512 

137 images at hips order 11. The images is then resampled using lanczos order 4 

138 such that the image is half the size. The original image is then divided 

139 into a 4x4 grid to produce hips images at order 10. The process is repeated 

140 to produce hips images at order 9, and finally the image is resampled down 

141 to 512x512 and saved out at hips order 8. 

142 

143 The order 8 image is resampled one more time to 256x256 and presisted by 

144 the butler for later consumption in the `LowOrderHipsTask`. 

145 

146 The difference at producding wcs at order 8 and working up to 11, is tested 

147 to be less than 6 decimal places when converting ra dec to pixel coordinates, 

148 and even that is likely to be due to differences in warping kernels, 

149 and not an intrinsic error. Doing processing like this allows hips generation 

150 to be more effectively split across compute nodes. 

151 """ 

152 

153 _DefaultName = "highOrderHipsTask" 

154 ConfigClass = HighOrderHipsTaskConfig 

155 

156 config: ConfigClass 

157 

158 def __init__(self, **kwargs): 

159 super().__init__(**kwargs) 

160 self.warper = Warper.fromConfig(self.config.warp) 

161 

162 # Set the base resource path that will be used for all outputs 

163 self.hips_base_path = ResourcePath(self.config.hips_base_uri, forceDirectory=True) 

164 self.hips_base_path = self.hips_base_path.join( 

165 f"color_{self.config.color_ordering}", forceDirectory=True 

166 ) 

167 

168 def run(self, input_images: Iterable[tuple[NDArray, SkyWcs, Box2I]], healpix_id) -> Struct: 

169 """Main execution method for generating HealPix tiles. 

170 

171 Parameters 

172 ---------- 

173 input_images : Iterable[tuple[NDArray, SkyWcs, Box2I]] 

174 Iterable of tuples containing image data, WCS, and bounding box information. 

175 healpix_id : int 

176 The HealPix order 8 ID to process. 

177 

178 Returns 

179 ------- 

180 Struct 

181 Output structure containing the processed HealPix order 8 tile. 

182 This has been downsampled to 256x256 corresponding to a quarter of a healpix 

183 order 7 image. 

184 """ 

185 # Make the WCS for the transform, intentionally over-sampled to shift order 12. 

186 # This creates as 4096 x 4096 image that can be broken apart to form the higher 

187 # orders, binning each as needed 

188 target_wcs = makeHpxWcs(8, healpix_id, 12, False) 

189 

190 # construct a bounding box that holds the warping results for each channel 

191 exp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(2**12, 2**12)) 

192 

193 output_array_hpx = np.zeros((4096, 4096, 3), dtype=np.float32) 

194 output_array_hpx[:, :, :] = np.nan 

195 

196 self.log.info("Warping input exposures and populating hpx8 super tile.") 

197 # Need to loop over input arrays then channel 

198 # Warp and combine input images into the HealPix tile 

199 for input_image, in_wcs, in_box in input_images: 

200 tmp_image = ImageF(in_box) 

201 in_image: NDArray = input_image 

202 

203 # Normalize image data based on dtype 

204 match in_image.dtype: 

205 case np.uint8: 

206 in_image = in_image.astype(np.float32) / 255.0 

207 case np.uint16: 

208 in_image = in_image.astype(np.float32) / 65535 

209 case np.float16: 

210 in_image = in_image.astype(np.float32) 

211 

212 # Process each color channel separately 

213 for channel in ColorChannel: 

214 # existing data 

215 existing = output_array_hpx[..., channel.value] 

216 

217 # construct an Exposure object from one channel in the array 

218 channel_array = in_image[..., channel.value] 

219 tmp_image.array[:, :] = channel_array 

220 

221 # Warp the image to the target WCS 

222 warpped = self.warper.warpImage(target_wcs, tmp_image, in_wcs, maxBBox=exp_bbox) 

223 warpped_box_slices = warpped.getBBox().slices 

224 

225 # Update the output array with valid (non-NaN) values 

226 are_warpped = np.isfinite(warpped.array) 

227 existing[warpped_box_slices][are_warpped] = warpped.array[are_warpped] 

228 

229 # Replace any remaining NaN values with zeros 

230 output_array_hpx[np.isnan(output_array_hpx)] = 0 

231 

232 # Flip the y-axis to match HealPix indexing 

233 output_array_hpx = output_array_hpx[::-1, :, :] 

234 

235 # Generate tiles for different HealPix orders using Lanczos resampling instead of binning. 

236 # This handles how intensities should change as the hips level changes. 

237 # 

238 # what this does is take a single 4096 x 4096 image and resamples it in a courser grain such 

239 # that the output pixels correspond to a 4x4 grid of hips pixels at an increasingly lower scale. 

240 # This works because hips is a hierarchy of tiles all contained in the same area of the sky. 

241 # This allows us to generate all the output images by resampling the inputs and saves the time 

242 # required to generate whole new images at each scale. 

243 # 

244 # The loop variables are the resampling factor, the hips order, and the number of sub-divisions 

245 # a pixel has gone through (used to determine quadrant). 

246 for zoom, hips_level, factor in zip((0, 2, 4, 8), (11, 10, 9, 8), (3, 2, 1, 0)): 

247 self.log.info("Generating tiles for hxp level %d", hips_level) 

248 if zoom: 

249 size = 4096 // zoom 

250 binned_array = cv2.resize(output_array_hpx, (size, size), interpolation=cv2.INTER_LANCZOS4) 

251 else: 

252 binned_array = output_array_hpx 

253 # always create blocks of 512x512 as that is native shift order 9 size 

254 # 

255 # Figure out the hips pixel ids at this hips order. This is complicated because each hipx pixel 

256 # turns into 4 at a higher level, but must be in a specific order to correspond to how the data 

257 # is layed out in an y,x grid. So if a hips order 8 pixel A turns into four pixels b,c,d,e, they 

258 # are layed out like [[b,d], [c,e]]. This is true for every pixel as you go up in order. So 

259 # if you start at order 8 with one pixel, you need to do order 9 and calculate the layout. Then 

260 # for each order 9 pixel, do the same to get the layout in order 10, etc. This leaves a grid 

261 # of pixels that are the ids of the corresponding 512,512 sub grid pixel in the input image. 

262 tmp_pixels = np.array([[healpix_id]]) 

263 for _ in range(factor): 

264 tmp_array = np.zeros(np.array(tmp_pixels.shape) * 2) 

265 for ii in range(tmp_pixels.shape[0]): 

266 for jj in range(tmp_pixels.shape[1]): 

267 tmp_array_view = tmp_array[ii * 2 : ii * 2 + 2, jj * 2 : jj * 2 + 2] 

268 tmp_range_set = RangeSet(int(tmp_pixels[ii, jj])) 

269 tmp_array_view[:, :] = ( 

270 np.array([x for x in range(*tmp_range_set.scaled(4)[0])], dtype=int)[[0, 2, 1, 3]] 

271 ).reshape(2, 2) 

272 tmp_pixels = tmp_array 

273 

274 # now for each 512x512 sub pixel region write the hips image with the corresponding healpix id 

275 hpx_id_array = tmp_pixels 

276 for i in range(binned_array.shape[0] // 512): 

277 for j in range(binned_array.shape[1] // 512): 

278 pixel_id = int(hpx_id_array[i, j]) 

279 sub_pixel = binned_array[i * 512 : i * 512 + 512, j * 512 : j * 512 + 512, :] 

280 self.log.info(f"writing sub_pixel {pixel_id}") 

281 _write_hips_image( 

282 sub_pixel, 

283 pixel_id, 

284 hips_level, 

285 self.hips_base_path, 

286 self.config.file_extension, 

287 self.config.array_type, 

288 ) 

289 

290 # Finally, bin the level 8 hpx to 256x256 (1/4 order 7) to save to the butler. 

291 # This makes smaller arrays to load, and saves the binning operation in the joint phase. 

292 zoomed = cv2.resize(output_array_hpx, (256, 256), interpolation=cv2.INTER_LANCZOS4) 

293 

294 return Struct(output_hpx=zoomed) 

295 

296 def _assemble_sub_region( 

297 self, tract_patch: dict[int, Iterable[tuple[DeferredDatasetHandle, SkyWcs, Box2I]]], patch_grow: int 

298 ) -> list[tuple[NDArray, SkyWcs, Box2I]]: 

299 """Assemble all the patches in each tract into images. 

300 

301 This function takes in an input keyed by tract, with values 

302 corresponding the patches in that tract that overlap the quatum's 

303 healpix value. It assembles each of these into a single image such 

304 that the return values is a list of images (and metadata) one element 

305 for each input tract. 

306 

307 Parameters 

308 ---------- 

309 tract_patch : `dict` of `int` to `iterable` of `tuple` of 

310 `DeferredDatasetHandle`, `SkyWcs` and `Box2I` 

311 Input images and metadata organized into corresponding tracts. 

312 patch_grow : `int` 

313 Amount to grow patches by 

314 

315 Returns 

316 ------- 

317 output_list : `list` of `tuple` of `NDArray` `SkyWcs` and `Box2I` 

318 List of assembled images and metadata, one element for each tract 

319 

320 """ 

321 

322 boxes = [] 

323 for _, iterable in tract_patch.items(): 

324 mosaic_maker = FeatheredMosaicCreator(patch_grow) 

325 new_box = Box2I() 

326 for _, _, bbox in iterable: 

327 new_box.include(bbox) 

328 # allocate tmp array 

329 new_array = np.zeros((new_box.getHeight(), new_box.getWidth(), 3), dtype=np.float32) 

330 for handle, skyWcs, box in iterable: 

331 # Make a new box of the same size, but with the origin centered 

332 # on the lowest corner were there is data. 

333 localOrigin = box.getBegin() - new_box.getBegin() 

334 localOrigin = Point2I( 

335 x=int(np.floor(localOrigin.x)), 

336 y=int(np.floor(localOrigin.y)), 

337 ) 

338 

339 localExtent = Extent2I( 

340 x=int(np.floor(box.getWidth())), 

341 y=int(np.floor(box.getHeight())), 

342 ) 

343 tmpBox = Box2I(localOrigin, localExtent) 

344 tmp_new_box = Box2I(Point2I(x=0, y=0), Extent2I(x=new_box.getWidth(), y=new_box.getHeight())) 

345 

346 image = handle.get() 

347 mosaic_maker.add_to_image(new_array, image, tmp_new_box, tmpBox, reverse=False) 

348 boxes.append((new_array, skyWcs, new_box)) 

349 return boxes 

350 

351 def runQuantum( 

352 self, 

353 butlerQC: QuantumContext, 

354 inputRefs: InputQuantizedConnection, 

355 outputRefs: OutputQuantizedConnection, 

356 ) -> None: 

357 # First get what healpix pixel this task is working on 

358 healpix_id = butlerQC.quantum.dataId["healpix8"] 

359 

360 # grab the skymap 

361 skymap: BaseSkyMap = butlerQC.get(inputRefs.skymap) 

362 

363 # Iterate over the input image refs, to get the corresponding bbox 

364 # and assemble into container for run 

365 inputs_by_tract = {} 

366 for input_image_ref in inputRefs.input_images: 

367 tract = input_image_ref.dataId["tract"] 

368 patch = input_image_ref.dataId["patch"] 

369 # All boxes in a given skymap will have the same inner dimensions 

370 # for x and y and will be the same for all patches 

371 imageWcs = skymap[tract][patch].getWcs() 

372 box = skymap[tract][patch].getOuterBBox() 

373 patch_grow = skymap[tract][patch].getCellInnerDimensions().getX() 

374 imageHandle = butlerQC.get(input_image_ref) 

375 container = inputs_by_tract.setdefault(tract, list()) 

376 container.append((imageHandle, imageWcs, box)) 

377 

378 input_images = self._assemble_sub_region(inputs_by_tract, patch_grow) 

379 

380 outputs = self.run(input_images, healpix_id) 

381 butlerQC.put(outputs, outputRefs)