Coverage for python/lsst/dax/apdb/apdbSql.py: 17%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

305 statements  

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining Apdb class and related methods. 

23""" 

24 

25from __future__ import annotations 

26 

27__all__ = ["ApdbSqlConfig", "ApdbSql"] 

28 

29from contextlib import contextmanager 

30import logging 

31import numpy as np 

32import pandas 

33from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple 

34 

35import lsst.daf.base as dafBase 

36from lsst.pex.config import Field, ChoiceField, ListField 

37from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d 

38import sqlalchemy 

39from sqlalchemy import (func, sql) 

40from sqlalchemy.pool import NullPool 

41from .apdb import Apdb, ApdbConfig 

42from .apdbSchema import ApdbTables, TableDef 

43from .apdbSqlSchema import ApdbSqlSchema 

44from .timer import Timer 

45 

46 

47_LOG = logging.getLogger(__name__) 

48 

49 

50def _split(seq: Iterable, nItems: int) -> Iterator[List]: 

51 """Split a sequence into smaller sequences""" 

52 seq = list(seq) 

53 while seq: 

54 yield seq[:nItems] 

55 del seq[:nItems] 

56 

57 

58def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame: 

59 """Change type of the uint64 columns to int64, return copy of data frame. 

60 """ 

61 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64] 

62 return df.astype({name: np.int64 for name in names}) 

63 

64 

65def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float: 

66 """Calculate starting point for time-based source search. 

67 

68 Parameters 

69 ---------- 

70 visit_time : `lsst.daf.base.DateTime` 

71 Time of current visit. 

72 months : `int` 

73 Number of months in the sources history. 

74 

75 Returns 

76 ------- 

77 time : `float` 

78 A ``midPointTai`` starting point, MJD time. 

79 """ 

80 # TODO: `system` must be consistent with the code in ap_association 

81 # (see DM-31996) 

82 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30 

83 

84 

85@contextmanager 

86def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]: 

87 """Returns a connection, makes sure that ANSI mode is set for MySQL 

88 """ 

89 with engine.begin() as conn: 

90 if engine.name == 'mysql': 

91 conn.execute(sql.text("SET SESSION SQL_MODE = 'ANSI'")) 

92 yield conn 

93 return 

94 

95 

96class ApdbSqlConfig(ApdbConfig): 

97 """APDB configuration class for SQL implementation (ApdbSql). 

