Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from future import standard_library 

2standard_library.install_aliases() 

3import sys 

4from builtins import str 

5 

6# 2017 March 9 

7# str_cast exists because numpy.dtype does 

8# not like unicode-like things as the names 

9# of columns. Unfortunately, in python 2, 

10# builtins.str looks unicode-like. We will 

11# use str_cast in python 2 to maintain 

12# both python 3 compatibility and our use of 

13# numpy dtype 

14str_cast = str 

15if sys.version_info.major == 2: 15 ↛ 16line 15 didn't jump to line 16, because the condition on line 15 was never true

16 from past.builtins import str as past_str 

17 str_cast = past_str 

18 

19from builtins import zip 

20from builtins import object 

21import warnings 

22import numpy 

23import os 

24import inspect 

25from io import BytesIO 

26from collections import OrderedDict 

27 

28from .utils import loadData 

29from sqlalchemy.orm import scoped_session, sessionmaker 

30from sqlalchemy.sql import expression 

31from sqlalchemy.engine import reflection, url 

32from sqlalchemy import (create_engine, MetaData, 

33 Table, event, text) 

34from sqlalchemy import exc as sa_exc 

35from lsst.daf.butler.registry import DbAuth 

36from lsst.sims.utils.CodeUtilities import sims_clean_up 

37from lsst.utils import getPackageDir 

38 

39#The documentation at http://docs.sqlalchemy.org/en/rel_0_7/core/types.html#sqlalchemy.types.Numeric 

40#suggests using the cdecimal module. Since it is not standard, import decimal. 

41#TODO: test for cdecimal and use it if it exists. 

42import decimal 

43from future.utils import with_metaclass 

44 

45__all__ = ["ChunkIterator", "DBObject", "CatalogDBObject", "fileDBObject"] 

46 

47def valueOfPi(): 

48 """ 

49 A function to return the value of pi. This is needed for adding PI() 

50 to sqlite databases 

51 """ 

52 return numpy.pi 

53 

54def declareTrigFunctions(conn,connection_rec,connection_proxy): 

55 """ 

56 A database event listener 

57 which will define the math functions necessary for evaluating the 

58 Haversine function in sqlite databases (where they are not otherwise 

59 defined) 

60 

61 see: http://docs.sqlalchemy.org/en/latest/core/events.html 

62 """ 

63 

64 conn.create_function("COS",1,numpy.cos) 

65 conn.create_function("SIN",1,numpy.sin) 

66 conn.create_function("ASIN",1,numpy.arcsin) 

67 conn.create_function("SQRT",1,numpy.sqrt) 

68 conn.create_function("POWER",2,numpy.power) 

69 conn.create_function("PI",0,valueOfPi) 

70 

71#------------------------------------------------------------ 

72# Iterator for database chunks 

73 

74class ChunkIterator(object): 

75 """Iterator for query chunks""" 

76 def __init__(self, dbobj, query, chunk_size, arbitrarySQL = False): 

77 self.dbobj = dbobj 

78 self.exec_query = dbobj.connection.session.execute(query) 

79 self.chunk_size = chunk_size 

80 

81 #arbitrarySQL exists in case a CatalogDBObject calls 

82 #get_arbitrary_chunk_iterator; in that case, we need to 

83 #be able to tell this object to call _postprocess_arbitrary_results, 

84 #rather than _postprocess_results 

85 self.arbitrarySQL = arbitrarySQL 

86 

87 def __iter__(self): 

88 return self 

89 

90 def __next__(self): 

91 if self.chunk_size is None and not self.exec_query.closed: 

92 chunk = self.exec_query.fetchall() 

93 return self._postprocess_results(chunk) 

94 elif self.chunk_size is not None: 

95 chunk = self.exec_query.fetchmany(self.chunk_size) 

96 return self._postprocess_results(chunk) 

97 else: 

98 raise StopIteration 

99 

100 def _postprocess_results(self, chunk): 

101 if len(chunk)==0: 

102 raise StopIteration 

103 if self.arbitrarySQL: 

104 return self.dbobj._postprocess_arbitrary_results(chunk) 

105 else: 

106 return self.dbobj._postprocess_results(chunk) 

107 

108 

109class DBConnection(object): 

110 """ 

111 This is a class that will hold the engine, session, and metadata for a 

112 DBObject. This will allow multiple DBObjects to share the same 

113 sqlalchemy connection, when appropriate. 

114 """ 

115 

116 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False): 

117 """ 

118 @param [in] database is the name of the database file being connected to 

119 

120 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

121 

122 @param [in] host is the URL of the remote host, if appropriate 

123 

124 @param [in] port is the port on the remote host to connect to, if appropriate 

125 

126 @param [in] verbose is a boolean controlling sqlalchemy's verbosity 

127 """ 

