Coverage for python/lsst/ap/association/packageAlerts.py: 21%

180 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-26 03:48 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("PackageAlertsConfig", "PackageAlertsTask") 

23 

24import io 

25import os 

26import sys 

27 

28from astropy import wcs 

29import astropy.units as u 

30from astropy.nddata import CCDData, VarianceUncertainty 

31import pandas as pd 

32import struct 

33import fastavro 

34# confluent_kafka is not in the standard Rubin environment as it is a third 

35# party package and is only needed when producing alerts. 

36try: 

37 import confluent_kafka 

38 from confluent_kafka import KafkaException 

39 from confluent_kafka.admin import AdminClient 

40except ImportError: 

41 confluent_kafka = None 

42 

43import lsst.alert.packet as alertPack 

44import lsst.afw.geom as afwGeom 

45import lsst.geom as geom 

46import lsst.pex.config as pexConfig 

47from lsst.pex.exceptions import InvalidParameterError 

48import lsst.pipe.base as pipeBase 

49from lsst.utils.timer import timeMethod 

50 

51 

52class PackageAlertsConfig(pexConfig.Config): 

53 """Config class for AssociationTask. 

54 """ 

55 schemaFile = pexConfig.Field( 

56 dtype=str, 

57 doc="Schema definition file URI for the avro alerts.", 

58 default=str(alertPack.get_uri_to_latest_schema()) 

59 ) 

60 minCutoutSize = pexConfig.RangeField( 

61 dtype=int, 

62 min=0, 

63 max=1000, 

64 default=30, 

65 doc="Dimension of the square image cutouts to package in the alert." 

66 ) 

67 alertWriteLocation = pexConfig.Field( 

68 dtype=str, 

69 doc="Location to write alerts to.", 

70 default=os.path.join(os.getcwd(), "alerts"), 

71 ) 

72 

73 doProduceAlerts = pexConfig.Field( 

74 dtype=bool, 

75 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.", 

76 default=False, 

77 ) 

78 

79 doWriteAlerts = pexConfig.Field( 

80 dtype=bool, 

81 doc="Write alerts to disk if true.", 

82 default=False, 

83 ) 

84 

85 

86class PackageAlertsTask(pipeBase.Task): 

87 """Tasks for packaging Dia and Pipelines data into Avro alert packages. 

88 """ 

89 ConfigClass = PackageAlertsConfig 

90 _DefaultName = "packageAlerts" 

91 

92 _scale = (1.0 * geom.arcseconds).asDegrees() 

93 

94 def __init__(self, **kwargs): 

95 super().__init__(**kwargs) 

96 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile) 

97 os.makedirs(self.config.alertWriteLocation, exist_ok=True) 

98 

99 if self.config.doProduceAlerts: 

100 if confluent_kafka is not None: 

101 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") 

102 if not self.password: 

103 raise ValueError("Kafka password environment variable was not set.") 

104 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") 

105 if not self.username: 

106 raise ValueError("Kafka username environment variable was not set.") 

107 self.server = os.getenv("AP_KAFKA_SERVER") 

108 if not self.server: 

109 raise ValueError("Kafka server environment variable was not set.") 

110 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC") 

111 if not self.kafkaTopic: 

112 raise ValueError("Kafka topic environment variable was not set.") 

113 

114 # confluent_kafka configures all of its classes with dictionaries. This one 

115 # sets up the bare minimum that is needed. 

116 self.kafkaConfig = { 

117 # This is the URL to use to connect to the Kafka cluster. 

118 "bootstrap.servers": self.server, 

119 # These next two properties tell the Kafka client about the specific 

120 # authentication and authorization protocols that should be used when 

121 # connecting. 

122 "security.protocol": "SASL_PLAINTEXT", 

123 "sasl.mechanisms": "SCRAM-SHA-512", 

124 # The sasl.username and sasl.password are passed through over 

125 # SCRAM-SHA-512 auth to connect to the cluster. The username is not 

126 # sensitive, but the password is (of course) a secret value which 

127 # should never be committed to source code. 

128 "sasl.username": self.username, 

129 "sasl.password": self.password, 

130 # Batch size limits the largest size of a kafka alert that can be sent. 

131 # We set the batch size to 2 Mb. 

132 "batch.size": 2097152, 

133 "linger.ms": 5, 

134 } 

