Coverage for python/lsst/ap/association/packageAlerts.py: 22%

183 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-10 04:36 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("PackageAlertsConfig", "PackageAlertsTask") 

23 

24import io 

25import os 

26import sys 

27 

28from astropy import wcs 

29import astropy.units as u 

30from astropy.nddata import CCDData, VarianceUncertainty 

31import pandas as pd 

32import struct 

33import fastavro 

34# confluent_kafka is not in the standard Rubin environment as it is a third 

35# party package and is only needed when producing alerts. 

36try: 

37 import confluent_kafka 

38 from confluent_kafka import KafkaException 

39 from confluent_kafka.admin import AdminClient 

40except ImportError: 

41 confluent_kafka = None 

42 

43import lsst.alert.packet as alertPack 

44import lsst.afw.geom as afwGeom 

45import lsst.geom as geom 

46import lsst.pex.config as pexConfig 

47from lsst.pex.exceptions import InvalidParameterError 

48import lsst.pipe.base as pipeBase 

49from lsst.utils.timer import timeMethod 

50 

51 

52class PackageAlertsConfig(pexConfig.Config): 

53 """Config class for AssociationTask. 

54 """ 

55 schemaFile = pexConfig.Field( 

56 dtype=str, 

57 doc="Schema definition file URI for the avro alerts.", 

58 default=str(alertPack.get_uri_to_latest_schema()) 

59 ) 

60 minCutoutSize = pexConfig.RangeField( 

61 dtype=int, 

62 min=0, 

63 max=1000, 

64 default=30, 

65 doc="Dimension of the square image cutouts to package in the alert." 

66 ) 

67 alertWriteLocation = pexConfig.Field( 

68 dtype=str, 

69 doc="Location to write alerts to.", 

70 default=os.path.join(os.getcwd(), "alerts"), 

71 ) 

72 

73 doProduceAlerts = pexConfig.Field( 

74 dtype=bool, 

75 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.", 

76 default=False, 

77 ) 

78 

79 doWriteAlerts = pexConfig.Field( 

80 dtype=bool, 

81 doc="Write alerts to disk if true.", 

82 default=False, 

83 ) 

84 

85 doWriteFailedAlerts = pexConfig.Field( 

86 dtype=bool, 

87 doc="If an alert cannot be sent when doProduceAlerts is set, " 

88 "write it to disk for debugging purposes.", 

89 default=False, 

90 ) 

91 

92 maxTimeout = pexConfig.Field( 

93 dtype=float, 

94 doc="Sets the maximum time in seconds to wait for the alert stream " 

95 "broker to respond to a query before timing out.", 

96 default=15.0, 

97 ) 

98 

99 deliveryTimeout = pexConfig.Field( 

100 dtype=float, 

101 doc="Sets the time to wait for the producer to wait to deliver an " 

102 "alert in milliseconds.", 

103 default=1200.0, 

104 ) 

105 

106 

107class PackageAlertsTask(pipeBase.Task): 

108 """Tasks for packaging Dia and Pipelines data into Avro alert packages. 

109 """ 

110 ConfigClass = PackageAlertsConfig 

111 _DefaultName = "packageAlerts" 

112 

113 _scale = (1.0 * geom.arcseconds).asDegrees() 

114 

115 def __init__(self, **kwargs): 

116 super().__init__(**kwargs) 

117 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile) 

118 os.makedirs(self.config.alertWriteLocation, exist_ok=True) 

119 

120 if self.config.doProduceAlerts: 

121 if confluent_kafka is not None: 

122 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") 

123 if not self.password: 

124 raise ValueError("Kafka password environment variable was not set.") 

125 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") 

126 if not self.username: 

127 raise ValueError("Kafka username environment variable was not set.") 

128 self.server = os.getenv("AP_KAFKA_SERVER") 

129 if not self.server: 

130 raise ValueError("Kafka server environment variable was not set.") 

131 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC") 

132 if not self.kafkaTopic: 

133 raise ValueError("Kafka topic environment variable was not set.") 

134 

135 # confluent_kafka configures all of its classes with dictionaries. This one 

136 # sets up the bare minimum that is needed. 