128 

129 self._database = database 

130 self._driver = driver 

131 self._host = host 

132 self._port = port 

133 self._verbose = verbose 

134 

135 self._validate_conn_params() 

136 self._connect_to_engine() 

137 

138 def __del__(self): 

139 try: 

140 del self._metadata 

141 except AttributeError: 

142 pass 

143 

144 try: 

145 del self._engine 

146 except AttributeError: 

147 pass 

148 

149 try: 

150 del self._session 

151 except AttributeError: 

152 pass 

153 

154 def _connect_to_engine(self): 

155 

156 #DbAuth will not look up hosts that are None, '' or 0 

157 if self._host: 

158 # This is triggered when you need to connect to a remote database. 

159 # Use 'HOME' as the default location (backwards compatibility) but fail graciously 

160 authdir = os.getenv('HOME') 

161 if authdir is None: 

162 # Use an empty file in this package, which causes 

163 # a fallback to database-native authentication. 

164 authdir = getPackageDir('sims_catalogs') 

165 auth = DbAuth(os.path.join(authdir, "tests", "db-auth.yaml")) 

166 else: 

167 auth = DbAuth(os.path.join(authdir, ".lsst", "db-auth.yaml")) 

168 username, password = auth.getAuth( 

169 self._driver, host=self._host, port=self._port, 

170 database=self._database) 

171 dbUrl = url.URL(self._driver, 

172 host=self._host, 

173 port=self._port, 

174 database=self._database, 

175 username=username, 

176 password=password) 

177 else: 

178 dbUrl = url.URL(self._driver, 

179 database=self._database) 

180 

181 

182 self._engine = create_engine(dbUrl, echo=self._verbose) 

183 

184 if self._engine.dialect.name == 'sqlite': 

185 event.listen(self._engine, 'checkout', declareTrigFunctions) 

186 

187 self._session = scoped_session(sessionmaker(autoflush=True, 

188 bind=self._engine)) 

189 self._metadata = MetaData(bind=self._engine) 

190 

191 

192 def _validate_conn_params(self): 

193 """Validate connection parameters 

194 

195 - Check if user passed dbAddress instead of an database. Convert and warn. 

196 - Check that required connection paramters are present 

197 - Replace default host/port if driver is 'sqlite' 

198 """ 

199 

200 if self._database is None: 

201 raise AttributeError("Cannot instantiate DBConnection; database is 'None'") 

202 

203 if '//' in self._database: 

204 warnings.warn("Database name '%s' is invalid but looks like a dbAddress. " 

205 "Attempting to convert to database, driver, host, " 

206 "and port parameters. Any usernames and passwords are ignored and must " 

207 "be in the db-auth.paf policy file. "%(self.database), FutureWarning) 

208 

209 dbUrl = url.make_url(self._database) 

210 dialect = dbUrl.get_dialect() 

211 self._driver = dialect.name + '+' + dialect.driver if dialect.driver else dialect.name 

212 for key, value in dbUrl.translate_connect_args().items(): 

213 if value is not None: 

214 setattr(self, '_'+key, value) 

215 

216 errMessage = "Please supply a 'driver' kwarg to the constructor or in class definition. " 

217 errMessage += "'driver' is formatted as dialect+driver, such as 'sqlite' or 'mssql+pymssql'." 

218 if not hasattr(self, '_driver'): 

219 raise AttributeError("%s has no attribute 'driver'. "%(self.__class__.__name__) + errMessage) 

220 elif self._driver is None: 

221 raise AttributeError("%s.driver is None. "%(self.__class__.__name__) + errMessage) 

222 

223 errMessage = "Please supply a 'database' kwarg to the constructor or in class definition. " 

224 errMessage += " 'database' is the database name or the filename path if driver is 'sqlite'. " 

225 if not hasattr(self, '_database'): 

226 raise AttributeError("%s has no attribute 'database'. "%(self.__class__.__name__) + errMessage) 

227 elif self._database is None: 

228 raise AttributeError("%s.database is None. "%(self.__class__.__name__) + errMessage) 

229 

230 if 'sqlite' in self._driver: 

231 #When passed sqlite database, override default host/port 

232 self._host = None 

233 self._port = None 

234 

235 

236 def __eq__(self, other): 

237 return (str(self._database) == str(other._database)) and \ 

238 (str(self._driver) == str(other._driver)) and \ 

239 (str(self._host) == str(other._host)) and \ 

240 (str(self._port) == str(other._port)) 

241 

242 

243 @property 

244 def engine(self): 