98 """ 

99 db_url = Field( 

100 dtype=str, 

101 doc="SQLAlchemy database connection URI" 

102 ) 

103 isolation_level = ChoiceField( 

104 dtype=str, 

105 doc="Transaction isolation level, if unset then backend-default value " 

106 "is used, except for SQLite backend where we use READ_UNCOMMITTED. " 

107 "Some backends may not support every allowed value.", 

108 allowed={ 

109 "READ_COMMITTED": "Read committed", 

110 "READ_UNCOMMITTED": "Read uncommitted", 

111 "REPEATABLE_READ": "Repeatable read", 

112 "SERIALIZABLE": "Serializable" 

113 }, 

114 default=None, 

115 optional=True 

116 ) 

117 connection_pool = Field( 

118 dtype=bool, 

119 doc="If False then disable SQLAlchemy connection pool. " 

120 "Do not use connection pool when forking.", 

121 default=True 

122 ) 

123 connection_timeout = Field( 

124 dtype=float, 

125 doc="Maximum time to wait time for database lock to be released before " 

126 "exiting. Defaults to sqlachemy defaults if not set.", 

127 default=None, 

128 optional=True 

129 ) 

130 sql_echo = Field( 

131 dtype=bool, 

132 doc="If True then pass SQLAlchemy echo option.", 

133 default=False 

134 ) 

135 dia_object_index = ChoiceField( 

136 dtype=str, 

137 doc="Indexing mode for DiaObject table", 

138 allowed={ 

139 'baseline': "Index defined in baseline schema", 

140 'pix_id_iov': "(pixelId, objectId, iovStart) PK", 

141 'last_object_table': "Separate DiaObjectLast table" 

142 }, 

143 default='baseline' 

144 ) 

145 htm_level = Field( 

146 dtype=int, 

147 doc="HTM indexing level", 

148 default=20 

149 ) 

150 htm_max_ranges = Field( 

151 dtype=int, 

152 doc="Max number of ranges in HTM envelope", 

153 default=64 

154 ) 

155 htm_index_column = Field( 

156 dtype=str, 

157 default="pixelId", 

158 doc="Name of a HTM index column for DiaObject and DiaSource tables" 

159 ) 

160 ra_dec_columns = ListField( 

161 dtype=str, 

162 default=["ra", "decl"], 

163 doc="Names ra/dec columns in DiaObject table" 

164 ) 

165 dia_object_columns = ListField( 

166 dtype=str, 

167 doc="List of columns to read from DiaObject, by default read all columns", 

168 default=[] 

169 ) 

170 object_last_replace = Field( 

171 dtype=bool, 

172 doc="If True (default) then use \"upsert\" for DiaObjectsLast table", 

173 default=True 

174 ) 

175 prefix = Field( 

176 dtype=str, 

177 doc="Prefix to add to table names and index names", 

178 default="" 

179 ) 

180 explain = Field( 

181 dtype=bool, 

182 doc="If True then run EXPLAIN SQL command on each executed query", 

183 default=False 

184 ) 

185 timer = Field( 

186 dtype=bool, 

187 doc="If True then print/log timing information", 

188 default=False 

189 ) 

190 

191 def validate(self) -> None: 

192 super().validate() 

193 if len(self.ra_dec_columns) != 2: 

194 raise ValueError("ra_dec_columns must have exactly two column names") 

195 

196 

197class ApdbSql(Apdb): 

198 """Implementation of APDB interface based on SQL database. 

199 

200 The implementation is configured via standard ``pex_config`` mechanism 

201 using `ApdbSqlConfig` configuration class. For an example of different 

202 configurations check ``config/`` folder. 

203 

204 Parameters 

205 ---------- 

206 config : `ApdbSqlConfig` 

207 Configuration object. 

208 """ 

209 

210 ConfigClass = ApdbSqlConfig 

211 

212 def __init__(self, config: ApdbSqlConfig): 

213 

214 self.config = config 

215 

216 _LOG.debug("APDB Configuration:") 

217 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index) 

218 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months) 

219 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months) 

220 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns) 

221 _LOG.debug(" object_last_replace: %s", self.config.object_last_replace) 

222 _LOG.debug(" schema_file: %s", self.config.schema_file) 

223 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file) 

224 _LOG.debug(" schema prefix: %s", self.config.prefix) 

225 

226 # engine is reused between multiple processes, make sure that we don't 

227 # share connections by disabling pool (by using NullPool class) 

228 kw = dict(echo=self.config.sql_echo) 

229 conn_args: Dict[str, Any] = dict() 

230 if not self.config.connection_pool: 

231 kw.update(poolclass=NullPool) 

232 if self.config.isolation_level is not None: 

233 kw.update(isolation_level=self.config.isolation_level) 

234 elif self.config.db_url.startswith("sqlite"): 

235 # Use READ_UNCOMMITTED as default value for sqlite. 

236 kw.update(isolation_level="READ_UNCOMMITTED") 

237 if self.config.connection_timeout is not None: 

238 if self.config.db_url.startswith("sqlite"): 

239 conn_args.update(timeout=self.config.connection_timeout) 

240 elif self.config.db_url.startswith(("postgresql", "mysql")): 

241 conn_args.update(connect_timeout=self.config.connection_timeout) 

242 kw.update(connect_args=conn_args) 

243 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw) 

244 

245 self._schema = ApdbSqlSchema(engine=self._engine, 

246 dia_object_index=self.config.dia_object_index, 

247 schema_file=self.config.schema_file, 

248 extra_schema_file=self.config.extra_schema_file, 

249 prefix=self.config.prefix, 

250 htm_index_column=self.config.htm_index_column) 

251 

252 self.pixelator = HtmPixelization(self.config.htm_level) 

253 

254 def tableRowCount(self) -> Dict[str, int]: 

255 """Returns dictionary with the table names and row counts. 