135 self.kafkaAdminConfig = { 

136 # This is the URL to use to connect to the Kafka cluster. 

137 "bootstrap.servers": self.server, 

138 # These next two properties tell the Kafka client about the specific 

139 # authentication and authorization protocols that should be used when 

140 # connecting. 

141 "security.protocol": "SASL_PLAINTEXT", 

142 "sasl.mechanisms": "SCRAM-SHA-512", 

143 # The sasl.username and sasl.password are passed through over 

144 # SCRAM-SHA-512 auth to connect to the cluster. The username is not 

145 # sensitive, but the password is (of course) a secret value which 

146 # should never be committed to source code. 

147 "sasl.username": self.username, 

148 "sasl.password": self.password, 

149 } 

150 

151 self._server_check() 

152 self.producer = confluent_kafka.Producer(**self.kafkaConfig) 

153 

154 else: 

155 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in " 

156 "the environment. Alerts will not be sent to the alert stream.") 

157 

158 @timeMethod 

159 def run(self, 

160 diaSourceCat, 

161 diaObjectCat, 

162 diaSrcHistory, 

163 diaForcedSources, 

164 diffIm, 

165 calexp, 

166 template, 

167 doRunForcedMeasurement=True, 

168 ): 

169 """Package DiaSources/Object and exposure data into Avro alerts. 

170 

171 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set 

172 and written to disk if ``doWriteAlerts`` is set. Both can be set at the 

173 same time, and are independent of one another. 

174 

175 Writes Avro alerts to a location determined by the 

176 ``alertWriteLocation`` configurable. 

177 

178 Parameters 

179 ---------- 

180 diaSourceCat : `pandas.DataFrame` 

181 New DiaSources to package. DataFrame should be indexed on 

182 ``["diaObjectId", "band", "diaSourceId"]`` 

183 diaObjectCat : `pandas.DataFrame` 

184 New and updated DiaObjects matched to the new DiaSources. DataFrame 

185 is indexed on ``["diaObjectId"]`` 

186 diaSrcHistory : `pandas.DataFrame` 

187 12 month history of DiaSources matched to the DiaObjects. Excludes 

188 the newest DiaSource and is indexed on 

189 ``["diaObjectId", "band", "diaSourceId"]`` 

190 diaForcedSources : `pandas.DataFrame` 

191 12 month history of DiaForcedSources matched to the DiaObjects. 

192 ``["diaObjectId"]`` 

193 diffIm : `lsst.afw.image.ExposureF` 

194 Difference image the sources in ``diaSourceCat`` were detected in. 

195 calexp : `lsst.afw.image.ExposureF` 

196 Calexp used to create the ``diffIm``. 

197 template : `lsst.afw.image.ExposureF` or `None` 

198 Template image used to create the ``diffIm``. 

199 doRunForcedMeasurement : `bool`, optional 

200 Flag to indicate whether forced measurement was run. 

201 This should only be turned off for debugging purposes. 

202 Added to allow disabling forced sources for performance 

203 reasons during the ops rehearsal. 

204 """ 

205 alerts = [] 

206 self._patchDiaSources(diaSourceCat) 

207 self._patchDiaSources(diaSrcHistory) 

208 ccdVisitId = diffIm.info.id 

209 diffImPhotoCalib = diffIm.getPhotoCalib() 

210 calexpPhotoCalib = calexp.getPhotoCalib() 

211 templatePhotoCalib = template.getPhotoCalib() 

212 for srcIndex, diaSource in diaSourceCat.iterrows(): 

213 # Get all diaSources for the associated diaObject. 

214 # TODO: DM-31992 skip DiaSources associated with Solar System 

215 # Objects for now. 

216 if srcIndex[0] == 0: 

217 continue 

218 diaObject = diaObjectCat.loc[srcIndex[0]] 

219 if diaObject["nDiaSources"] > 1: 

220 objSourceHistory = diaSrcHistory.loc[srcIndex[0]] 

221 else: 

222 objSourceHistory = None 

223 if doRunForcedMeasurement: 

224 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]] 

225 else: 

226 # Send empty table with correct columns 

227 objDiaForcedSources = diaForcedSources.loc[[]] 

228 sphPoint = geom.SpherePoint(diaSource["ra"], 

229 diaSource["dec"], 

230 geom.degrees) 

231 

232 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"]) 

