Coverage for python/lsst/ap/association/packageAlerts.py: 22%

171 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-22 03:19 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("PackageAlertsConfig", "PackageAlertsTask") 

23 

24import io 

25import os 

26import sys 

27 

28from astropy import wcs 

29import astropy.units as u 

30from astropy.nddata import CCDData, VarianceUncertainty 

31import pandas as pd 

32import struct 

33import fastavro 

34# confluent_kafka is not in the standard Rubin environment as it is a third 

35# party package and is only needed when producing alerts. 

36try: 

37 import confluent_kafka 

38 from confluent_kafka import KafkaException 

39except ImportError: 

40 confluent_kafka = None 

41 

42import lsst.alert.packet as alertPack 

43import lsst.afw.geom as afwGeom 

44import lsst.geom as geom 

45import lsst.pex.config as pexConfig 

46from lsst.pex.exceptions import InvalidParameterError 

47import lsst.pipe.base as pipeBase 

48from lsst.utils.timer import timeMethod 

49 

50 

51class PackageAlertsConfig(pexConfig.Config): 

52 """Config class for AssociationTask. 

53 """ 

54 schemaFile = pexConfig.Field( 

55 dtype=str, 

56 doc="Schema definition file URI for the avro alerts.", 

57 default=str(alertPack.get_uri_to_latest_schema()) 

58 ) 

59 minCutoutSize = pexConfig.RangeField( 

60 dtype=int, 

61 min=0, 

62 max=1000, 

63 default=30, 

64 doc="Dimension of the square image cutouts to package in the alert." 

65 ) 

66 alertWriteLocation = pexConfig.Field( 

67 dtype=str, 

68 doc="Location to write alerts to.", 

69 default=os.path.join(os.getcwd(), "alerts"), 

70 ) 

71 

72 doProduceAlerts = pexConfig.Field( 

73 dtype=bool, 

74 doc="Turn on alert production to kafka if true and if confluent_kafka is in the environment.", 

75 default=False, 

76 ) 

77 

78 doWriteAlerts = pexConfig.Field( 

79 dtype=bool, 

80 doc="Write alerts to disk if true.", 

81 default=False, 

82 ) 

83 

84 

85class PackageAlertsTask(pipeBase.Task): 

86 """Tasks for packaging Dia and Pipelines data into Avro alert packages. 

87 """ 

88 ConfigClass = PackageAlertsConfig 

89 _DefaultName = "packageAlerts" 

90 

91 _scale = (1.0 * geom.arcseconds).asDegrees() 

92 

93 def __init__(self, **kwargs): 

94 super().__init__(**kwargs) 

95 self.alertSchema = alertPack.Schema.from_uri(self.config.schemaFile) 

96 os.makedirs(self.config.alertWriteLocation, exist_ok=True) 

97 

98 if self.config.doProduceAlerts: 

99 if confluent_kafka is not None: 

100 self.password = os.getenv("AP_KAFKA_PRODUCER_PASSWORD") 

101 if not self.password: 

102 raise ValueError("Kafka password environment variable was not set.") 

103 self.username = os.getenv("AP_KAFKA_PRODUCER_USERNAME") 

104 if not self.username: 

105 raise ValueError("Kafka username environment variable was not set.") 

106 self.server = os.getenv("AP_KAFKA_SERVER") 

107 if not self.server: 

108 raise ValueError("Kafka server environment variable was not set.") 

109 self.kafkaTopic = os.getenv("AP_KAFKA_TOPIC") 

110 if not self.kafkaTopic: 

111 raise ValueError("Kafka topic environment variable was not set.") 

112 

113 # confluent_kafka configures all of its classes with dictionaries. This one 

114 # sets up the bare minimum that is needed. 

115 self.kafkaConfig = { 

116 # This is the URL to use to connect to the Kafka cluster. 

117 "bootstrap.servers": self.server, 

118 # These next two properties tell the Kafka client about the specific 

119 # authentication and authorization protocols that should be used when 

120 # connecting. 

121 "security.protocol": "SASL_PLAINTEXT", 

122 "sasl.mechanisms": "SCRAM-SHA-512", 

123 # The sasl.username and sasl.password are passed through over 

124 # SCRAM-SHA-512 auth to connect to the cluster. The username is not 

125 # sensitive, but the password is (of course) a secret value which 

126 # should never be committed to source code. 

127 "sasl.username": self.username, 

128 "sasl.password": self.password, 

129 # Batch size limits the largest size of a kafka alert that can be sent. 

130 # We set the batch size to 2 Mb. 

131 "batch.size": 2097152, 

132 "linger.ms": 5, 

133 } 