256 

257 Used by ``ap_proto`` to keep track of the size of the database tables. 

258 Depending on database technology this could be expensive operation. 

259 

260 Returns 

261 ------- 

262 row_counts : `dict` 

263 Dict where key is a table name and value is a row count. 

264 """ 

265 res = {} 

266 tables: List[sqlalchemy.schema.Table] = [ 

267 self._schema.objects, self._schema.sources, self._schema.forcedSources] 

268 if self.config.dia_object_index == 'last_object_table': 

269 tables.append(self._schema.objects_last) 

270 for table in tables: 

271 stmt = sql.select([func.count()]).select_from(table) 

272 count = self._engine.scalar(stmt) 

273 res[table.name] = count 

274 

275 return res 

276 

277 def tableDef(self, table: ApdbTables) -> Optional[TableDef]: 

278 # docstring is inherited from a base class 

279 return self._schema.tableSchemas.get(table) 

280 

281 def makeSchema(self, drop: bool = False) -> None: 

282 # docstring is inherited from a base class 

283 self._schema.makeSchema(drop=drop) 

284 

285 def getDiaObjects(self, region: Region) -> pandas.DataFrame: 

286 # docstring is inherited from a base class 

287 

288 # decide what columns we need 

289 table: sqlalchemy.schema.Table 

290 if self.config.dia_object_index == 'last_object_table': 

291 table = self._schema.objects_last 

292 else: 

293 table = self._schema.objects 

294 if not self.config.dia_object_columns: 

295 query = table.select() 

296 else: 

297 columns = [table.c[col] for col in self.config.dia_object_columns] 

298 query = sql.select(columns) 

299 

300 # build selection 

301 htm_index_column = table.columns[self.config.htm_index_column] 

302 exprlist = [] 

303 pixel_ranges = self._htm_indices(region) 

304 for low, upper in pixel_ranges: 

305 upper -= 1 

306 if low == upper: 

307 exprlist.append(htm_index_column == low) 

308 else: 

309 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

310 query = query.where(sql.expression.or_(*exprlist)) 

311 

312 # select latest version of objects 

313 if self.config.dia_object_index != 'last_object_table': 

314 query = query.where(table.c.validityEnd == None) # noqa: E711 

315 

316 _LOG.debug("query: %s", query) 

317 

318 if self.config.explain: 

319 # run the same query with explain 

320 self._explain(query, self._engine) 

321 

322 # execute select 

323 with Timer('DiaObject select', self.config.timer): 

324 with self._engine.begin() as conn: 

325 objects = pandas.read_sql_query(query, conn) 

326 _LOG.debug("found %s DiaObjects", len(objects)) 

327 return objects 

328 

329 def getDiaSources(self, region: Region, 

330 object_ids: Optional[Iterable[int]], 

331 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

332 # docstring is inherited from a base class 

333 if self.config.read_sources_months == 0: 

334 _LOG.debug("Skip DiaSources fetching") 

335 return None 

336 

337 if object_ids is None: 

338 # region-based select 

339 return self._getDiaSourcesInRegion(region, visit_time) 

340 else: 

341 return self._getDiaSourcesByIDs(list(object_ids), visit_time) 

342 

343 def getDiaForcedSources(self, region: Region, 

344 object_ids: Optional[Iterable[int]], 

345 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]: 

346 """Return catalog of DiaForcedSource instances from a given region. 

347 

348 Parameters 

349 ---------- 

350 region : `lsst.sphgeom.Region` 

351 Region to search for DIASources. 

352 object_ids : iterable [ `int` ], optional 

353 List of DiaObject IDs to further constrain the set of returned 

354 sources. If list is empty then empty catalog is returned with a 

