Coverage for python/lsst/ap/association/packageAlerts.py: 23%

164 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-19 09:44 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("PackageAlertsConfig", "PackageAlertsTask") 

23 

24import io 

25import os 

26import sys 

27 

28from astropy import wcs 

29import astropy.units as u 

30from astropy.nddata import CCDData, VarianceUncertainty 

31import pandas as pd 

32import struct 

33import fastavro 

34# confluent_kafka is not in the standard Rubin environment as it is a third 

35# party package and is only needed when producing alerts. 

36try: 

37 import confluent_kafka 

38 from confluent_kafka import KafkaException 

39except ImportError: 

40 confluent_kafka = None 

41 

42import lsst.alert.packet as alertPack 

43import lsst.afw.geom as afwGeom 

44import lsst.geom as geom 

45import lsst.pex.config as pexConfig 

46from lsst.pex.exceptions import InvalidParameterError 

47import lsst.pipe.base as pipeBase 

48from lsst.utils.timer import timeMethod 

49 

50 

51class PackageAlertsConfig(pexConfig.Config): 

52 """Config class for AssociationTask. 

53 """ 

54 schemaFile = pexConfig.Field( 

55 dtype=str, 

56 doc="Schema definition file URI for the avro alerts.", 

57 default=str(alertPack.get_uri_to_latest_schema()) 

58 ) 

59 minCutoutSize = pexConfig.RangeField( 

60 dtype=int, 

61 min=0, 

62 max=1000, 

63 default=30, 

64 doc="Dimension of the square image cutouts to package in the alert." 

65 ) 

66 alertWriteLocation = pexConfig.Field( 

67 dtype=str, 

68 doc="Location to write alerts to.", 

69 default=os.path.join(os.getcwd(), "alerts"), 

70 ) 

71 

72 doProduceAlerts = pexConfig.Field( 

73 dtype=bool, 

74 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.", 

75 default=False, 

76 ) 

77 

78 doWriteAlerts = pexConfig.Field( 

79 dtype=bool, 

80 doc="Write alerts to disk if true.", 

81 default=False, 

82 ) 

83 

84 

85class PackageAlertsTask(pipeBase.Task): 

86 """Tasks for packaging Dia and Pipelines data into Avro alert packages. 

87 """ 

88 ConfigClass = PackageAlertsConfig 

89 _DefaultName = "packageAlerts" 

90 

91 _scale = (1.0 * geom.arcseconds).asDegrees() 

92 

93 def __init__(self, **kwargs): 

94 super().__init__(**kwargs) 

95 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile) 

96 os.makedirs(self.config.alertWriteLocation, exist_ok=True) 

97 

98 if self.config.doProduceAlerts: 

99 if confluent_kafka is not None: 

100 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") 

101 if not self.password: 

102 raise ValueError("Kafka password environment variable was not set.") 

103 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") 

104 if not self.username: 

105 raise ValueError("Kafka username environment variable was not set.") 

106 self.server = os.getenv("AP_KAFKA_SERVER") 

107 if not self.server: 

108 raise ValueError("Kafka server environment variable was not set.") 

109 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC") 

110 if not self.kafkaTopic: 

111 raise ValueError("Kafka topic environment variable was not set.") 

112 

113 # confluent_kafka configures all of its classes with dictionaries. This one 

114 # sets up the bare minimum that is needed. 

115 self.kafkaConfig = { 

116 # This is the URL to use to connect to the Kafka cluster. 

117 "bootstrap.servers": self.server, 

118 # These next two properties tell the Kafka client about the specific 

119 # authentication and authorization protocols that should be used when 

120 # connecting. 

121 "security.protocol": "SASL_PLAINTEXT", 

122 "sasl.mechanisms": "SCRAM-SHA-512", 

123 # The sasl.username and sasl.password are passed through over 

124 # SCRAM-SHA-512 auth to connect to the cluster. The username is not 

125 # sensitive, but the password is (of course) a secret value which 

126 # should never be committed to source code. 

127 "sasl.username": self.username, 

128 "sasl.password": self.password, 

129 # Batch size limits the largest size of a kafka alert that can be sent. 

130 # We set the batch size to 2 Mb. 

131 "batch.size": 2097152, 

132 "linger.ms": 5, 

133 } 

134 self.producer = confluent_kafka.Producer(**self.kafkaConfig) 

135 

136 else: 

137 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in " 

138 "the environment. Alerts will not be sent to the alert stream.") 

139 

140 @timeMethod 

141 def run(self, 

142 diaSourceCat, 

143 diaObjectCat, 

144 diaSrcHistory, 

145 diaForcedSources, 

146 diffIm, 

147 template, 

148 ): 