134 self.producer = confluent_kafka.Producer(**self.kafkaConfig) 

135 

136 else: 

137 raise RuntimeError("Produce alerts is set but confluent_kafka is not present in " 

138 "the environment. Alerts will not be sent to the alert stream.") 

139 

140 @timeMethod 

141 def run(self, 

142 diaSourceCat, 

143 diaObjectCat, 

144 diaSrcHistory, 

145 diaForcedSources, 

146 diffIm, 

147 calexp, 

148 template, 

149 doRunForcedMeasurement=True, 

150 ): 

151 """Package DiaSources/Object and exposure data into Avro alerts. 

152 

153 Alerts can be sent to the alert stream if ``doProduceAlerts`` is set 

154 and written to disk if ``doWriteAlerts`` is set. Both can be set at the 

155 same time, and are independent of one another. 

156 

157 Writes Avro alerts to a location determined by the 

158 ``alertWriteLocation`` configurable. 

159 

160 Parameters 

161 ---------- 

162 diaSourceCat : `pandas.DataFrame` 

163 New DiaSources to package. DataFrame should be indexed on 

164 ``["diaObjectId", "band", "diaSourceId"]`` 

165 diaObjectCat : `pandas.DataFrame` 

166 New and updated DiaObjects matched to the new DiaSources. DataFrame 

167 is indexed on ``["diaObjectId"]`` 

168 diaSrcHistory : `pandas.DataFrame` 

169 12 month history of DiaSources matched to the DiaObjects. Excludes 

170 the newest DiaSource and is indexed on 

171 ``["diaObjectId", "band", "diaSourceId"]`` 

172 diaForcedSources : `pandas.DataFrame` 

173 12 month history of DiaForcedSources matched to the DiaObjects. 

174 ``["diaObjectId"]`` 

175 diffIm : `lsst.afw.image.ExposureF` 

176 Difference image the sources in ``diaSourceCat`` were detected in. 

177 calexp : `lsst.afw.image.ExposureF` 

178 Calexp used to create the ``diffIm``. 

179 template : `lsst.afw.image.ExposureF` or `None` 

180 Template image used to create the ``diffIm``. 

181 doRunForcedMeasurement : `bool`, optional 

182 Flag to indicate whether forced measurement was run. 

183 This should only be turned off for debugging purposes. 

184 Added to allow disabling forced sources for performance 

185 reasons during the ops rehearsal. 

186 """ 

187 alerts = [] 

188 self._patchDiaSources(diaSourceCat) 

189 self._patchDiaSources(diaSrcHistory) 

190 ccdVisitId = diffIm.info.id 

191 diffImPhotoCalib = diffIm.getPhotoCalib() 

192 calexpPhotoCalib = calexp.getPhotoCalib() 

193 templatePhotoCalib = template.getPhotoCalib() 

194 for srcIndex, diaSource in diaSourceCat.iterrows(): 

195 # Get all diaSources for the associated diaObject. 

196 # TODO: DM-31992 skip DiaSources associated with Solar System 

197 # Objects for now. 

198 if srcIndex[0] == 0: 

199 continue 

200 diaObject = diaObjectCat.loc[srcIndex[0]] 

201 if diaObject["nDiaSources"] > 1: 

202 objSourceHistory = diaSrcHistory.loc[srcIndex[0]] 

203 else: 

204 objSourceHistory = None 

205 if doRunForcedMeasurement: 

206 objDiaForcedSources = diaForcedSources.loc[srcIndex[0]] 

207 else: 

208 # Send empty table with correct columns 

209 objDiaForcedSources = diaForcedSources.loc[[]] 

210 sphPoint = geom.SpherePoint(diaSource["ra"], 

211 diaSource["dec"], 

212 geom.degrees) 

213 

214 cutoutExtent = self.createDiaSourceExtent(diaSource["bboxSize"]) 

215 diffImCutout = self.createCcdDataCutout( 

216 diffIm, 

217 sphPoint, 

218 cutoutExtent, 

219 diffImPhotoCalib, 

220 diaSource["diaSourceId"]) 

221 calexpCutout = self.createCcdDataCutout( 

222 calexp, 

223 sphPoint, 

224 cutoutExtent, 

225 calexpPhotoCalib, 

226 diaSource["diaSourceId"]) 

227 templateCutout = self.createCcdDataCutout( 

228 template, 

229 sphPoint, 

230 cutoutExtent, 

231 templatePhotoCalib, 

232 diaSource["diaSourceId"]) 