355 correct schema. 

356 visit_time : `lsst.daf.base.DateTime` 

357 Time of the current visit. 

358 

359 Returns 

360 ------- 

361 catalog : `pandas.DataFrame`, or `None` 

362 Catalog containing DiaSource records. `None` is returned if 

363 ``read_sources_months`` configuration parameter is set to 0. 

364 

365 Raises 

366 ------ 

367 NotImplementedError 

368 Raised if ``object_ids`` is `None`. 

369 

370 Notes 

371 ----- 

372 Even though base class allows `None` to be passed for ``object_ids``, 

373 this class requires ``object_ids`` to be not-`None`. 

374 `NotImplementedError` is raised if `None` is passed. 

375 

376 This method returns DiaForcedSource catalog for a region with additional 

377 filtering based on DiaObject IDs. Only a subset of DiaSource history 

378 is returned limited by ``read_forced_sources_months`` config parameter, 

379 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog 

380 is always returned with a correct schema (columns/types). 

381 """ 

382 

383 if self.config.read_forced_sources_months == 0: 

384 _LOG.debug("Skip DiaForceSources fetching") 

385 return None 

386 

387 if object_ids is None: 

388 # This implementation does not support region-based selection. 

389 raise NotImplementedError("Region-based selection is not supported") 

390 

391 # TODO: DateTime.MJD must be consistent with code in ap_association, 

392 # alternatively we can fill midPointTai ourselves in store() 

393 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_forced_sources_months) 

394 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

395 

396 table: sqlalchemy.schema.Table = self._schema.forcedSources 

397 with Timer('DiaForcedSource select', self.config.timer): 

398 sources = self._getSourcesByIDs(table, list(object_ids), midPointTai_start) 

399 

400 _LOG.debug("found %s DiaForcedSources", len(sources)) 

401 return sources 

402 

403 def store(self, 

404 visit_time: dafBase.DateTime, 

405 objects: pandas.DataFrame, 

406 sources: Optional[pandas.DataFrame] = None, 

407 forced_sources: Optional[pandas.DataFrame] = None) -> None: 

408 # docstring is inherited from a base class 

409 

410 # fill pixelId column for DiaObjects 

411 objects = self._add_obj_htm_index(objects) 

412 self._storeDiaObjects(objects, visit_time) 

413 

414 if sources is not None: 

415 # copy pixelId column from DiaObjects to DiaSources 

416 sources = self._add_src_htm_index(sources, objects) 

417 self._storeDiaSources(sources) 

418 

419 if forced_sources is not None: 

420 self._storeDiaForcedSources(forced_sources) 

421 

422 def dailyJob(self) -> None: 

423 # docstring is inherited from a base class 

424 

425 if self._engine.name == 'postgresql': 

426 

427 # do VACUUM on all tables 

428 _LOG.info("Running VACUUM on all tables") 

429 connection = self._engine.raw_connection() 

430 ISOLATION_LEVEL_AUTOCOMMIT = 0 

431 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) 

432 cursor = connection.cursor() 

433 cursor.execute("VACUUM ANALYSE") 

434 

435 def countUnassociatedObjects(self) -> int: 

436 # docstring is inherited from a base class 

437 

438 # Retrieve the DiaObject table. 

439 table: sqlalchemy.schema.Table = self._schema.objects 

440 

441 # Construct the sql statement. 

442 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1) 

443 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711 

444 

445 # Return the count. 

446 with self._engine.begin() as conn: 

447 count = conn.scalar(stmt) 

448 

449 return count 

450 

451 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime 

452 ) -> pandas.DataFrame: 

453 """Returns catalog of DiaSource instances from given region. 

454 

455 Parameters 

456 ---------- 

457 region : `lsst.sphgeom.Region` 

458 Region to search for DIASources. 

459 visit_time : `lsst.daf.base.DateTime` 

460 Time of the current visit. 

461 

462 Returns 

463 ------- 

464 catalog : `pandas.DataFrame` 

465 Catalog containing DiaSource records. 

