Coverage for python/lsst/ap/association/packageAlerts.py: 21%

182 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-27 03:37 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("PackageAlertsConfig", "PackageAlertsTask") 

23 

24import io 

25import os 

26import sys 

27 

28from astropy import wcs 

29import astropy.units as u 

30from astropy.nddata import CCDData, VarianceUncertainty 

31import pandas as pd 

32import struct 

33import fastavro 

34# confluent_kafka is not in the standard Rubin environment as it is a third 

35# party package and is only needed when producing alerts. 

36try: 

37 import confluent_kafka 

38 from confluent_kafka import KafkaException 

39 from confluent_kafka.admin import AdminClient 

40except ImportError: 

41 confluent_kafka = None 

42 

43import lsst.alert.packet as alertPack 

44import lsst.afw.geom as afwGeom 

45import lsst.geom as geom 

46import lsst.pex.config as pexConfig 

47from lsst.pex.exceptions import InvalidParameterError 

48import lsst.pipe.base as pipeBase 

49from lsst.utils.timer import timeMethod 

50 

51 

52class PackageAlertsConfig(pexConfig.Config): 

53 """Config class for AssociationTask. 

54 """ 

55 schemaFile = pexConfig.Field( 

56 dtype=str, 

57 doc="Schema definition file URI for the avro alerts.", 

58 default=str(alertPack.get_uri_to_latest_schema()) 

59 ) 

60 minCutoutSize = pexConfig.RangeField( 

61 dtype=int, 

62 min=0, 

63 max=1000, 

64 default=30, 

65 doc="Dimension of the square image cutouts to package in the alert." 

66 ) 

67 alertWriteLocation = pexConfig.Field( 

68 dtype=str, 

69 doc="Location to write alerts to.", 

70 default=os.path.join(os.getcwd(), "alerts"), 

71 ) 

72 

73 doProduceAlerts = pexConfig.Field( 

74 dtype=bool, 

75 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.", 

76 default=False, 

77 ) 

78 

79 doWriteAlerts = pexConfig.Field( 

80 dtype=bool, 

81 doc="Write alerts to disk if true.", 

82 default=False, 

83 ) 

84 

85 doWriteFailedAlerts = pexConfig.Field( 

86 dtype=bool, 

87 doc="If an alert cannot be sent when doProduceAlerts is set, " 

88 "write it to disk for debugging purposes.", 

89 default=False, 

90 ) 

91 

92 

93class PackageAlertsTask(pipeBase.Task): 

94 """Tasks for packaging Dia and Pipelines data into Avro alert packages. 

95 """ 

96 ConfigClass = PackageAlertsConfig 

97 _DefaultName = "packageAlerts" 

98 

99 _scale = (1.0 * geom.arcseconds).asDegrees() 

100 

101 def __init__(self, **kwargs): 

102 super().__init__(**kwargs) 

103 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile) 

104 os.makedirs(self.config.alertWriteLocation, exist_ok=True) 

105 

106 if self.config.doProduceAlerts: 

107 if confluent_kafka is not None: 

108 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") 

109 if not self.password: 

110 raise ValueError("Kafka password environment variable was not set.") 

111 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") 

112 if not self.username: 

113 raise ValueError("Kafka username environment variable was not set.") 

114 self.server = os.getenv("AP_KAFKA_SERVER") 

115 if not self.server: 

116 raise ValueError("Kafka server environment variable was not set.") 

117 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC") 

118 if not self.kafkaTopic: 

119 raise ValueError("Kafka topic environment variable was not set.") 

120 

121 # confluent_kafka configures all of its classes with dictionaries. This one 

122 # sets up the bare minimum that is needed. 

123 self.kafkaConfig = { 

124 # This is the URL to use to connect to the Kafka cluster. 

125 "bootstrap.servers": self.server, 

126 # These next two properties tell the Kafka client about the specific 

127 # authentication and authorization protocols that should be used when 

128 # connecting. 

129 "security.protocol": "SASL_PLAINTEXT", 

130 "sasl.mechanisms": "SCRAM-SHA-512", 

131 # The sasl.username and sasl.password are passed through over 

132 # SCRAM-SHA-512 auth to connect to the cluster. The username is not 

133 # sensitive, but the password is (of course) a secret value which 

134 # should never be committed to source code. 

135 "sasl.username": self.username, 

136 "sasl.password": self.password, 

137 # Batch size limits the largest size of a kafka alert that can be sent. 

138 # We set the batch size to 2 Mb. 

139 "batch.size": 2097152, 

140 "linger.ms": 5, 

141 } 