233 diffImCutout = self.createCcdDataCutout( 

234 diffIm, 

235 sphPoint, 

236 cutoutExtent, 

237 diffImPhotoCalib, 

238 diaSource["diaSourceId"]) 

239 calexpCutout = self.createCcdDataCutout( 

240 calexp, 

241 sphPoint, 

242 cutoutExtent, 

243 calexpPhotoCalib, 

244 diaSource["diaSourceId"]) 

245 templateCutout = self.createCcdDataCutout( 

246 template, 

247 sphPoint, 

248 cutoutExtent, 

249 templatePhotoCalib, 

250 diaSource["diaSourceId"]) 

251 

252 # TODO: Create alertIds DM-24858 

253 alertId = diaSource["diaSourceId"] 

254 alerts.append( 

255 self.makeAlertDict(alertId, 

256 diaSource, 

257 diaObject, 

258 objSourceHistory, 

259 objDiaForcedSources, 

260 diffImCutout, 

261 calexpCutout, 

262 templateCutout)) 

263 

264 if self.config.doProduceAlerts: 

265 self.produceAlerts(alerts, ccdVisitId) 

266 

267 if self.config.doWriteAlerts: 

268 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f: 

269 self.alertSchema.store_alerts(f, alerts) 

270 

271 def _patchDiaSources(self, diaSources): 

272 """Add the ``programId`` column to the data. 

273 

274 Parameters 

275 ---------- 

276 diaSources : `pandas.DataFrame` 

277 DataFrame of DiaSources to patch. 

278 """ 

279 diaSources["programId"] = 0 

280 

281 def createDiaSourceExtent(self, bboxSize): 

282 """Create an extent for a box for the cutouts given the size of the 

283 square BBox that covers the source footprint. 

284 

285 Parameters 

286 ---------- 

287 bboxSize : `int` 

288 Size of a side of the square bounding box in pixels. 

289 

290 Returns 

291 ------- 

292 extent : `lsst.geom.Extent2I` 

293 Geom object representing the size of the bounding box. 

294 """ 

295 if bboxSize < self.config.minCutoutSize: 

296 extent = geom.Extent2I(self.config.minCutoutSize, 

297 self.config.minCutoutSize) 

298 else: 

299 extent = geom.Extent2I(bboxSize, bboxSize) 

300 return extent 

301 

302 def produceAlerts(self, alerts, ccdVisitId): 

303 """Serialize alerts and send them to the alert stream using 

304 confluent_kafka's producer. 

305 

306 Parameters 

307 ---------- 

308 alerts : `dict` 

309 Dictionary of alerts to be sent to the alert stream. 

310 ccdVisitId : `int` 

311 ccdVisitId of the alerts sent to the alert stream. Used to write 

312 out alerts which fail to be sent to the alert stream. 

313 """ 

314 self._server_check() 

315 for alert in alerts: 

316 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1) 

317 try: 

318 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback) 

319 self.producer.flush() 

320 

321 except KafkaException as e: 

322 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes))) 

323 with open(os.path.join(self.config.alertWriteLocation, 

324 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f: 

325 f.write(alertBytes) 

326 

327 self.producer.flush() 

328 

329 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId): 

330 """Grab an image as a cutout and return a calibrated CCDData image. 

331 

332 Parameters 

333 ---------- 

334 image : `lsst.afw.image.ExposureF` 

335 Image to pull cutout from. 

336 skyCenter : `lsst.geom.SpherePoint` 

337 Center point of DiaSource on the sky. 

338 extent : `lsst.geom.Extent2I` 

339 Bounding box to cutout from the image. 

340 photoCalib : `lsst.afw.image.PhotoCalib` 

341 Calibrate object of the image the cutout is cut from. 

342 srcId : `int` 

343 Unique id of DiaSource. Used for when an error occurs extracting 

344 a cutout. 

345 

346 Returns 

347 ------- 

348 ccdData : `astropy.nddata.CCDData` or `None` 

349 CCDData object storing the calibrate information from the input 

350 difference or template image. 

351 """ 

352 # Catch errors in retrieving the cutout. 

353 try: 

354 cutout = image.getCutout(skyCenter, extent) 

355 except InvalidParameterError: 

356 point = image.getWcs().skyToPixel(skyCenter) 

357 imBBox = image.getBBox() 

358 if not geom.Box2D(image.getBBox()).contains(point): 

359 self.log.warning( 

360 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) " 

361 "which is outside the Exposure with bounding box " 

362 "((%i, %i), (%i, %i)). Returning None for cutout...", 

363 srcId, point.x, point.y, 

364 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY) 