466 """ 

467 # TODO: DateTime.MJD must be consistent with code in ap_association, 

468 # alternatively we can fill midPointTai ourselves in store() 

469 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

470 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

471 

472 table: sqlalchemy.schema.Table = self._schema.sources 

473 query = table.select() 

474 

475 # build selection 

476 htm_index_column = table.columns[self.config.htm_index_column] 

477 exprlist = [] 

478 pixel_ranges = self._htm_indices(region) 

479 for low, upper in pixel_ranges: 

480 upper -= 1 

481 if low == upper: 

482 exprlist.append(htm_index_column == low) 

483 else: 

484 exprlist.append(sql.expression.between(htm_index_column, low, upper)) 

485 time_filter = table.columns["midPointTai"] > midPointTai_start 

486 where = sql.expression.and_(sql.expression.or_(*exprlist), time_filter) 

487 query = query.where(where) 

488 

489 # execute select 

490 with Timer('DiaSource select', self.config.timer): 

491 with _ansi_session(self._engine) as conn: 

492 sources = pandas.read_sql_query(query, conn) 

493 _LOG.debug("found %s DiaSources", len(sources)) 

494 return sources 

495 

496 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime 

497 ) -> pandas.DataFrame: 

498 """Returns catalog of DiaSource instances given set of DiaObject IDs. 

499 

500 Parameters 

501 ---------- 

502 object_ids : 

503 Collection of DiaObject IDs 

504 visit_time : `lsst.daf.base.DateTime` 

505 Time of the current visit. 

506 

507 Returns 

508 ------- 

509 catalog : `pandas.DataFrame` 

510 Catalog contaning DiaSource records. 

511 """ 

512 # TODO: DateTime.MJD must be consistent with code in ap_association, 

513 # alternatively we can fill midPointTai ourselves in store() 

514 midPointTai_start = _make_midPointTai_start(visit_time, self.config.read_sources_months) 

515 _LOG.debug("midPointTai_start = %.6f", midPointTai_start) 

516 

517 table: sqlalchemy.schema.Table = self._schema.sources 

518 with Timer('DiaSource select', self.config.timer): 

519 sources = self._getSourcesByIDs(table, object_ids, midPointTai_start) 

520 

521 _LOG.debug("found %s DiaSources", len(sources)) 

522 return sources 

523 

524 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table, 

525 object_ids: List[int], 

526 midPointTai_start: float 

527 ) -> pandas.DataFrame: 

528 """Returns catalog of DiaSource or DiaForcedSource instances given set 

529 of DiaObject IDs. 

530 

531 Parameters 

532 ---------- 

533 table : `sqlalchemy.schema.Table` 

534 Database table. 

535 object_ids : 

536 Collection of DiaObject IDs 

537 midPointTai_start : `float` 

538 Earliest midPointTai to retrieve. 

539 

540 Returns 

541 ------- 

542 catalog : `pandas.DataFrame` 

543 Catalog contaning DiaSource records. `None` is returned if 

544 ``read_sources_months`` configuration parameter is set to 0 or 

545 when ``object_ids`` is empty. 

546 """ 

547 sources: Optional[pandas.DataFrame] = None 

548 with _ansi_session(self._engine) as conn: 

549 if len(object_ids) <= 0: 

550 _LOG.debug("ID list is empty, just fetch empty result") 

551 query = table.select().where(False) 

552 sources = pandas.read_sql_query(query, conn) 

553 else: 

554 for ids in _split(sorted(object_ids), 1000): 

555 query = f'SELECT * FROM "{table.name}" WHERE ' 

556 

557 # select by object id 

558 ids_str = ",".join(str(id) for id in ids) 

559 query += f'"diaObjectId" IN ({ids_str})' 

560 query += f' AND "midPointTai" > {midPointTai_start}' 

561 

562 # execute select 

563 df = pandas.read_sql_query(sql.text(query), conn) 

564 if sources is None: 

565 sources = df 

566 else: 

567 sources = sources.append(df) 

568 assert sources is not None, "Catalog cannot be None" 

569 return sources 

570 

571 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None: 

572 """Store catalog of DiaObjects from current visit. 