142 self.kafkaAdminConfig = { 

143 # This is the URL to use to connect to the Kafka cluster. 

144 "bootstrap.servers": self.server, 

145 # These next two properties tell the Kafka client about the specific 

146 # authentication and authorization protocols that should be used when 

147 # connecting. 

148 "security.protocol": "SASL_PLAINTEXT", 

149 "sasl.mechanisms": "SCRAM-SHA-512", 

150 # The sasl.username and sasl.password are passed through over 

151 # SCRAM-SHA-512 auth to connect to the cluster. The username is not 

152 # sensitive, but the password is (of course) a secret value which 

153 # should never be committed to source code. 

154 "sasl.username": self.username, 

155 "sasl.password": self.password, 

156 } 

157 

158 self._server_check() 

159 self.producer = confluent_kafka.Producer(**self.kafkaConfig) 

160 

161 else: 

162 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in " 

163 "the environment. Alerts will not be sent to the alert stream.") 

164 

165 @timeMethod 

166 def run(self, 

167 diaSourceCat, 

168 diaObjectCat, 

169 diaSrcHistory, 

170 diaForcedSources, 

171 diffIm, 

172 calexp, 

173 template, 

174 doRunForcedMeasurement=True, 

175 ): 

176 """Package DiaSources/Object and exposure data into Avro alerts. 

177 

178 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set 

179 and written to disk if ``doWriteAlerts`` is set. Both can be set at the 

180 same time, and are independent of one another. 

181 

182 Writes Avro alerts to a location determined by the 

183 ``alertWriteLocation`` configurable. 

184 

185 Parameters 

186 ---------- 

187 diaSourceCat : `pandas.DataFrame` 

188 New DiaSources to package. DataFrame should be indexed on 

189 ``["diaObjectId", "band", "diaSourceId"]`` 

190 diaObjectCat : `pandas.DataFrame` 

191 New and updated DiaObjects matched to the new DiaSources. DataFrame 

192 is indexed on ``["diaObjectId"]`` 

193 diaSrcHistory : `pandas.DataFrame` 

194 12 month history of DiaSources matched to the DiaObjects. Excludes 

195 the newest DiaSource and is indexed on 

196 ``["diaObjectId", "band", "diaSourceId"]`` 

197 diaForcedSources : `pandas.DataFrame` 

198 12 month history of DiaForcedSources matched to the DiaObjects. 

199 ``["diaObjectId"]`` 

200 diffIm : `lsst.afw.image.ExposureF` 

201 Difference image the sources in ``diaSourceCat`` were detected in. 

202 calexp : `lsst.afw.image.ExposureF` 

203 Calexp used to create the ``diffIm``. 

204 template : `lsst.afw.image.ExposureF` or `None` 

205 Template image used to create the ``diffIm``. 

206 doRunForcedMeasurement : `bool`, optional 

207 Flag to indicate whether forced measurement was run. 

208 This should only be turned off for debugging purposes. 

209 Added to allow disabling forced sources for performance 

210 reasons during the ops rehearsal. 

211 """ 

212 alerts = [] 

213 self._patchDiaSources(diaSourceCat) 

214 self._patchDiaSources(diaSrcHistory) 

215 ccdVisitId = diffIm.info.id 

216 diffImPhotoCalib = diffIm.getPhotoCalib() 

217 calexpPhotoCalib = calexp.getPhotoCalib() 

218 templatePhotoCalib = template.getPhotoCalib() 

219 for srcIndex, diaSource in diaSourceCat.iterrows(): 

220 # Get all diaSources for the associated diaObject. 

221 # TODO: DM-31992 skip DiaSources associated with Solar System 

222 # Objects for now. 

223 if srcIndex[0] == 0: 

224 continue 

225 diaObject = diaObjectCat.loc[srcIndex[0]] 

226 if diaObject["nDiaSources"] > 1: 

227 objSourceHistory = diaSrcHistory.loc[srcIndex[0]] 

228 else: 

229 objSourceHistory = None 

230 if doRunForcedMeasurement: 

231 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]] 

232 else: 

233 # Send empty table with correct columns 

234 objDiaForcedSources = diaForcedSources.loc[[]] 