149 """Package DiaSources/Object and exposure data into Avro alerts. 

150 

151 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set 

152 and written to disk if ``doWriteAlerts`` is set. Both can be set at the 

153 same time, and are independent of one another. 

154 

155 Writes Avro alerts to a location determined by the 

156 ``alertWriteLocation`` configurable. 

157 

158 Parameters 

159 ---------- 

160 diaSourceCat : `pandas.DataFrame` 

161 New DiaSources to package. DataFrame should be indexed on 

162 ``["diaObjectId", "band", "diaSourceId"]`` 

163 diaObjectCat : `pandas.DataFrame` 

164 New and updated DiaObjects matched to the new DiaSources. DataFrame 

165 is indexed on ``["diaObjectId"]`` 

166 diaSrcHistory : `pandas.DataFrame` 

167 12 month history of DiaSources matched to the DiaObjects. Excludes 

168 the newest DiaSource and is indexed on 

169 ``["diaObjectId", "band", "diaSourceId"]`` 

170 diaForcedSources : `pandas.DataFrame` 

171 12 month history of DiaForcedSources matched to the DiaObjects. 

172 ``["diaObjectId"]`` 

173 diffIm : `lsst.afw.image.ExposureF` 

174 Difference image the sources in ``diaSourceCat`` were detected in. 

175 template : `lsst.afw.image.ExposureF` or `None` 

176 Template image used to create the ``diffIm``. 

177 """ 

178 alerts = [] 

179 self._patchDiaSources(diaSourceCat) 

180 self._patchDiaSources(diaSrcHistory) 

181 ccdVisitId = diffIm.info.id 

182 diffImPhotoCalib = diffIm.getPhotoCalib() 

183 templatePhotoCalib = template.getPhotoCalib() 

184 for srcIndex, diaSource in diaSourceCat.iterrows(): 

185 # Get all diaSources for the associated diaObject. 

186 # TODO: DM-31992 skip DiaSources associated with Solar System 

187 # Objects for now. 

188 if srcIndex[0] == 0: 

189 continue 

190 diaObject = diaObjectCat.loc[srcIndex[0]] 

191 if diaObject["nDiaSources"] > 1: 

192 objSourceHistory = diaSrcHistory.loc[srcIndex[0]] 

193 else: 

194 objSourceHistory = None 

195 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]] 

196 sphPoint = geom.SpherePoint(diaSource["ra"], 

197 diaSource["dec"], 

198 geom.degrees) 

199 

200 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"]) 

201 diffImCutout = self.createCcdDataCutout( 

202 diffIm, 

203 sphPoint, 

204 cutoutExtent, 

205 diffImPhotoCalib, 

206 diaSource["diaSourceId"]) 

207 templateCutout = self.createCcdDataCutout( 

208 template, 

209 sphPoint, 

210 cutoutExtent, 

211 templatePhotoCalib, 

212 diaSource["diaSourceId"]) 

213 

214 # TODO: Create alertIds DM-24858 

215 alertId = diaSource["diaSourceId"] 

216 alerts.append( 

217 self.makeAlertDict(alertId, 

218 diaSource, 

219 diaObject, 

220 objSourceHistory, 

221 objDiaForcedSources, 

222 diffImCutout, 

223 templateCutout)) 

224 

225 if self.config.doProduceAlerts: 

226 self.produceAlerts(alerts, ccdVisitId) 

227 

228 if self.config.doWriteAlerts: 

229 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f: 

230 self.alertSchema.store_alerts(f, alerts) 

231 

232 def _patchDiaSources(self, diaSources): 

233 """Add the ``programId`` column to the data. 

234 

235 Parameters 

236 ---------- 

237 diaSources : `pandas.DataFrame` 

238 DataFrame of DiaSources to patch. 

239 """ 

240 diaSources["programId"] = 0 

241 

242 def createDiaSourceExtent(self, bboxSize): 

243 """Create an extent for a box for the cutouts given the size of the 

244 square BBox that covers the source footprint. 

245 

246 Parameters 

247 ---------- 

248 bboxSize : `int` 

249 Size of a side of the square bounding box in pixels. 

250 

251 Returns 

252 ------- 

253 extent : `lsst.geom.Extent2I` 

254 Geom object representing the size of the bounding box. 

255 """ 

256 if bboxSize < self.config.minCutoutSize: 

257 extent = geom.Extent2I(self.config.minCutoutSize, 

258 self.config.minCutoutSize) 

259 else: 

260 extent = geom.Extent2I(bboxSize, bboxSize) 

261 return extent 

262 

263 def produceAlerts(self, alerts, ccdVisitId): 