245 return self._engine 

246 

247 @property 

248 def session(self): 

249 return self._session 

250 

251 

252 @property 

253 def metadata(self): 

254 return self._metadata 

255 

256 @property 

257 def database(self): 

258 return self._database 

259 

260 @property 

261 def driver(self): 

262 return self._driver 

263 

264 @property 

265 def host(self): 

266 return self._host 

267 

268 @property 

269 def port(self): 

270 return self._port 

271 

272 @property 

273 def verbose(self): 

274 return self._verbose 

275 

276 

277class DBObject(object): 

278 

279 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False, 

280 connection=None, cache_connection=True): 

281 """ 

282 Initialize DBObject. 

283 

284 @param [in] database is the name of the database file being connected to 

285 

286 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

287 

288 @param [in] host is the URL of the remote host, if appropriate 

289 

290 @param [in] port is the port on the remote host to connect to, if appropriate 

291 

292 @param [in] verbose is a boolean controlling sqlalchemy's verbosity (default False) 

293 

294 @param [in] connection is an optional instance of DBConnection, in the event that 

295 this DBObject can share a database connection with another DBObject. This is only 

296 necessary or even possible in a few specialized cases and should be used carefully. 

297 

298 @param [in] cache_connection is a boolean. If True, DBObject will use a cache of 

299 DBConnections (if available) to get the connection to this database. 

300 """ 

301 

302 self.dtype = None 

303 #this is a cache for the query, so that any one query does not have to guess dtype multiple times 

304 

305 if connection is None: 

306 #Explicit constructor to DBObject preferred 

307 kwargDict = dict(database=database, 

308 driver=driver, 

309 host=host, 

310 port=port, 

311 verbose=verbose) 

312 

313 for key, value in kwargDict.items(): 

314 if value is not None or not hasattr(self, key): 

315 setattr(self, key, value) 

316 

317 self.connection = self._get_connection(self.database, self.driver, self.host, self.port, 

318 use_cache=cache_connection) 

319 

320 else: 

321 self.connection = connection 

322 self.database = connection.database 

323 self.driver = connection.driver 

324 self.host = connection.host 

325 self.port = connection.port 

326 self.verbose = connection.verbose 

327 

328 def _get_connection(self, database, driver, host, port, use_cache=True): 

329 """ 

330 Search self._connection_cache (if it exists; it won't for DBObject, but 

331 will for CatalogDBObject) for a DBConnection matching the specified 

332 parameters. If it exists, return it. If not, open a connection to 

333 the specified database, add it to the cache, and return the connection. 

334 

335 Parameters 

336 ---------- 

337 database is the name of the database file being connected to 

338 

339 driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

340 

341 host is the URL of the remote host, if appropriate 

342 

343 port is the port on the remote host to connect to, if appropriate 

344 

345 use_cache is a boolean specifying whether or not we try to use the 

346 cache of database connections (you don't want to if opening many 

347 connections in many threads). 

348 """ 

349 

350 if use_cache and hasattr(self, '_connection_cache'): 

351 for conn in self._connection_cache: 

352 if str(conn.database) == str(database): 

353 if str(conn.driver) == str(driver): 

354 if str(conn.host) == str(host): 

355 if str(conn.port) == str(port): 

356 return conn 

357 

358 conn = DBConnection(database=database, driver=driver, host=host, port=port) 

359 

360 if use_cache and hasattr(self, '_connection_cache'): 

361 self._connection_cache.append(conn) 

362 

363 return conn 

364 

365 def get_table_names(self): 

366 """Return a list of the names of the tables (and views) in the database""" 

367 return [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_table_names()] + \ 

368 [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_view_names()] 

369 

370 def get_column_names(self, tableName=None): 

371 """ 

372 Return a list of the names of the columns in the specified table. 

373 If no table is specified, return a dict of lists. The dict will be keyed 

374 to the table names. The lists will be of the column names in that table 

375 """ 

376 tableNameList = self.get_table_names() 

377 if tableName is not None: 

378 if tableName not in tableNameList: 

379 return [] 

380 return [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(tableName)] 

381 else: 

382 columnDict = {} 

383 for name in tableNameList: 

384 columnList = [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(name)] 

385 columnDict[name] = columnList 

386 return columnDict 

387 

388 def _final_pass(self, results): 

389 """ Make final modifications to a set of data before returning it to the user 

390 

391 **Parameters** 

392 

393 * results : a structured array constructed from the result set from a query 

394 

395 **Returns** 

396 

397 * results : a potentially modified structured array. The default is to do nothing. 

398 

399 """ 

400 return results 

401 

402 def _convert_results_to_numpy_recarray_dbobj(self, results): 