235 sphPoint = geom.SpherePoint(diaSource["ra"], 

236 diaSource["dec"], 

237 geom.degrees) 

238 

239 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"]) 

240 diffImCutout = self.createCcdDataCutout( 

241 diffIm, 

242 sphPoint, 

243 cutoutExtent, 

244 diffImPhotoCalib, 

245 diaSource["diaSourceId"]) 

246 calexpCutout = self.createCcdDataCutout( 

247 calexp, 

248 sphPoint, 

249 cutoutExtent, 

250 calexpPhotoCalib, 

251 diaSource["diaSourceId"]) 

252 templateCutout = self.createCcdDataCutout( 

253 template, 

254 sphPoint, 

255 cutoutExtent, 

256 templatePhotoCalib, 

257 diaSource["diaSourceId"]) 

258 

259 # TODO: Create alertIds DM-24858 

260 alertId = diaSource["diaSourceId"] 

261 alerts.append( 

262 self.makeAlertDict(alertId, 

263 diaSource, 

264 diaObject, 

265 objSourceHistory, 

266 objDiaForcedSources, 

267 diffImCutout, 

268 calexpCutout, 

269 templateCutout)) 

270 

271 if self.config.doProduceAlerts: 

272 self.produceAlerts(alerts, ccdVisitId) 

273 

274 if self.config.doWriteAlerts: 

275 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f: 

276 self.alertSchema.store_alerts(f, alerts) 

277 

278 def _patchDiaSources(self, diaSources): 

279 """Add the ``programId`` column to the data. 

280 

281 Parameters 

282 ---------- 

283 diaSources : `pandas.DataFrame` 

284 DataFrame of DiaSources to patch. 

285 """ 

286 diaSources["programId"] = 0 

287 

288 def createDiaSourceExtent(self, bboxSize): 

289 """Create an extent for a box for the cutouts given the size of the 

290 square BBox that covers the source footprint. 

291 

292 Parameters 

293 ---------- 

294 bboxSize : `int` 

295 Size of a side of the square bounding box in pixels. 

296 

297 Returns 

298 ------- 

299 extent : `lsst.geom.Extent2I` 

300 Geom object representing the size of the bounding box. 

301 """ 

302 if bboxSize < self.config.minCutoutSize: 

303 extent = geom.Extent2I(self.config.minCutoutSize, 

304 self.config.minCutoutSize) 

305 else: 

306 extent = geom.Extent2I(bboxSize, bboxSize) 

307 return extent 

308 

309 def produceAlerts(self, alerts, ccdVisitId): 

310 """Serialize alerts and send them to the alert stream using 

311 confluent_kafka's producer. 

312 

313 Parameters 

314 ---------- 

315 alerts : `dict` 

316 Dictionary of alerts to be sent to the alert stream. 

317 ccdVisitId : `int` 

318 ccdVisitId of the alerts sent to the alert stream. Used to write 

319 out alerts which fail to be sent to the alert stream. 

320 """ 

321 self._server_check() 

322 for alert in alerts: 

323 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1) 

324 try: 

325 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback) 

326 self.producer.flush() 

327 

328 except KafkaException as e: 

329 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes))) 

330 

331 if self.config.doWriteFailedAlerts: 

332 with open(os.path.join(self.config.alertWriteLocation, 

333 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f: 

334 f.write(alertBytes) 

335 

336 self.producer.flush() 

337 

338 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId): 

339 """Grab an image as a cutout and return a calibrated CCDData image. 

340 

341 Parameters 

342 ---------- 

343 image : `lsst.afw.image.ExposureF` 

344 Image to pull cutout from. 

345 skyCenter : `lsst.geom.SpherePoint` 

346 Center point of DiaSource on the sky. 

347 extent : `lsst.geom.Extent2I` 

348 Bounding box to cutout from the image. 

349 photoCalib : `lsst.afw.image.PhotoCalib` 

350 Calibrate object of the image the cutout is cut from. 

351 srcId : `int` 

352 Unique id of DiaSource. Used for when an error occurs extracting 

353 a cutout. 

354 

355 Returns 

356 ------- 

357 ccdData : `astropy.nddata.CCDData` or `None` 

358 CCDData object storing the calibrate information from the input 

359 difference or template image. 

360 """ 

361 # Catch errors in retrieving the cutout. 

362 try: 

363 cutout = image.getCutout(skyCenter, extent) 

364 except InvalidParameterError: 

365 point = image.getWcs().skyToPixel(skyCenter) 

366 imBBox = image.getBBox() 

367 if not geom.Box2D(image.getBBox()).contains(point): 

368 self.log.warning( 

369 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) " 

370 "which is outside the Exposure with bounding box " 

371 "((%i, %i), (%i, %i)). Returning None for cutout...", 

372 srcId, point.x, point.y, 

373 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY) 