365 else: 

366 raise InvalidParameterError( 

367 "Failed to retrieve cutout from image for DiaSource with " 

368 "id=%i. InvalidParameterError thrown during cutout " 

369 "creation. Exiting." 

370 % srcId) 

371 return None 

372 

373 # Find the value of the bottom corner of our cutout's BBox and 

374 # subtract 1 so that the CCDData cutout position value will be 

375 # [1, 1]. 

376 cutOutMinX = cutout.getBBox().minX - 1 

377 cutOutMinY = cutout.getBBox().minY - 1 

378 center = cutout.getWcs().skyToPixel(skyCenter) 

379 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage()) 

380 

381 cutoutWcs = wcs.WCS(naxis=2) 

382 cutoutWcs.array_shape = (cutout.getBBox().getWidth(), 

383 cutout.getBBox().getWidth()) 

384 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY] 

385 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(), 

386 skyCenter.getDec().asDegrees()] 

387 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(), 

388 center, 

389 skyCenter) 

390 

391 return CCDData( 

392 data=calibCutout.getImage().array, 

393 uncertainty=VarianceUncertainty(calibCutout.getVariance().array), 

394 flags=calibCutout.getMask().array, 

395 wcs=cutoutWcs, 

396 meta={"cutMinX": cutOutMinX, 

397 "cutMinY": cutOutMinY}, 

398 unit=u.nJy) 

399 

400 def makeLocalTransformMatrix(self, wcs, center, skyCenter): 

401 """Create a local, linear approximation of the wcs transformation 

402 matrix. 

403 

404 The approximation is created as if the center is at RA=0, DEC=0. All 

405 comparing x,y coordinate are relative to the position of center. Matrix 

406 is initially calculated with units arcseconds and then converted to 

407 degrees. This yields higher precision results due to quirks in AST. 

408 

409 Parameters 

410 ---------- 

411 wcs : `lsst.afw.geom.SkyWcs` 

412 Wcs to approximate 

413 center : `lsst.geom.Point2D` 

414 Point at which to evaluate the LocalWcs. 

415 skyCenter : `lsst.geom.SpherePoint` 

416 Point on sky to approximate the Wcs. 

417 

418 Returns 

419 ------- 

420 localMatrix : `numpy.ndarray` 

421 Matrix representation the local wcs approximation with units 

422 degrees. 

423 """ 

424 blankCDMatrix = [[self._scale, 0], [0, self._scale]] 

425 localGnomonicWcs = afwGeom.makeSkyWcs( 

426 center, skyCenter, blankCDMatrix) 

427 measurementToLocalGnomonic = wcs.getTransform().then( 

428 localGnomonicWcs.getTransform().inverted() 

429 ) 

430 localMatrix = measurementToLocalGnomonic.getJacobian(center) 

431 return localMatrix / 3600 

432 

433 def makeAlertDict(self, 

434 alertId, 

435 diaSource, 

436 diaObject, 

437 objDiaSrcHistory, 

438 objDiaForcedSources, 

439 diffImCutout, 

440 calexpCutout, 

441 templateCutout): 

442 """Convert data and package into a dictionary alert. 

443 

444 Parameters 

445 ---------- 

446 diaSource : `pandas.DataFrame` 

447 New single DiaSource to package. 

448 diaObject : `pandas.DataFrame` 

449 DiaObject that ``diaSource`` is matched to. 

450 objDiaSrcHistory : `pandas.DataFrame` 

451 12 month history of ``diaObject`` excluding the latest DiaSource. 

452 objDiaForcedSources : `pandas.DataFrame` 

453 12 month history of ``diaObject`` forced measurements. 

454 diffImCutout : `astropy.nddata.CCDData` or `None` 

455 Cutout of the difference image around the location of ``diaSource`` 

456 with a min size set by the ``cutoutSize`` configurable. 

457 calexpCutout : `astropy.nddata.CCDData` or `None` 

458 Cutout of the calexp around the location of ``diaSource`` 

459 with a min size set by the ``cutoutSize`` configurable. 

460 templateCutout : `astropy.nddata.CCDData` or `None` 

461 Cutout of the template image around the location of ``diaSource`` 

462 with a min size set by the ``cutoutSize`` configurable. 

463 """ 