403 if self.dtype is None: 

404 """ 

405 Determine the dtype from the data. 

406 Store it in a global variable so we do not have to repeat on every chunk. 

407 """ 

408 dataString = '' 

409 

410 # We are going to detect the dtype by reading in a single row 

411 # of data with np.genfromtxt. To do this, we must pass the 

412 # row as a string delimited by a specified character. Here we 

413 # select a character that does not occur anywhere in the data. 

414 delimit_char_list = [',', ';', '|', ':', '/', '\\'] 

415 delimit_char = None 

416 for cc in delimit_char_list: 

417 is_valid = True 

418 for xx in results[0]: 

419 if cc in str(xx): 

420 is_valid = False 

421 break 

422 

423 if is_valid: 

424 delimit_char = cc 

425 break 

426 

427 if delimit_char is None: 

428 raise RuntimeError("DBObject could not detect the dtype of your return rows\n" 

429 "Please specify a dtype with the 'dtype' kwarg.") 

430 

431 for xx in results[0]: 

432 if dataString != '': 

433 dataString += delimit_char 

434 dataString += str(xx) 

435 names = [str_cast(ww) for ww in results[0].keys()] 

436 dataArr = numpy.genfromtxt(BytesIO(dataString.encode()), dtype=None, 

437 names=names, delimiter=delimit_char, 

438 encoding='utf-8') 

439 dt_list = [] 

440 for name in dataArr.dtype.names: 

441 type_name = str(dataArr.dtype[name]) 

442 sub_list = [name] 

443 if type_name.startswith('S') or type_name.startswith('|S'): 

444 sub_list.append(str_cast) 

445 sub_list.append(int(type_name.replace('S','').replace('|',''))) 

446 else: 

447 sub_list.append(dataArr.dtype[name]) 

448 dt_list.append(tuple(sub_list)) 

449 

450 self.dtype = numpy.dtype(dt_list) 

451 

452 if len(results) == 0: 

453 return numpy.recarray((0,), dtype = self.dtype) 

454 

455 retresults = numpy.rec.fromrecords([tuple(xx) for xx in results],dtype = self.dtype) 

456 return retresults 

457 

458 def _postprocess_results(self, results): 

459 """ 

460 This wrapper exists so that a ChunkIterator built from a DBObject 

461 can have the same API as a ChunkIterator built from a CatalogDBObject 

462 """ 

463 return self._postprocess_arbitrary_results(results) 

464 

465 

466 def _postprocess_arbitrary_results(self, results): 

467 

468 if not isinstance(results, numpy.recarray): 

469 retresults = self._convert_results_to_numpy_recarray_dbobj(results) 

470 else: 

471 retresults = results 

472 

473 return self._final_pass(retresults) 

474 

475 def execute_arbitrary(self, query, dtype = None): 

476 """ 

477 Executes an arbitrary query. Returns a recarray of the results. 

478 

479 dtype will be the dtype of the output recarray. If it is None, then 

480 the code will guess the datatype and assign generic names to the columns 

481 """ 

482 

483 try: 

484 is_string = isinstance(query, basestring) 

485 except: 

486 is_string = isinstance(query, str) 

487 

488 if not is_string: 

489 raise RuntimeError("DBObject execute must be called with a string query") 

490 

491 unacceptableCommands = ["delete","drop","insert","update"] 

492 for badCommand in unacceptableCommands: 

493 if query.lower().find(badCommand.lower())>=0: 

494 raise RuntimeError("query made to DBObject execute contained %s " % badCommand) 

495 

496 self.dtype = dtype 

497 retresults = self._postprocess_arbitrary_results(self.connection.session.execute(query).fetchall()) 

498 return retresults 

499 

500 def get_arbitrary_chunk_iterator(self, query, chunk_size = None, dtype =None): 

501 """ 

502 This wrapper exists so that CatalogDBObjects can refer to 

503 get_arbitrary_chunk_iterator and DBObjects can refer to 

504 get_chunk_iterator 

505 """ 

506 return self.get_chunk_iterator(query, chunk_size = chunk_size, dtype = dtype) 

507 

508 def get_chunk_iterator(self, query, chunk_size = None, dtype = None): 

509 """ 

510 Take an arbitrary, user-specified query and return a ChunkIterator that 

511 executes that query 

512 

513 dtype will tell the ChunkIterator what datatype to expect for this query. 

514 This information gets passed to _postprocess_results. 

515 

516 If 'None', then _postprocess_results will just guess the datatype 

517 and return generic names for the columns. 

518 """ 

519 self.dtype = dtype 

520 return ChunkIterator(self, query, chunk_size, arbitrarySQL = True) 