137 self.kafkaConfig = { 

138 # This is the URL to use to connect to the Kafka cluster. 

139 "bootstrap.servers": self.server, 

140 # These next two properties tell the Kafka client about the specific 

141 # authentication and authorization protocols that should be used when 

142 # connecting. 

143 "security.protocol": "SASL_PLAINTEXT", 

144 "sasl.mechanisms": "SCRAM-SHA-512", 

145 # The sasl.username and sasl.password are passed through over 

146 # SCRAM-SHA-512 auth to connect to the cluster. The username is not 

147 # sensitive, but the password is (of course) a secret value which 

148 # should never be committed to source code. 

149 "sasl.username": self.username, 

150 "sasl.password": self.password, 

151 # Batch size limits the largest size of a kafka alert that can be sent. 

152 # We set the batch size to 2 Mb. 

153 "batch.size": 2097152, 

154 "linger.ms": 5, 

155 "delivery.timeout.ms": self.config.deliveryTimeout, 

156 } 

157 self.kafkaAdminConfig = { 

158 # This is the URL to use to connect to the Kafka cluster. 

159 "bootstrap.servers": self.server, 

160 # These next two properties tell the Kafka client about the specific 

161 # authentication and authorization protocols that should be used when 

162 # connecting. 

163 "security.protocol": "SASL_PLAINTEXT", 

164 "sasl.mechanisms": "SCRAM-SHA-512", 

165 # The sasl.username and sasl.password are passed through over 

166 # SCRAM-SHA-512 auth to connect to the cluster. The username is not 

167 # sensitive, but the password is (of course) a secret value which 

168 # should never be committed to source code. 

169 "sasl.username": self.username, 

170 "sasl.password": self.password, 

171 } 

172 

173 self._server_check() 

174 self.producer = confluent_kafka.Producer(**self.kafkaConfig) 

175 

176 else: 

177 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in " 

178 "the environment. Alerts will not be sent to the alert stream.") 

179 

180 @timeMethod 

181 def run(self, 

182 diaSourceCat, 

183 diaObjectCat, 

184 diaSrcHistory, 

185 diaForcedSources, 

186 diffIm, 

187 calexp, 

188 template, 

189 doRunForcedMeasurement=True, 

190 ): 

191 """Package DiaSources/Object and exposure data into Avro alerts. 

192 

193 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set 

194 and written to disk if ``doWriteAlerts`` is set. Both can be set at the 

195 same time, and are independent of one another. 

196 

197 Writes Avro alerts to a location determined by the 

198 ``alertWriteLocation`` configurable. 

199 

200 Parameters 

201 ---------- 

202 diaSourceCat : `pandas.DataFrame` 

203 New DiaSources to package. DataFrame should be indexed on 

204 ``["diaObjectId", "band", "diaSourceId"]`` 

205 diaObjectCat : `pandas.DataFrame` 

206 New and updated DiaObjects matched to the new DiaSources. DataFrame 

207 is indexed on ``["diaObjectId"]`` 

208 diaSrcHistory : `pandas.DataFrame` 

209 12 month history of DiaSources matched to the DiaObjects. Excludes 

210 the newest DiaSource and is indexed on 

211 ``["diaObjectId", "band", "diaSourceId"]`` 

212 diaForcedSources : `pandas.DataFrame` 

213 12 month history of DiaForcedSources matched to the DiaObjects. 

214 ``["diaObjectId"]`` 

215 diffIm : `lsst.afw.image.ExposureF` 

216 Difference image the sources in ``diaSourceCat`` were detected in. 

217 calexp : `lsst.afw.image.ExposureF` 

218 Calexp used to create the ``diffIm``. 

219 template : `lsst.afw.image.ExposureF` or `None` 

220 Template image used to create the ``diffIm``. 

221 doRunForcedMeasurement : `bool`, optional 

222 Flag to indicate whether forced measurement was run. 

223 This should only be turned off for debugging purposes. 

224 Added to allow disabling forced sources for performance 

225 reasons during the ops rehearsal. 

226 """ 

227 alerts = [] 

228 self._patchDiaSources(diaSourceCat) 

229 self._patchDiaSources(diaSrcHistory) 

230 ccdVisitId = diffIm.info.id 

231 diffImPhotoCalib = diffIm.getPhotoCalib() 

232 calexpPhotoCalib = calexp.getPhotoCalib() 

233 templatePhotoCalib = template.getPhotoCalib() 

234 for srcIndex, diaSource in diaSourceCat.iterrows(): 

235 # Get all diaSources for the associated diaObject. 

236 # TODO: DM-31992 skip DiaSources associated with Solar System 

237 # Objects for now. 

238 if srcIndex[0] == 0: 

239 continue 

240 diaObject = diaObjectCat.loc[srcIndex[0]] 

241 if diaObject["nDiaSources"] > 1: 