374 else: 

375 raise InvalidParameterError( 

376 "Failed to retrieve cutout from image for DiaSource with " 

377 "id=%i. InvalidParameterError thrown during cutout " 

378 "creation. Exiting." 

379 % srcId) 

380 return None 

381 

382 # Find the value of the bottom corner of our cutout's BBox and 

383 # subtract 1 so that the CCDData cutout position value will be 

384 # [1, 1]. 

385 cutOutMinX = cutout.getBBox().minX - 1 

386 cutOutMinY = cutout.getBBox().minY - 1 

387 center = cutout.getWcs().skyToPixel(skyCenter) 

388 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage()) 

389 

390 cutoutWcs = wcs.WCS(naxis=2) 

391 cutoutWcs.array_shape = (cutout.getBBox().getWidth(), 

392 cutout.getBBox().getWidth()) 

393 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY] 

394 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(), 

395 skyCenter.getDec().asDegrees()] 

396 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(), 

397 center, 

398 skyCenter) 

399 

400 return CCDData( 

401 data=calibCutout.getImage().array, 

402 uncertainty=VarianceUncertainty(calibCutout.getVariance().array), 

403 flags=calibCutout.getMask().array, 

404 wcs=cutoutWcs, 

405 meta={"cutMinX": cutOutMinX, 

406 "cutMinY": cutOutMinY}, 

407 unit=u.nJy) 

408 

409 def makeLocalTransformMatrix(self, wcs, center, skyCenter): 

410 """Create a local, linear approximation of the wcs transformation 

411 matrix. 

412 

413 The approximation is created as if the center is at RA=0, DEC=0. All 

414 comparing x,y coordinate are relative to the position of center. Matrix 

415 is initially calculated with units arcseconds and then converted to 

416 degrees. This yields higher precision results due to quirks in AST. 

417 

418 Parameters 

419 ---------- 

420 wcs : `lsst.afw.geom.SkyWcs` 

421 Wcs to approximate 

422 center : `lsst.geom.Point2D` 

423 Point at which to evaluate the LocalWcs. 

424 skyCenter : `lsst.geom.SpherePoint` 

425 Point on sky to approximate the Wcs. 

426 

427 Returns 

428 ------- 

429 localMatrix : `numpy.ndarray` 

430 Matrix representation the local wcs approximation with units 

431 degrees. 

432 """ 

433 blankCDMatrix = [[self._scale, 0], [0, self._scale]] 

434 localGnomonicWcs = afwGeom.makeSkyWcs( 

435 center, skyCenter, blankCDMatrix) 

436 measurementToLocalGnomonic = wcs.getTransform().then( 

437 localGnomonicWcs.getTransform().inverted() 

438 ) 

439 localMatrix = measurementToLocalGnomonic.getJacobian(center) 

440 return localMatrix / 3600 

441 

442 def makeAlertDict(self, 

443 alertId, 

444 diaSource, 

445 diaObject, 

446 objDiaSrcHistory, 

447 objDiaForcedSources, 

448 diffImCutout, 

449 calexpCutout, 

450 templateCutout): 

451 """Convert data and package into a dictionary alert. 

452 

453 Parameters 

454 ---------- 

455 diaSource : `pandas.DataFrame` 

456 New single DiaSource to package. 

457 diaObject : `pandas.DataFrame` 

458 DiaObject that ``diaSource`` is matched to. 

459 objDiaSrcHistory : `pandas.DataFrame` 

460 12 month history of ``diaObject`` excluding the latest DiaSource. 

461 objDiaForcedSources : `pandas.DataFrame` 

462 12 month history of ``diaObject`` forced measurements. 

463 diffImCutout : `astropy.nddata.CCDData` or `None` 

464 Cutout of the difference image around the location of ``diaSource`` 

465 with a min size set by the ``cutoutSize`` configurable. 

466 calexpCutout : `astropy.nddata.CCDData` or `None` 

467 Cutout of the calexp around the location of ``diaSource`` 

468 with a min size set by the ``cutoutSize`` configurable. 

469 templateCutout : `astropy.nddata.CCDData` or `None` 

470 Cutout of the template image around the location of ``diaSource`` 

471 with a min size set by the ``cutoutSize`` configurable. 

472 """ 