573 

574 Parameters 

575 ---------- 

576 objs : `pandas.DataFrame` 

577 Catalog with DiaObject records. 

578 visit_time : `lsst.daf.base.DateTime` 

579 Time of the visit. 

580 """ 

581 

582 ids = sorted(objs['diaObjectId']) 

583 _LOG.debug("first object ID: %d", ids[0]) 

584 

585 # NOTE: workaround for sqlite, need this here to avoid 

586 # "database is locked" error. 

587 table: sqlalchemy.schema.Table = self._schema.objects 

588 

589 # TODO: Need to verify that we are using correct scale here for 

590 # DATETIME representation (see DM-31996). 

591 dt = visit_time.toPython() 

592 

593 # everything to be done in single transaction 

594 with _ansi_session(self._engine) as conn: 

595 

596 ids_str = ",".join(str(id) for id in ids) 

597 

598 if self.config.dia_object_index == 'last_object_table': 

599 

600 # insert and replace all records in LAST table, mysql and postgres have 

601 # non-standard features 

602 table = self._schema.objects_last 

603 do_replace = self.config.object_last_replace 

604 # If the input data is of type Pandas, we drop the previous 

605 # objects regardless of the do_replace setting due to how 

606 # Pandas inserts objects. 

607 if not do_replace or isinstance(objs, pandas.DataFrame): 

608 query = 'DELETE FROM "' + table.name + '" ' 

609 query += 'WHERE "diaObjectId" IN (' + ids_str + ') ' 

610 

611 if self.config.explain: 

612 # run the same query with explain 

613 self._explain(query, conn) 

614 

615 with Timer(table.name + ' delete', self.config.timer): 

616 res = conn.execute(sql.text(query)) 

617 _LOG.debug("deleted %s objects", res.rowcount) 

618 

619 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt) 

620 with Timer("DiaObjectLast insert", self.config.timer): 

621 objs = _coerce_uint64(objs) 

622 for col, data in extra_columns.items(): 

623 objs[col] = data 

624 objs.to_sql("DiaObjectLast", conn, if_exists='append', 

625 index=False) 

626 else: 

627 

628 # truncate existing validity intervals 

629 table = self._schema.objects 

630 query = 'UPDATE "' + table.name + '" ' 

631 query += "SET \"validityEnd\" = '" + str(dt) + "' " 

632 query += 'WHERE "diaObjectId" IN (' + ids_str + ') ' 

633 query += 'AND "validityEnd" IS NULL' 

634 

635 # _LOG.debug("query: %s", query) 

636 

637 if self.config.explain: 

638 # run the same query with explain 

639 self._explain(query, conn) 

640 

641 with Timer(table.name + ' truncate', self.config.timer): 

642 res = conn.execute(sql.text(query)) 

643 _LOG.debug("truncated %s intervals", res.rowcount) 

644 

645 # insert new versions 

646 table = self._schema.objects 

647 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt, 

648 validityEnd=None) 

649 with Timer("DiaObject insert", self.config.timer): 

650 objs = _coerce_uint64(objs) 

651 for col, data in extra_columns.items(): 

652 objs[col] = data 

653 objs.to_sql("DiaObject", conn, if_exists='append', 

654 index=False) 

655 

656 def _storeDiaSources(self, sources: pandas.DataFrame) -> None: 

657 """Store catalog of DiaSources from current visit. 

658 

659 Parameters 

660 ---------- 

661 sources : `pandas.DataFrame` 

662 Catalog containing DiaSource records 

663 """ 

664 # everything to be done in single transaction 

665 with _ansi_session(self._engine) as conn: 

666 

667 with Timer("DiaSource insert", self.config.timer): 

668 sources = _coerce_uint64(sources) 

669 sources.to_sql("DiaSource", conn, if_exists='append', index=False) 

670 

671 def _storeDiaForcedSources(self, sources: pandas.DataFrame) -> None: 

672 """Store a set of DiaForcedSources from current visit. 