242 objSourceHistory = diaSrcHistory.loc[srcIndex[0]] 

243 else: 

244 objSourceHistory = None 

245 if doRunForcedMeasurement: 

246 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]] 

247 else: 

248 # Send empty table with correct columns 

249 objDiaForcedSources = diaForcedSources.loc[[]] 

250 sphPoint = geom.SpherePoint(diaSource["ra"], 

251 diaSource["dec"], 

252 geom.degrees) 

253 

254 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"]) 

255 diffImCutout = self.createCcdDataCutout( 

256 diffIm, 

257 sphPoint, 

258 cutoutExtent, 

259 diffImPhotoCalib, 

260 diaSource["diaSourceId"]) 

261 calexpCutout = self.createCcdDataCutout( 

262 calexp, 

263 sphPoint, 

264 cutoutExtent, 

265 calexpPhotoCalib, 

266 diaSource["diaSourceId"]) 

267 templateCutout = self.createCcdDataCutout( 

268 template, 

269 sphPoint, 

270 cutoutExtent, 

271 templatePhotoCalib, 

272 diaSource["diaSourceId"]) 

273 

274 # TODO: Create alertIds DM-24858 

275 alertId = diaSource["diaSourceId"] 

276 alerts.append( 

277 self.makeAlertDict(alertId, 

278 diaSource, 

279 diaObject, 

280 objSourceHistory, 

281 objDiaForcedSources, 

282 diffImCutout, 

283 calexpCutout, 

284 templateCutout)) 

285 

286 if self.config.doProduceAlerts: 

287 self.produceAlerts(alerts, ccdVisitId) 

288 

289 if self.config.doWriteAlerts: 

290 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f: 

291 self.alertSchema.store_alerts(f, alerts) 

292 

293 def _patchDiaSources(self, diaSources): 

294 """Add the ``programId`` column to the data. 

295 

296 Parameters 

297 ---------- 

298 diaSources : `pandas.DataFrame` 

299 DataFrame of DiaSources to patch. 

300 """ 

301 diaSources["programId"] = 0 

302 

303 def createDiaSourceExtent(self, bboxSize): 

304 """Create an extent for a box for the cutouts given the size of the 

305 square BBox that covers the source footprint. 

306 

307 Parameters 

308 ---------- 

309 bboxSize : `int` 

310 Size of a side of the square bounding box in pixels. 

311 

312 Returns 

313 ------- 

314 extent : `lsst.geom.Extent2I` 

315 Geom object representing the size of the bounding box. 

316 """ 

317 if bboxSize < self.config.minCutoutSize: 

318 extent = geom.Extent2I(self.config.minCutoutSize, 

319 self.config.minCutoutSize) 

320 else: 

321 extent = geom.Extent2I(bboxSize, bboxSize) 

322 return extent 

323 

324 def produceAlerts(self, alerts, ccdVisitId): 

325 """Serialize alerts and send them to the alert stream using 

326 confluent_kafka's producer. 

327 

328 Parameters 

329 ---------- 

330 alerts : `dict` 

331 Dictionary of alerts to be sent to the alert stream. 

332 ccdVisitId : `int` 

333 ccdVisitId of the alerts sent to the alert stream. Used to write 

334 out alerts which fail to be sent to the alert stream. 

335 """ 

336 for alert in alerts: 

337 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1) 

338 try: 

339 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback) 

340 self.producer.flush() 

341 

342 except KafkaException as e: 

343 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes))) 

344 

345 if self.config.doWriteFailedAlerts: 

346 with open(os.path.join(self.config.alertWriteLocation, 

347 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f: 

348 f.write(alertBytes) 

349 

350 self.producer.flush() 

351 

352 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId): 

353 """Grab an image as a cutout and return a calibrated CCDData image. 

354 

355 Parameters 

356 ---------- 

357 image : `lsst.afw.image.ExposureF` 

358 Image to pull cutout from. 

359 skyCenter : `lsst.geom.SpherePoint` 

360 Center point of DiaSource on the sky. 

361 extent : `lsst.geom.Extent2I` 

362 Bounding box to cutout from the image. 

363 photoCalib : `lsst.afw.image.PhotoCalib` 

364 Calibrate object of the image the cutout is cut from. 

365 srcId : `int` 

366 Unique id of DiaSource. Used for when an error occurs extracting 

367 a cutout. 

368 

369 Returns 

370 ------- 

371 ccdData : `astropy.nddata.CCDData` or `None` 

372 CCDData object storing the calibrate information from the input 

373 difference or template image. 

374 """ 

375 # Catch errors in retrieving the cutout. 

376 try: 

377 cutout = image.getCutout(skyCenter, extent) 

378 except InvalidParameterError: 

379 point = image.getWcs().skyToPixel(skyCenter) 

380 imBBox = image.getBBox() 

381 if not geom.Box2D(image.getBBox()).contains(point): 

382 self.log.warning( 

383 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) " 

384 "which is outside the Exposure with bounding box " 

385 "((%i, %i), (%i, %i)). Returning None for cutout...", 

386 srcId, point.x, point.y, 

387 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY) 