233 

234 # TODO: Create alertIds DM-24858 

235 alertId = diaSource["diaSourceId"] 

236 alerts.append( 

237 self.makeAlertDict(alertId, 

238 diaSource, 

239 diaObject, 

240 objSourceHistory, 

241 objDiaForcedSources, 

242 diffImCutout, 

243 calexpCutout, 

244 templateCutout)) 

245 

246 if self.config.doProduceAlerts: 

247 self.produceAlerts(alerts, ccdVisitId) 

248 

249 if self.config.doWriteAlerts: 

250 with open(os.path.join(self.config.alertWriteLocation, f"{ccdVisitId}.avro"), "wb") as f: 

251 self.alertSchema.store_alerts(f, alerts) 

252 

253 def _patchDiaSources(self, diaSources): 

254 """Add the ``programId`` column to the data. 

255 

256 Parameters 

257 ---------- 

258 diaSources : `pandas.DataFrame` 

259 DataFrame of DiaSources to patch. 

260 """ 

261 diaSources["programId"] = 0 

262 

263 def createDiaSourceExtent(self, bboxSize): 

264 """Create an extent for a box for the cutouts given the size of the 

265 square BBox that covers the source footprint. 

266 

267 Parameters 

268 ---------- 

269 bboxSize : `int` 

270 Size of a side of the square bounding box in pixels. 

271 

272 Returns 

273 ------- 

274 extent : `lsst.geom.Extent2I` 

275 Geom object representing the size of the bounding box. 

276 """ 

277 if bboxSize < self.config.minCutoutSize: 

278 extent = geom.Extent2I(self.config.minCutoutSize, 

279 self.config.minCutoutSize) 

280 else: 

281 extent = geom.Extent2I(bboxSize, bboxSize) 

282 return extent 

283 

284 def produceAlerts(self, alerts, ccdVisitId): 

285 """Serialize alerts and send them to the alert stream using 

286 confluent_kafka's producer. 

287 

288 Parameters 

289 ---------- 

290 alerts : `dict` 

291 Dictionary of alerts to be sent to the alert stream. 

292 ccdVisitId : `int` 

293 ccdVisitId of the alerts sent to the alert stream. Used to write 

294 out alerts which fail to be sent to the alert stream. 

295 """ 

296 for alert in alerts: 

297 alertBytes = self._serializeAlert(alert, schema=self.alertSchema.definition, schema_id=1) 

298 try: 

299 self.producer.produce(self.kafkaTopic, alertBytes, callback=self._delivery_callback) 

300 self.producer.flush() 

301 

302 except KafkaException as e: 

303 self.log.warning('Kafka error: {}, message was {} bytes'.format(e, sys.getsizeof(alertBytes))) 

304 

305 with open(os.path.join(self.config.alertWriteLocation, 

306 f"{ccdVisitId}_{alert['alertId']}.avro"), "wb") as f: 

307 f.write(alertBytes) 

308 

309 self.producer.flush() 

310 

311 def createCcdDataCutout(self, image, skyCenter, extent, photoCalib, srcId): 

312 """Grab an image as a cutout and return a calibrated CCDData image. 

313 

314 Parameters 

315 ---------- 

316 image : `lsst.afw.image.ExposureF` 

317 Image to pull cutout from. 

318 skyCenter : `lsst.geom.SpherePoint` 

319 Center point of DiaSource on the sky. 

320 extent : `lsst.geom.Extent2I` 

321 Bounding box to cutout from the image. 

322 photoCalib : `lsst.afw.image.PhotoCalib` 

323 Calibrate object of the image the cutout is cut from. 

324 srcId : `int` 

325 Unique id of DiaSource. Used for when an error occurs extracting 

326 a cutout. 

327 

328 Returns 

329 ------- 

330 ccdData : `astropy.nddata.CCDData` or `None` 

331 CCDData object storing the calibrate information from the input 

332 difference or template image. 

333 """ 

334 # Catch errors in retrieving the cutout. 

335 try: 

336 cutout = image.getCutout(skyCenter, extent) 

337 except InvalidParameterError: 

338 point = image.getWcs().skyToPixel(skyCenter) 

339 imBBox = image.getBBox() 

340 if not geom.Box2D(image.getBBox()).contains(point): 