521 

522class CatalogDBObjectMeta(type): 

523 """Meta class for registering new objects. 

524 

525 When any new type of object class is created, this registers it 

526 in a `registry` class attribute, available to all derived instance 

527 catalog. 

528 """ 

529 def __new__(cls, name, bases, dct): 

530 # check if attribute objid is specified. 

531 # If not, create a default 

532 if 'registry' in dct: 532 ↛ 533line 532 didn't jump to line 533, because the condition on line 532 was never true

533 warnings.warn("registry class attribute should not be " 

534 "over-ridden in InstanceCatalog classes. " 

535 "Proceed with caution") 

536 if 'objid' not in dct: 

537 dct['objid'] = name 

538 return super(CatalogDBObjectMeta, cls).__new__(cls, name, bases, dct) 

539 

540 def __init__(cls, name, bases, dct): 

541 # check if 'registry' is specified. 

542 # if not, then this is the base class: add the registry 

543 if not hasattr(cls, 'registry'): 

544 cls.registry = {} 

545 else: 

546 if not cls.skipRegistration: 546 ↛ 558line 546 didn't jump to line 558, because the condition on line 546 was never false

547 # add this class to the registry 

548 if cls.objid in cls.registry: 548 ↛ 549line 548 didn't jump to line 549, because the condition on line 548 was never true

549 srcfile = inspect.getsourcefile(cls.registry[cls.objid]) 

550 srcline = inspect.getsourcelines(cls.registry[cls.objid])[1] 

551 warnings.warn('duplicate object identifier %s specified. '%(cls.objid)+\ 

552 'This will override previous definition on line %i of %s'% 

553 (srcline, srcfile)) 

554 cls.registry[cls.objid] = cls 

555 

556 # check if the list of unique ids is specified 

557 # if not, then this is the base class: add the list 

558 if not hasattr(cls, 'objectTypeIdList'): 

559 cls.objectTypeIdList = [] 

560 else: 

561 if cls.skipRegistration: 561 ↛ 562line 561 didn't jump to line 562, because the condition on line 561 was never true

562 pass 

563 elif cls.objectTypeId is None: 

564 pass #Don't add typeIds that are None 

565 elif cls.objectTypeId in cls.objectTypeIdList: 

566 warnings.warn('Duplicate object type id %s specified: '%cls.objectTypeId+\ 

567 '\nOutput object ids may not be unique.\nThis may not be a problem if you do not '+\ 

568 'want globally unique id values') 

569 else: 

570 cls.objectTypeIdList.append(cls.objectTypeId) 

571 return super(CatalogDBObjectMeta, cls).__init__(name, bases, dct) 

572 

573 def __str__(cls): 

574 dbObjects = cls.registry.keys() 

575 outstr = "++++++++++++++++++++++++++++++++++++++++++++++\n"+\ 

576 "Registered object types are:\n" 

577 for dbObject in dbObjects: 

578 outstr += "%s\n"%(dbObject) 

579 outstr += "\n\n" 

580 outstr += "To query the possible column names do:\n" 

581 outstr += "$> CatalogDBObject.from_objid([name]).show_mapped_columns()\n" 

582 outstr += "+++++++++++++++++++++++++++++++++++++++++++++" 

583 return outstr 

584 

585class CatalogDBObject(with_metaclass(CatalogDBObjectMeta, DBObject)): 

586 """Database Object base class 

587 

588 """ 

589 

590 epoch = 2000.0 

591 skipRegistration = False 

592 objid = None 

593 tableid = None 

594 idColKey = None 

595 objectTypeId = None 

596 columns = None 

597 generateDefaultColumnMap = True 

598 dbDefaultValues = {} 

599 raColName = None 

600 decColName = None 

601 

602 _connection_cache = [] # a list to store open database connections in 

603 

604 #Provide information if this object should be tested in the unit test 

605 doRunTest = False 

606 testObservationMetaData = None 

607 

608 #: Mapping of DDL types to python types. Strings are assumed to be 256 characters 

609 #: this can be overridden by modifying the dbTypeMap or by making a custom columns 

610 #: list. 

611 #: numpy doesn't know how to convert decimal.Decimal types, so I changed this to float 

612 #: TODO this doesn't seem to make a difference but make sure. 

613 dbTypeMap = {'BIGINT':(int,), 'BOOLEAN':(bool,), 'FLOAT':(float,), 'INTEGER':(int,), 

614 'NUMERIC':(float,), 'SMALLINT':(int,), 'TINYINT':(int,), 'VARCHAR':(str, 256), 

615 'TEXT':(str, 256), 'CLOB':(str, 256), 'NVARCHAR':(str, 256), 

616 'NCLOB':(str, 256), 'NTEXT':(str, 256), 'CHAR':(str, 1), 'INT':(int,), 

617 'REAL':(float,), 'DOUBLE':(float,), 'STRING':(str, 256), 'DOUBLE_PRECISION':(float,), 

618 'DECIMAL':(float,)} 