464 alert = dict() 

465 alert['alertId'] = alertId 

466 alert['diaSource'] = diaSource.to_dict() 

467 

468 if objDiaSrcHistory is None: 

469 alert['prvDiaSources'] = objDiaSrcHistory 

470 else: 

471 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records") 

472 

473 if isinstance(objDiaForcedSources, pd.Series): 

474 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()] 

475 else: 

476 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records") 

477 alert['prvDiaNondetectionLimits'] = None 

478 

479 alert['diaObject'] = diaObject.to_dict() 

480 

481 alert['ssObject'] = None 

482 

483 if diffImCutout is None: 

484 alert['cutoutDifference'] = None 

485 else: 

486 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout) 

487 

488 if calexpCutout is None: 

489 alert['cutoutScience'] = None 

490 else: 

491 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout) 

492 

493 if templateCutout is None: 

494 alert["cutoutTemplate"] = None 

495 else: 

496 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout) 

497 

498 return alert 

499 

500 def streamCcdDataToBytes(self, cutout): 

501 """Serialize a cutout into bytes. 

502 

503 Parameters 

504 ---------- 

505 cutout : `astropy.nddata.CCDData` 

506 Cutout to serialize. 

507 

508 Returns 

509 ------- 

510 coutputBytes : `bytes` 

511 Input cutout serialized into byte data. 

512 """ 

513 with io.BytesIO() as streamer: 

514 cutout.write(streamer, format="fits") 

515 cutoutBytes = streamer.getvalue() 

516 return cutoutBytes 

517 

518 def _serializeAlert(self, alert, schema=None, schema_id=0): 

519 """Serialize an alert to a byte sequence for sending to Kafka. 

520 

521 Parameters 

522 ---------- 

523 alert : `dict` 

524 An alert payload to be serialized. 

525 schema : `dict`, optional 

526 An Avro schema definition describing how to encode `alert`. By default, 

527 the schema is None, which sets it to the latest schema available. 

528 schema_id : `int`, optional 

529 The Confluent Schema Registry ID of the schema. By default, 0 (an 

530 invalid ID) is used, indicating that the schema is not registered. 

531 

532 Returns 

533 ------- 

534 serialized : `bytes` 

535 The byte sequence describing the alert, including the Confluent Wire 

536 Format prefix. 

537 """ 

538 if schema is None: 

539 schema = self.alertSchema.definition 

540 

541 buf = io.BytesIO() 

542 # TODO: Use a proper schema versioning system (DM-42606) 

543 buf.write(self._serializeConfluentWireHeader(schema_id)) 

544 fastavro.schemaless_writer(buf, schema, alert) 

545 return buf.getvalue() 

546 

547 @staticmethod 

548 def _serializeConfluentWireHeader(schema_version): 

549 """Returns the byte prefix for Confluent Wire Format-style Kafka messages. 

550 

551 Parameters 

552 ---------- 

553 schema_version : `int` 

554 A version number which indicates the Confluent Schema Registry ID 

555 number of the Avro schema used to encode the message that follows this 

556 header. 

557 

558 Returns 

559 ------- 

560 header : `bytes` 

561 The 5-byte encoded message prefix. 

562 

563 Notes 

564 ----- 

565 The Confluent Wire Format is described more fully here: 

566 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format 

567 """ 

568 ConfluentWireFormatHeader = struct.Struct(">bi") 

569 return ConfluentWireFormatHeader.pack(0, schema_version) 

570 

571 def _delivery_callback(self, err, msg): 

572 if err: 

573 self.log.warning('Message failed delivery: %s\n' % err) 

574 else: 

575 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset()) 

576 

577 def _server_check(self): 

578 """Checks if the alert stream credentials are still valid and the 

579 server is contactable. 

580 

581 Raises 

582 ------- 

583 KafkaException 

584 Raised if the server us not contactable. 

585 RuntimeError 

586 Raised if the server is contactable but there are no topics 

587 present. 

588 """ 

589 admin_client = AdminClient(self.kafkaAdminConfig) 

590 topics = admin_client.list_topics(timeout=0.5).topics 

591 

592 if not topics: 

593 raise RuntimeError()