341 self.log.warning( 

342 "DiaSource id=%i centroid lies at pixel (%.2f, %.2f) " 

343 "which is outside the Exposure with bounding box " 

344 "((%i, %i), (%i, %i)). Returning None for cutout...", 

345 srcId, point.x, point.y, 

346 imBBox.minX, imBBox.maxX, imBBox.minY, imBBox.maxY) 

347 else: 

348 raise InvalidParameterError( 

349 "Failed to retrieve cutout from image for DiaSource with " 

350 "id=%i. InvalidParameterError thrown during cutout " 

351 "creation. Exiting." 

352 % srcId) 

353 return None 

354 

355 # Find the value of the bottom corner of our cutout's BBox and 

356 # subtract 1 so that the CCDData cutout position value will be 

357 # [1, 1]. 

358 cutOutMinX = cutout.getBBox().minX - 1 

359 cutOutMinY = cutout.getBBox().minY - 1 

360 center = cutout.getWcs().skyToPixel(skyCenter) 

361 calibCutout = photoCalib.calibrateImage(cutout.getMaskedImage()) 

362 

363 cutoutWcs = wcs.WCS(naxis=2) 

364 cutoutWcs.array_shape = (cutout.getBBox().getWidth(), 

365 cutout.getBBox().getWidth()) 

366 cutoutWcs.wcs.crpix = [center.x - cutOutMinX, center.y - cutOutMinY] 

367 cutoutWcs.wcs.crval = [skyCenter.getRa().asDegrees(), 

368 skyCenter.getDec().asDegrees()] 

369 cutoutWcs.wcs.cd = self.makeLocalTransformMatrix(cutout.getWcs(), 

370 center, 

371 skyCenter) 

372 

373 return CCDData( 

374 data=calibCutout.getImage().array, 

375 uncertainty=VarianceUncertainty(calibCutout.getVariance().array), 

376 flags=calibCutout.getMask().array, 

377 wcs=cutoutWcs, 

378 meta={"cutMinX": cutOutMinX, 

379 "cutMinY": cutOutMinY}, 

380 unit=u.nJy) 

381 

382 def makeLocalTransformMatrix(self, wcs, center, skyCenter): 

383 """Create a local, linear approximation of the wcs transformation 

384 matrix. 

385 

386 The approximation is created as if the center is at RA=0, DEC=0. All 

387 comparing x,y coordinate are relative to the position of center. Matrix 

388 is initially calculated with units arcseconds and then converted to 

389 degrees. This yields higher precision results due to quirks in AST. 

390 

391 Parameters 

392 ---------- 

393 wcs : `lsst.afw.geom.SkyWcs` 

394 Wcs to approximate 

395 center : `lsst.geom.Point2D` 

396 Point at which to evaluate the LocalWcs. 

397 skyCenter : `lsst.geom.SpherePoint` 

398 Point on sky to approximate the Wcs. 

399 

400 Returns 

401 ------- 

402 localMatrix : `numpy.ndarray` 

403 Matrix representation the local wcs approximation with units 

404 degrees. 

405 """ 

406 blankCDMatrix = [[self._scale, 0], [0, self._scale]] 

407 localGnomonicWcs = afwGeom.makeSkyWcs( 

408 center, skyCenter, blankCDMatrix) 

409 measurementToLocalGnomonic = wcs.getTransform().then( 

410 localGnomonicWcs.getTransform().inverted() 

411 ) 

412 localMatrix = measurementToLocalGnomonic.getJacobian(center) 

413 return localMatrix / 3600 

414 

415 def makeAlertDict(self, 

416 alertId, 

417 diaSource, 

418 diaObject, 

419 objDiaSrcHistory, 

420 objDiaForcedSources, 

421 diffImCutout, 

422 calexpCutout, 

423 templateCutout): 

424 """Convert data and package into a dictionary alert. 

425 

426 Parameters 

427 ---------- 

428 diaSource : `pandas.DataFrame` 

429 New single DiaSource to package. 

430 diaObject : `pandas.DataFrame` 

431 DiaObject that ``diaSource`` is matched to. 

432 objDiaSrcHistory : `pandas.DataFrame` 

433 12 month history of ``diaObject`` excluding the latest DiaSource. 

434 objDiaForcedSources : `pandas.DataFrame` 

435 12 month history of ``diaObject`` forced measurements. 

436 diffImCutout : `astropy.nddata.CCDData` or `None` 

437 Cutout of the difference image around the location of ``diaSource`` 

438 with a min size set by the ``cutoutSize`` configurable. 

439 calexpCutout : `astropy.nddata.CCDData` or `None` 

440 Cutout of the calexp around the location of ``diaSource`` 

441 with a min size set by the ``cutoutSize`` configurable. 

442 templateCutout : `astropy.nddata.CCDData` or `None` 

443 Cutout of the template image around the location of ``diaSource`` 

444 with a min size set by the ``cutoutSize`` configurable. 

445 """ 