619 

620 @classmethod 

621 def from_objid(cls, objid, *args, **kwargs): 

622 """Given a string objid, return an instance of 

623 the appropriate CatalogDBObject class. 

624 """ 

625 if objid not in cls.registry: 

626 raise RuntimeError('Attempting to construct an object that does not exist') 

627 cls = cls.registry.get(objid, CatalogDBObject) 

628 return cls(*args, **kwargs) 

629 

630 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False, 

631 table=None, objid=None, idColKey=None, connection=None, 

632 cache_connection=True): 

633 if not verbose: 

634 with warnings.catch_warnings(): 

635 warnings.simplefilter("ignore", category=sa_exc.SAWarning) 

636 

637 if self.tableid is not None and table is not None: 

638 raise ValueError("Double-specified tableid in CatalogDBObject:" 

639 " once in class definition, once in __init__") 

640 

641 if table is not None: 

642 self.tableid = table 

643 

644 if self.objid is not None and objid is not None: 

645 raise ValueError("Double-specified objid in CatalogDBObject:" 

646 " once in class definition, once in __init__") 

647 

648 if objid is not None: 

649 self.objid = objid 

650 

651 if self.idColKey is not None and idColKey is not None: 

652 raise ValueError("Double-specified idColKey in CatalogDBObject:" 

653 " once in class definition, once in __init__") 

654 

655 if idColKey is not None: 

656 self.idColKey = idColKey 

657 

658 if self.idColKey is None: 

659 self.idColKey = self.getIdColKey() 

660 if (self.objid is None) or (self.tableid is None) or (self.idColKey is None): 

661 msg = ("CatalogDBObject must be subclassed, and " 

662 "define objid, tableid and idColKey. You are missing: ") 

663 if self.objid is None: 

664 msg += "objid, " 

665 if self.tableid is None: 

666 msg += "tableid, " 

667 if self.idColKey is None: 

668 msg += "idColKey" 

669 raise ValueError(msg) 

670 

671 if (self.objectTypeId is None) and verbose: 

672 warnings.warn("objectTypeId has not " 

673 "been set. Input files for phosim are not " 

674 "possible.") 

675 

676 super(CatalogDBObject, self).__init__(database=database, driver=driver, host=host, port=port, 

677 verbose=verbose, connection=connection, cache_connection=True) 

678 

679 try: 

680 self._get_table() 

681 except sa_exc.OperationalError as e: 

682 if self.driver == 'mssql+pymssql': 

683 message = "\n To connect to the UW CATSIM database: " 

684 message += " Check that you have valid connection parameters, an open ssh tunnel " 

685 message += "and that your $HOME/.lsst/db-auth.paf contains the appropriate credientials. " 

686 message += "Please consult the following link for more information on access: " 

687 message += " https://confluence.lsstcorp.org/display/SIM/Accessing+the+UW+CATSIM+Database " 

688 else: 

689 message = '' 

690 raise RuntimeError("Failed to connect to %s: sqlalchemy.%s %s" % (self.connection.engine, e.args[0], message)) 

691 

692 #Need to do this after the table is instantiated so that 

693 #the default columns can be filled from the table object. 

694 if self.generateDefaultColumnMap: 

695 self._make_default_columns() 

696 # build column mapping and type mapping dicts from columns 

697 self._make_column_map() 

698 self._make_type_map() 

699 

700 def show_mapped_columns(self): 

701 for col in self.columnMap.keys(): 

702 print("%s -- %s"%(col, self.typeMap[col][0].__name__)) 

703 

704 def show_db_columns(self): 

705 for col in self.table.c.keys(): 

706 print("%s -- %s"%(col, self.table.c[col].type.__visit_name__)) 

707 

708 

709 def getCatalog(self, ftype, *args, **kwargs): 

710 try: 

711 from lsst.sims.catalogs.definitions import InstanceCatalog 

712 return InstanceCatalog.new_catalog(ftype, self, *args, **kwargs) 

713 except ImportError: 

714 raise ImportError("sims_catalogs not set up. Cannot get InstanceCatalog from the object.") 

715 

716 def getIdColKey(self): 

717 return self.idColKey 

718 

719 def getObjectTypeId(self): 

720 return self.objectTypeId 

721 

722 def _get_table(self): 