264 """Serialize alerts and send them to the alert stream using 

265 confluent_kafka's producer. 

266 

267 Parameters 

268 ---------- 

269 alerts : `dict` 

270 Dictionary of alerts to be sent to the alert stream. 

271 ccdVisitId : `int` 

272 ccdVisitId of the alerts sent to the alert stream. Used to write 

273 out alerts which fail to be sent to the alert stream. 

274 """ 

275 for alert in alerts: 

276 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1) 

277 try: 

278 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback) 

279 self.producer.flush() 

280 

281 except KafkaException as e: 

282 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes))) 

283 

284 with open(os.path.join(self.config.alertWriteLocation, 

285 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f: 

286 f.write(alertBytes) 

287 

288 self.producer.flush() 

289 

290 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId): 

291 """Grab an image as a cutout and return a calibrated CCDData image. 

292 

293 Parameters 

294 ---------- 

295 image : `lsst.afw.image.ExposureF` 

296 Image to pull cutout from. 

297 skyCenter : `lsst.geom.SpherePoint` 

298 Center point of DiaSource on the sky. 

299 extent : `lsst.geom.Extent2I` 

300 Bounding box to cutout from the image. 

301 photoCalib : `lsst.afw.image.PhotoCalib` 

302 Calibrate object of the image the cutout is cut from. 

303 srcId : `int` 

304 Unique id of DiaSource. Used for when an error occurs extracting 

305 a cutout. 

306 

307 Returns 

308 ------- 

309 ccdData : `astropy.nddata.CCDData` or `None` 

310 CCDData object storing the calibrate information from the input 

311 difference or template image. 

312 """ 

313 # Catch errors in retrieving the cutout. 

314 try: 

315 cutout = image.getCutout(skyCenter, extent) 

316 except InvalidParameterError: 

317 point = image.getWcs().skyToPixel(skyCenter) 

318 imBBox = image.getBBox() 

319 if not geom.Box2D(image.getBBox()).contains(point): 

320 self.log.warning( 

321 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) " 

322 "which is outside the Exposure with bounding box " 

323 "((%i, %i), (%i, %i)). Returning None for cutout...", 

324 srcId, point.x, point.y, 

325 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY) 

326 else: 

327 raise InvalidParameterError( 

328 "Failed to retrieve cutout from image for DiaSource with " 

329 "id=%i. InvalidParameterError thrown during cutout " 

330 "creation. Exiting." 

331 % srcId) 

332 return None 

333 

334 # Find the value of the bottom corner of our cutout's BBox and 

335 # subtract 1 so that the CCDData cutout position value will be 

336 # [1, 1]. 

337 cutOutMinX = cutout.getBBox().minX - 1 

338 cutOutMinY = cutout.getBBox().minY - 1 

339 center = cutout.getWcs().skyToPixel(skyCenter) 

340 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage()) 

341 

342 cutoutWcs = wcs.WCS(naxis=2) 

343 cutoutWcs.array_shape = (cutout.getBBox().getWidth(), 

344 cutout.getBBox().getWidth()) 

345 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY] 

346 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(), 

347 skyCenter.getDec().asDegrees()] 

348 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(), 

349 center, 

350 skyCenter) 

351 

352 return CCDData( 

353 data=calibCutout.getImage().array, 

354 uncertainty=VarianceUncertainty(calibCutout.getVariance().array), 

355 flags=calibCutout.getMask().array, 

356 wcs=cutoutWcs, 

357 meta={"cutMinX": cutOutMinX, 

358 "cutMinY": cutOutMinY}, 

359 unit=u.nJy) 

360 

361 def makeLocalTransformMatrix(self, wcs, center, skyCenter): 

362 """Create a local, linear approximation of the wcs transformation 

363 matrix. 

364 

365 The approximation is created as if the center is at RA=0, DEC=0. All 

366 comparing x,y coordinate are relative to the position of center. Matrix 

367 is initially calculated with units arcseconds and then converted to 

368 degrees. This yields higher precision results due to quirks in AST. 

369 

370 Parameters 

371 ---------- 

372 wcs : `lsst.afw.geom.SkyWcs` 

373 Wcs to approximate 

374 center : `lsst.geom.Point2D` 

375 Point at which to evaluate the LocalWcs. 

376 skyCenter : `lsst.geom.SpherePoint` 

377 Point on sky to approximate the Wcs. 

378 

379 Returns 

380 ------- 

381 localMatrix : `numpy.ndarray` 

382 Matrix representation the local wcs approximation with units 

383 degrees. 

384 """ 

385 blankCDMatrix = [[self._scale, 0], [0, self._scale]] 

386 localGnomonicWcs = afwGeom.makeSkyWcs( 

387 center, skyCenter, blankCDMatrix) 

388 measurementToLocalGnomonic = wcs.getTransform().then( 

389 localGnomonicWcs.getTransform().inverted() 

390 ) 

391 localMatrix = measurementToLocalGnomonic.getJacobian(center) 