388 else: 

389 raise InvalidParameterError( 

390 "Failed to retrieve cutout from image for DiaSource with " 

391 "id=%i. InvalidParameterError thrown during cutout " 

392 "creation. Exiting." 

393 % srcId) 

394 return None 

395 

396 # Find the value of the bottom corner of our cutout's BBox and 

397 # subtract 1 so that the CCDData cutout position value will be 

398 # [1, 1]. 

399 cutOutMinX = cutout.getBBox().minX - 1 

400 cutOutMinY = cutout.getBBox().minY - 1 

401 center = cutout.getWcs().skyToPixel(skyCenter) 

402 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage()) 

403 

404 cutoutWcs = wcs.WCS(naxis=2) 

405 cutoutWcs.array_shape = (cutout.getBBox().getWidth(), 

406 cutout.getBBox().getWidth()) 

407 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY] 

408 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(), 

409 skyCenter.getDec().asDegrees()] 

410 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(), 

411 center, 

412 skyCenter) 

413 

414 return CCDData( 

415 data=calibCutout.getImage().array, 

416 uncertainty=VarianceUncertainty(calibCutout.getVariance().array), 

417 flags=calibCutout.getMask().array, 

418 wcs=cutoutWcs, 

419 meta={"cutMinX": cutOutMinX, 

420 "cutMinY": cutOutMinY}, 

421 unit=u.nJy) 

422 

423 def makeLocalTransformMatrix(self, wcs, center, skyCenter): 

424 """Create a local, linear approximation of the wcs transformation 

425 matrix. 

426 

427 The approximation is created as if the center is at RA=0, DEC=0. All 

428 comparing x,y coordinate are relative to the position of center. Matrix 

429 is initially calculated with units arcseconds and then converted to 

430 degrees. This yields higher precision results due to quirks in AST. 

431 

432 Parameters 

433 ---------- 

434 wcs : `lsst.afw.geom.SkyWcs` 

435 Wcs to approximate 

436 center : `lsst.geom.Point2D` 

437 Point at which to evaluate the LocalWcs. 

438 skyCenter : `lsst.geom.SpherePoint` 

439 Point on sky to approximate the Wcs. 

440 

441 Returns 

442 ------- 

443 localMatrix : `numpy.ndarray` 

444 Matrix representation the local wcs approximation with units 

445 degrees. 

446 """ 

447 blankCDMatrix = [[self._scale, 0], [0, self._scale]] 

448 localGnomonicWcs = afwGeom.makeSkyWcs( 

449 center, skyCenter, blankCDMatrix) 

450 measurementToLocalGnomonic = wcs.getTransform().then( 

451 localGnomonicWcs.getTransform().inverted() 

452 ) 

453 localMatrix = measurementToLocalGnomonic.getJacobian(center) 

454 return localMatrix / 3600 

455 

456 def makeAlertDict(self, 

457 alertId, 

458 diaSource, 

459 diaObject, 

460 objDiaSrcHistory, 

461 objDiaForcedSources, 

462 diffImCutout, 

463 calexpCutout, 

464 templateCutout): 

465 """Convert data and package into a dictionary alert. 

466 

467 Parameters 

468 ---------- 

469 diaSource : `pandas.DataFrame` 

470 New single DiaSource to package. 

471 diaObject : `pandas.DataFrame` 

472 DiaObject that ``diaSource`` is matched to. 

473 objDiaSrcHistory : `pandas.DataFrame` 

474 12 month history of ``diaObject`` excluding the latest DiaSource. 

475 objDiaForcedSources : `pandas.DataFrame` 

476 12 month history of ``diaObject`` forced measurements. 

477 diffImCutout : `astropy.nddata.CCDData` or `None` 

478 Cutout of the difference image around the location of ``diaSource`` 

479 with a min size set by the ``cutoutSize`` configurable. 

480 calexpCutout : `astropy.nddata.CCDData` or `None` 

481 Cutout of the calexp around the location of ``diaSource`` 

482 with a min size set by the ``cutoutSize`` configurable. 

483 templateCutout : `astropy.nddata.CCDData` or `None` 

484 Cutout of the template image around the location of ``diaSource`` 

485 with a min size set by the ``cutoutSize`` configurable. 

486 """ 