673 

674 Parameters 

675 ---------- 

676 sources : `pandas.DataFrame` 

677 Catalog containing DiaForcedSource records 

678 """ 

679 

680 # everything to be done in single transaction 

681 with _ansi_session(self._engine) as conn: 

682 

683 with Timer("DiaForcedSource insert", self.config.timer): 

684 sources = _coerce_uint64(sources) 

685 sources.to_sql("DiaForcedSource", conn, if_exists='append', index=False) 

686 

687 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) -> None: 

688 """Run the query with explain 

689 """ 

690 

691 _LOG.info("explain for query: %s...", query[:64]) 

692 

693 if conn.engine.name == 'mysql': 

694 query = "EXPLAIN EXTENDED " + query 

695 else: 

696 query = "EXPLAIN " + query 

697 

698 res = conn.execute(sql.text(query)) 

699 if res.returns_rows: 

700 _LOG.info("explain: %s", res.keys()) 

701 for row in res: 

702 _LOG.info("explain: %s", row) 

703 else: 

704 _LOG.info("EXPLAIN returned nothing") 

705 

706 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]: 

707 """Generate a set of HTM indices covering specified region. 

708 

709 Parameters 

710 ---------- 

711 region: `sphgeom.Region` 

712 Region that needs to be indexed. 

713 

714 Returns 

715 ------- 

716 Sequence of ranges, range is a tuple (minHtmID, maxHtmID). 

717 """ 

718 _LOG.debug('region: %s', region) 

719 indices = self.pixelator.envelope(region, self.config.htm_max_ranges) 

720 

721 if _LOG.isEnabledFor(logging.DEBUG): 

722 for irange in indices.ranges(): 

723 _LOG.debug('range: %s %s', self.pixelator.toString(irange[0]), 

724 self.pixelator.toString(irange[1])) 

725 

726 return indices.ranges() 

727 

728 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame: 

729 """Calculate HTM index for each record and add it to a DataFrame. 

730 

731 Notes 

732 ----- 

733 This overrides any existing column in a DataFrame with the same name 

734 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

735 returned. 

736 """ 

737 # calculate HTM index for every DiaObject 

738 htm_index = np.zeros(df.shape[0], dtype=np.int64) 

739 ra_col, dec_col = self.config.ra_dec_columns 

740 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])): 

741 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec)) 

742 idx = self.pixelator.index(uv3d) 

743 htm_index[i] = idx 

744 df = df.copy() 

745 df[self.config.htm_index_column] = htm_index 

746 return df 

747 

748 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame: 

749 """Add pixelId column to DiaSource catalog. 

750 

751 Notes 

752 ----- 

753 This method copies pixelId value from a matching DiaObject record. 

754 DiaObject catalog needs to have a pixelId column filled by 

755 ``_add_obj_htm_index`` method and DiaSource records need to be 

756 associated to DiaObjects via ``diaObjectId`` column. 

757 

758 This overrides any existing column in a DataFrame with the same name 

759 (pixelId). Original DataFrame is not changed, copy of a DataFrame is 

760 returned. 

761 """ 

762 pixel_id_map: Dict[int, int] = { 

763 diaObjectId: pixelId for diaObjectId, pixelId 

764 in zip(objs["diaObjectId"], objs[self.config.htm_index_column]) 

765 } 

766 # DiaSources associated with SolarSystemObjects do not have an 

767 # associated DiaObject hence we skip them and set their htmIndex 

768 # value to 0. 

769 pixel_id_map[0] = 0 

770 htm_index = np.zeros(sources.shape[0], dtype=np.int64) 

771 for i, diaObjId in enumerate(sources["diaObjectId"]): 

772 htm_index[i] = pixel_id_map[diaObjId] 

773 sources = sources.copy() 

774 sources[self.config.htm_index_column] = htm_index 

775 return sources