392 return localMatrix / 3600 

393 

394 def makeAlertDict(self, 

395 alertId, 

396 diaSource, 

397 diaObject, 

398 objDiaSrcHistory, 

399 objDiaForcedSources, 

400 diffImCutout, 

401 templateCutout): 

402 """Convert data and package into a dictionary alert. 

403 

404 Parameters 

405 ---------- 

406 diaSource : `pandas.DataFrame` 

407 New single DiaSource to package. 

408 diaObject : `pandas.DataFrame` 

409 DiaObject that ``diaSource`` is matched to. 

410 objDiaSrcHistory : `pandas.DataFrame` 

411 12 month history of ``diaObject`` excluding the latest DiaSource. 

412 objDiaForcedSources : `pandas.DataFrame` 

413 12 month history of ``diaObject`` forced measurements. 

414 diffImCutout : `astropy.nddata.CCDData` or `None` 

415 Cutout of the difference image around the location of ``diaSource`` 

416 with a min size set by the ``cutoutSize`` configurable. 

417 templateCutout : `astropy.nddata.CCDData` or `None` 

418 Cutout of the template image around the location of ``diaSource`` 

419 with a min size set by the ``cutoutSize`` configurable. 

420 """ 

421 alert = dict() 

422 alert['alertId'] = alertId 

423 alert['diaSource'] = diaSource.to_dict() 

424 

425 if objDiaSrcHistory is None: 

426 alert['prvDiaSources'] = objDiaSrcHistory 

427 else: 

428 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records") 

429 

430 if isinstance(objDiaForcedSources, pd.Series): 

431 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()] 

432 else: 

433 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records") 

434 alert['prvDiaNondetectionLimits'] = None 

435 

436 alert['diaObject'] = diaObject.to_dict() 

437 

438 alert['ssObject'] = None 

439 

440 if diffImCutout is None: 

441 alert['cutoutDifference'] = None 

442 else: 

443 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout) 

444 

445 if templateCutout is None: 

446 alert["cutoutTemplate"] = None 

447 else: 

448 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout) 

449 

450 return alert 

451 

452 def streamCcdDataToBytes(self, cutout): 

453 """Serialize a cutout into bytes. 

454 

455 Parameters 

456 ---------- 

457 cutout : `astropy.nddata.CCDData` 

458 Cutout to serialize. 

459 

460 Returns 

461 ------- 

462 coutputBytes : `bytes` 

463 Input cutout serialized into byte data. 

464 """ 

465 with io.BytesIO() as streamer: 

466 cutout.write(streamer, format="fits") 

467 cutoutBytes = streamer.getvalue() 

468 return cutoutBytes 

469 

470 def _serializeAlert(self, alert, schema=None, schema_id=0): 

471 """Serialize an alert to a byte sequence for sending to Kafka. 

472 

473 Parameters 

474 ---------- 

475 alert : `dict` 

476 An alert payload to be serialized. 

477 schema : `dict`, optional 

478 An Avro schema definition describing how to encode `alert`. By default, 

479 the schema is None, which sets it to the latest schema available. 

480 schema_id : `int`, optional 

481 The Confluent Schema Registry ID of the schema. By default, 0 (an 

482 invalid ID) is used, indicating that the schema is not registered. 

483 

484 Returns 

485 ------- 

486 serialized : `bytes` 

487 The byte sequence describing the alert, including the Confluent Wire 

488 Format prefix. 

489 """ 

490 if schema is None: 

491 schema = self.alertSchema.definition 

492 

493 buf = io.BytesIO() 

494 # TODO: Use a proper schema versioning system (DM-42606) 

495 buf.write(self._serializeConfluentWireHeader(schema_id)) 

496 fastavro.schemaless_writer(buf, schema, alert) 

497 return buf.getvalue() 

498 

499 @staticmethod 

500 def _serializeConfluentWireHeader(schema_version): 

501 """Returns the byte prefix for Confluent Wire Format-style Kafka messages. 

502 

503 Parameters 

504 ---------- 

505 schema_version : `int` 

506 A version number which indicates the Confluent Schema Registry ID 

507 number of the Avro schema used to encode the message that follows this 

508 header. 

509 

510 Returns 

511 ------- 

512 header : `bytes` 

513 The 5-byte encoded message prefix. 

514 

515 Notes 

516 ----- 

517 The Confluent Wire Format is described more fully here: 

518 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format 

519 """ 

520 ConfluentWireFormatHeader = struct.Struct(">bi") 

521 return ConfluentWireFormatHeader.pack(0, schema_version) 

522 

523 def _delivery_callback(self, err, msg): 

524 if err: 

525 self.log.warning('Message failed delivery: %s\n' % err) 

526 else: 

527 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())