723 self.table = Table(self.tableid, self.connection.metadata, 

724 autoload=True) 

725 

726 def _make_column_map(self): 

727 self.columnMap = OrderedDict([(el[0], el[1] if el[1] else el[0]) 

728 for el in self.columns]) 

729 def _make_type_map(self): 

730 self.typeMap = OrderedDict([(el[0], el[2:] if len(el)> 2 else (float,)) 

731 for el in self.columns]) 

732 

733 def _make_default_columns(self): 

734 if self.columns: 

735 colnames = [el[0] for el in self.columns] 

736 else: 

737 self.columns = [] 

738 colnames = [] 

739 for col in self.table.c.keys(): 

740 dbtypestr = self.table.c[col].type.__visit_name__ 

741 dbtypestr = dbtypestr.upper() 

742 if col in colnames: 

743 if self.verbose: #Warn for possible column redefinition 

744 warnings.warn("Database column, %s, overridden in self.columns... "%(col)+ 

745 "Skipping default assignment.") 

746 elif dbtypestr in self.dbTypeMap: 

747 self.columns.append((col, col)+self.dbTypeMap[dbtypestr]) 

748 else: 

749 if self.verbose: 

750 warnings.warn("Can't create default column for %s. There is no mapping "%(col)+ 

751 "for type %s. Modify the dbTypeMap, or make a custom columns "%(dbtypestr)+ 

752 "list.") 

753 

754 def _get_column_query(self, colnames=None): 

755 """Given a list of valid column names, return the query object""" 

756 if colnames is None: 

757 colnames = [k for k in self.columnMap] 

758 try: 

759 vals = [self.columnMap[k] for k in colnames] 

760 except KeyError: 

761 offending_columns = '\n' 

762 for col in colnames: 

763 if col in self.columnMap: 

764 continue 

765 else: 

766 offending_columns +='%s\n' % col 

767 raise ValueError('entries in colnames must be in self.columnMap. ' 

768 'These:%sare not' % offending_columns) 

769 

770 # Get the first query 

771 idColName = self.columnMap[self.idColKey] 

772 if idColName in vals: 

773 idLabel = self.idColKey 

774 else: 

775 idLabel = idColName 

776 

777 query = self.connection.session.query(self.table.c[idColName].label(idLabel)) 

778 

779 for col, val in zip(colnames, vals): 

780 if val is idColName: 

781 continue 

782 #Check if the column is a default column (col == val) 

783 if col == val: 

784 #If column is in the table, use it. 

785 query = query.add_columns(self.table.c[col].label(col)) 

786 else: 

787 #If not assume the user specified the column correctly 

788 query = query.add_columns(expression.literal_column(val).label(col)) 

789 

790 return query 

791 

792 def filter(self, query, bounds): 

793 """Filter the query by the associated metadata""" 

794 if bounds is not None: 

795 on_clause = bounds.to_SQL(self.raColName,self.decColName) 

796 query = query.filter(text(on_clause)) 

797 return query 

798 

799 def _convert_results_to_numpy_recarray_catalogDBObj(self, results): 

800 """Post-process the query results to put them 

801 in a structured array. 

802 

803 **Parameters** 

804 

805 * results : a result set as returned by execution of the query 

806 

807 **Returns** 

808 

809 * _final_pass(retresults) : the result of calling the _final_pass method on a 

810 structured array constructed from the query data. 

811 """ 

812 

813 if len(results) > 0: 

814 cols = [str(k) for k in results[0].keys()] 

815 else: 

816 return results 

817 

818 if sys.version_info.major == 2: 

819 dt_list = [] 

820 for k in cols: 

821 sub_list = [past_str(k)] 

822 if self.typeMap[k][0] is not str: 

823 for el in self.typeMap[k]: 

824 sub_list.append(el) 

825 else: 

826 sub_list.append(past_str) 

827 for el in self.typeMap[k][1:]: 

828 sub_list.append(el) 

829 dt_list.append(tuple(sub_list)) 

830 

831 dtype = numpy.dtype(dt_list) 

832 

833 else: 

834 dtype = numpy.dtype([(k,)+self.typeMap[k] for k in cols]) 

835 

836 if len(set(cols)&set(self.dbDefaultValues)) > 0: 

837 

838 results_array = [] 

839 

840 for result in results: 

841 results_array.append(tuple(result[colName] 

842 if result[colName] or 

843 colName not in self.dbDefaultValues 

844 else self.dbDefaultValues[colName] 

845 for colName in cols)) 

846 

847 else: 

848 results_array = [tuple(rr) for rr in results] 

849 

850 retresults = numpy.rec.fromrecords(results_array, dtype=dtype) 

851 return retresults 

852 