473 alert = dict() 

474 alert['alertId'] = alertId 

475 alert['diaSource'] = diaSource.to_dict() 

476 

477 if objDiaSrcHistory is None: 

478 alert['prvDiaSources'] = objDiaSrcHistory 

479 else: 

480 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records") 

481 

482 if isinstance(objDiaForcedSources, pd.Series): 

483 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()] 

484 else: 

485 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records") 

486 alert['prvDiaNondetectionLimits'] = None 

487 

488 alert['diaObject'] = diaObject.to_dict() 

489 

490 alert['ssObject'] = None 

491 

492 if diffImCutout is None: 

493 alert['cutoutDifference'] = None 

494 else: 

495 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout) 

496 

497 if calexpCutout is None: 

498 alert['cutoutScience'] = None 

499 else: 

500 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout) 

501 

502 if templateCutout is None: 

503 alert["cutoutTemplate"] = None 

504 else: 

505 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout) 

506 

507 return alert 

508 

509 def streamCcdDataToBytes(self, cutout): 

510 """Serialize a cutout into bytes. 

511 

512 Parameters 

513 ---------- 

514 cutout : `astropy.nddata.CCDData` 

515 Cutout to serialize. 

516 

517 Returns 

518 ------- 

519 coutputBytes : `bytes` 

520 Input cutout serialized into byte data. 

521 """ 

522 with io.BytesIO() as streamer: 

523 cutout.write(streamer, format="fits") 

524 cutoutBytes = streamer.getvalue() 

525 return cutoutBytes 

526 

527 def _serializeAlert(self, alert, schema=None, schema_id=0): 

528 """Serialize an alert to a byte sequence for sending to Kafka. 

529 

530 Parameters 

531 ---------- 

532 alert : `dict` 

533 An alert payload to be serialized. 

534 schema : `dict`, optional 

535 An Avro schema definition describing how to encode `alert`. By default, 

536 the schema is None, which sets it to the latest schema available. 

537 schema_id : `int`, optional 

538 The Confluent Schema Registry ID of the schema. By default, 0 (an 

539 invalid ID) is used, indicating that the schema is not registered. 

540 

541 Returns 

542 ------- 

543 serialized : `bytes` 

544 The byte sequence describing the alert, including the Confluent Wire 

545 Format prefix. 

546 """ 

547 if schema is None: 

548 schema = self.alertSchema.definition 

549 

550 buf = io.BytesIO() 

551 # TODO: Use a proper schema versioning system (DM-42606) 

552 buf.write(self._serializeConfluentWireHeader(schema_id)) 

553 fastavro.schemaless_writer(buf, schema, alert) 

554 return buf.getvalue() 

555 

556 @staticmethod 

557 def _serializeConfluentWireHeader(schema_version): 

558 """Returns the byte prefix for Confluent Wire Format-style Kafka messages. 

559 

560 Parameters 

561 ---------- 

562 schema_version : `int` 

563 A version number which indicates the Confluent Schema Registry ID 

564 number of the Avro schema used to encode the message that follows this 

565 header. 

566 

567 Returns 

568 ------- 

569 header : `bytes` 

570 The 5-byte encoded message prefix. 

571 

572 Notes 

573 ----- 

574 The Confluent Wire Format is described more fully here: 

575 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format 

576 """ 

577 ConfluentWireFormatHeader = struct.Struct(">bi") 

578 return ConfluentWireFormatHeader.pack(0, schema_version) 

579 

580 def _delivery_callback(self, err, msg): 

581 if err: 

582 self.log.warning('Message failed delivery: %s\n' % err) 

583 else: 

584 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset()) 

585 

586 def _server_check(self): 

587 """Checks if the alert stream credentials are still valid and the 

588 server is contactable. 

589 

590 Raises 

591 ------- 

592 KafkaException 

593 Raised if the server us not contactable. 

594 RuntimeError 

595 Raised if the server is contactable but there are no topics 

596 present. 

597 """ 

598 admin_client = AdminClient(self.kafkaAdminConfig) 

599 topics = admin_client.list_topics(timeout=0.5).topics 

600 

601 if not topics: 

602 raise RuntimeError()