446 alert = dict() 

447 alert['alertId'] = alertId 

448 alert['diaSource'] = diaSource.to_dict() 

449 

450 if objDiaSrcHistory is None: 

451 alert['prvDiaSources'] = objDiaSrcHistory 

452 else: 

453 alert['prvDiaSources'] = objDiaSrcHistory.to_dict("records") 

454 

455 if isinstance(objDiaForcedSources, pd.Series): 

456 alert['prvDiaForcedSources'] = [objDiaForcedSources.to_dict()] 

457 else: 

458 alert['prvDiaForcedSources'] = objDiaForcedSources.to_dict("records") 

459 alert['prvDiaNondetectionLimits'] = None 

460 

461 alert['diaObject'] = diaObject.to_dict() 

462 

463 alert['ssObject'] = None 

464 

465 if diffImCutout is None: 

466 alert['cutoutDifference'] = None 

467 else: 

468 alert['cutoutDifference'] = self.streamCcdDataToBytes(diffImCutout) 

469 

470 if calexpCutout is None: 

471 alert['cutoutScience'] = None 

472 else: 

473 alert['cutoutScience'] = self.streamCcdDataToBytes(calexpCutout) 

474 

475 if templateCutout is None: 

476 alert["cutoutTemplate"] = None 

477 else: 

478 alert["cutoutTemplate"] = self.streamCcdDataToBytes(templateCutout) 

479 

480 return alert 

481 

482 def streamCcdDataToBytes(self, cutout): 

483 """Serialize a cutout into bytes. 

484 

485 Parameters 

486 ---------- 

487 cutout : `astropy.nddata.CCDData` 

488 Cutout to serialize. 

489 

490 Returns 

491 ------- 

492 coutputBytes : `bytes` 

493 Input cutout serialized into byte data. 

494 """ 

495 with io.BytesIO() as streamer: 

496 cutout.write(streamer, format="fits") 

497 cutoutBytes = streamer.getvalue() 

498 return cutoutBytes 

499 

500 def _serializeAlert(self, alert, schema=None, schema_id=0): 

501 """Serialize an alert to a byte sequence for sending to Kafka. 

502 

503 Parameters 

504 ---------- 

505 alert : `dict` 

506 An alert payload to be serialized. 

507 schema : `dict`, optional 

508 An Avro schema definition describing how to encode `alert`. By default, 

509 the schema is None, which sets it to the latest schema available. 

510 schema_id : `int`, optional 

511 The Confluent Schema Registry ID of the schema. By default, 0 (an 

512 invalid ID) is used, indicating that the schema is not registered. 

513 

514 Returns 

515 ------- 

516 serialized : `bytes` 

517 The byte sequence describing the alert, including the Confluent Wire 

518 Format prefix. 

519 """ 

520 if schema is None: 

521 schema = self.alertSchema.definition 

522 

523 buf = io.BytesIO() 

524 # TODO: Use a proper schema versioning system (DM-42606) 

525 buf.write(self._serializeConfluentWireHeader(schema_id)) 

526 fastavro.schemaless_writer(buf, schema, alert) 

527 return buf.getvalue() 

528 

529 @staticmethod 

530 def _serializeConfluentWireHeader(schema_version): 

531 """Returns the byte prefix for Confluent Wire Format-style Kafka messages. 

532 

533 Parameters 

534 ---------- 

535 schema_version : `int` 

536 A version number which indicates the Confluent Schema Registry ID 

537 number of the Avro schema used to encode the message that follows this 

538 header. 

539 

540 Returns 

541 ------- 

542 header : `bytes` 

543 The 5-byte encoded message prefix. 

544 

545 Notes 

546 ----- 

547 The Confluent Wire Format is described more fully here: 

548 https://docs.confluent.io/current/schema-registry/serdes-develop/index.html#wire-format 

549 """ 

550 ConfluentWireFormatHeader = struct.Struct(">bi") 

551 return ConfluentWireFormatHeader.pack(0, schema_version) 

552 

553 def _delivery_callback(self, err, msg): 

554 if err: 

555 self.log.warning('Message failed delivery: %s\n' % err) 

556 else: 

557 self.log.debug('Message delivered to %s [%d] @ %d', msg.topic(), msg.partition(), msg.offset())