853 def _postprocess_results(self, results): 

854 if not isinstance(results, numpy.recarray): 

855 retresults = self._convert_results_to_numpy_recarray_catalogDBObj(results) 

856 else: 

857 retresults = results 

858 return self._final_pass(retresults) 

859 

860 def query_columns(self, colnames=None, chunk_size=None, 

861 obs_metadata=None, constraint=None, limit=None): 

862 """Execute a query 

863 

864 **Parameters** 

865 

866 * colnames : list or None 

867 a list of valid column names, corresponding to entries in the 

868 `columns` class attribute. If not specified, all columns are 

869 queried. 

870 * chunk_size : int (optional) 

871 if specified, then return an iterator object to query the database, 

872 each time returning the next `chunk_size` elements. If not 

873 specified, all matching results will be returned. 

874 * obs_metadata : object (optional) 

875 an observation metadata object which has a "filter" method, which 

876 will add a filter string to the query. 

877 * constraint : str (optional) 

878 a string which is interpreted as SQL and used as a predicate on the query 

879 * limit : int (optional) 

880 limits the number of rows returned by the query 

881 

882 **Returns** 

883 

884 * result : list or iterator 

885 If chunk_size is not specified, then result is a list of all 

886 items which match the specified query. If chunk_size is specified, 

887 then result is an iterator over lists of the given size. 

888 

889 """ 

890 query = self._get_column_query(colnames) 

891 

892 if obs_metadata is not None: 

893 query = self.filter(query, obs_metadata.bounds) 

894 

895 if constraint is not None: 

896 query = query.filter(text(constraint)) 

897 

898 if limit is not None: 

899 query = query.limit(limit) 

900 

901 return ChunkIterator(self, query, chunk_size) 

902 

903sims_clean_up.targets.append(CatalogDBObject._connection_cache) 

904 

905class fileDBObject(CatalogDBObject): 

906 ''' Class to read a file into a database and then query it''' 

907 #Column names to index. Specify compound indexes using tuples of column names 

908 indexCols = [] 

909 def __init__(self, dataLocatorString, runtable=None, driver="sqlite", host=None, port=None, database=":memory:", 

910 dtype=None, numGuess=1000, delimiter=None, verbose=False, idColKey=None, **kwargs): 

911 """ 

912 Initialize an object for querying databases loaded from a file 

913 

914 Keyword arguments: 

915 @param dataLocatorString: Path to the file to load 

916 @param runtable: The name of the table to create. If None, a random table name will be used. 

917 @param driver: name of database driver (e.g. 'sqlite', 'mssql+pymssql') 

918 @param host: hostname for database connection (None if sqlite) 

919 @param port: port for database connection (None if sqlite) 

920 @param database: name of database (filename if sqlite) 

921 @param dtype: The numpy dtype to use when loading the file. If None, it the dtype will be guessed. 

922 @param numGuess: The number of lines to use in guessing the dtype from the file. 

923 @param delimiter: The delimiter to use when parsing the file default is white space. 

924 @param idColKey: The name of the column that uniquely identifies each row in the database 

925 """ 

926 self.verbose = verbose 

927 

928 if idColKey is not None: 

929 self.idColKey = idColKey 

930 

931 if(self.objid is None) or (self.idColKey is None): 

932 raise ValueError("CatalogDBObject must be subclassed, and " 

933 "define objid and tableid and idColKey.") 

934 

935 if (self.objectTypeId is None) and self.verbose: 

936 warnings.warn("objectTypeId has not " 

937 "been set. Input files for phosim are not " 

938 "possible.") 

939 

940 if os.path.exists(dataLocatorString): 

941 self.driver = driver 

942 self.host = host 

943 self.port = port 

944 self.database = database 

945 self.connection = DBConnection(database=self.database, driver=self.driver, host=self.host, 

946 port=self.port, verbose=verbose) 

947 self.tableid = loadData(dataLocatorString, dtype, delimiter, runtable, self.idColKey, 

948 self.connection.engine, self.connection.metadata, numGuess, 

949 indexCols=self.indexCols, **kwargs) 

950 self._get_table() 

951 else: 

952 raise ValueError("Could not locate file %s."%(dataLocatorString)) 

953 

954 if self.generateDefaultColumnMap: 

955 self._make_default_columns() 

956 

957 self._make_column_map() 

958 self._make_type_map() 

959 

960 @classmethod 

961 def from_objid(cls, objid, *args, **kwargs): 

962 """Given a string objid, return an instance of 

963 the appropriate fileDBObject class. 

964 """ 

965 cls = cls.registry.get(objid, CatalogDBObject) 

966 return cls(*args, **kwargs)