487 alert = dict() 

488 alert['alertId'] = alertId 

489 alert['diaSource'] = diaSource.to_dict() 

490 

491 if objDiaSrcHistory is None: 

492 alert['prvDiaSources'] = objDiaSrcHistory 

493 else: 

494 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records") 

495 

496 if isinstance(objDiaForcedSources, pd.Series): 

497 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()] 

498 else: 

499 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records") 

500 alert['prvDiaNondetectionLimits'] = None 

501 

502 alert['diaObject'] = diaObject.to_dict() 

503 

504 alert['ssObject'] = None 

505 

506 if diffImCutout is None: 

507 alert['cutoutDifference'] = None 

508 else: 

509 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout) 

510 

511 if calexpCutout is None: 

512 alert['cutoutScience'] = None 

513 else: 

514 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout) 

515 

516 if templateCutout is None: 

517 alert["cutoutTemplate"] = None 

518 else: 

519 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout) 

520 

521 return alert 

522 

523 def streamCcdDataToBytes(self, cutout): 

524 """Serialize a cutout into bytes. 

525 

526 Parameters 

527 ---------- 

528 cutout : `astropy.nddata.CCDData` 

529 Cutout to serialize. 

530 

531 Returns 

532 ------- 

533 coutputBytes : `bytes` 

534 Input cutout serialized into byte data. 

535 """ 

536 with io.BytesIO() as streamer: 

537 cutout.write(streamer, format="fits") 

538 cutoutBytes = streamer.getvalue() 

539 return cutoutBytes 

540 

541 def _serializeAlert(self, alert, schema=None, schema_id=0): 

542 """Serialize an alert to a byte sequence for sending to Kafka. 

543 

544 Parameters 

545 ---------- 

546 alert : `dict` 

547 An alert payload to be serialized. 

548 schema : `dict`, optional 

549 An Avro schema definition describing how to encode `alert`. By default, 

550 the schema is None, which sets it to the latest schema available. 

551 schema_id : `int`, optional 

552 The Confluent Schema Registry ID of the schema. By default, 0 (an 

553 invalid ID) is used, indicating that the schema is not registered. 

554 

555 Returns 

556 ------- 

557 serialized : `bytes` 

558 The byte sequence describing the alert, including the Confluent Wire 

559 Format prefix. 

560 """ 

561 if schema is None: 

562 schema = self.alertSchema.definition 

563 

564 buf = io.BytesIO() 

565 # TODO: Use a proper schema versioning system (DM-42606) 

566 buf.write(self._serializeConfluentWireHeader(schema_id)) 

567 fastavro.schemaless_writer(buf, schema, alert) 

568 return buf.getvalue() 

569 

570 @staticmethod 

571 def _serializeConfluentWireHeader(schema_version): 

572 """Returns the byte prefix for Confluent Wire Format-style Kafka messages. 

573 

574 Parameters 

575 ---------- 

576 schema_version : `int` 

577 A version number which indicates the Confluent Schema Registry ID 

578 number of the Avro schema used to encode the message that follows this 

579 header. 

580 

581 Returns 

582 ------- 

583 header : `bytes` 

584 The 5-byte encoded message prefix. 

585 

586 Notes 

587 ----- 

588 The Confluent Wire Format is described more fully here: 

589 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format 

590 """ 

591 ConfluentWireFormatHeader = struct.Struct(">bi") 

592 return ConfluentWireFormatHeader.pack(0, schema_version) 

593 

594 def _delivery_callback(self, err, msg): 

595 if err: 

596 self.log.warning('Message failed delivery: %s\n' % err) 

597 else: 

598 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset()) 

599 

600 def _server_check(self): 

601 """Checks if the alert stream credentials are still valid and the 

602 server is contactable. 

603 

604 Raises 

605 ------- 

606 KafkaException 

607 Raised if the server us not contactable. 

608 RuntimeError 

609 Raised if the server is contactable but there are no topics 

610 present. 

611 """ 

612 admin_client = AdminClient(self.kafkaAdminConfig) 

613 topics = admin_client.list_topics(timeout=self.config.maxTimeout).topics 

614 

615 if not topics: 

616 raise RuntimeError()