Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import print_function 

2from future import standard_library 

3standard_library.install_aliases() 

4import sys 

5from builtins import str 

6 

7# 2017 March 9 

8# str_cast exists because numpy.dtype does 

9# not like unicode-like things as the names 

10# of columns. Unfortunately, in python 2, 

11# builtins.str looks unicode-like. We will 

12# use str_cast in python 2 to maintain 

13# both python 3 compatibility and our use of 

14# numpy dtype 

15str_cast = str 

16if sys.version_info.major == 2: 16 ↛ 17line 16 didn't jump to line 17, because the condition on line 16 was never true

17 from past.builtins import str as past_str 

18 str_cast = past_str 

19 

20from builtins import zip 

21from builtins import object 

22import warnings 

23import numpy 

24import os 

25import inspect 

26from io import BytesIO 

27from collections import OrderedDict 

28 

29from .utils import loadData 

30from sqlalchemy.orm import scoped_session, sessionmaker 

31from sqlalchemy.sql import expression 

32from sqlalchemy.engine import reflection, url 

33from sqlalchemy import (create_engine, MetaData, 

34 Table, event, text) 

35from sqlalchemy import exc as sa_exc 

36from lsst.daf.persistence import DbAuth 

37from lsst.sims.utils.CodeUtilities import sims_clean_up 

38 

39#The documentation at http://docs.sqlalchemy.org/en/rel_0_7/core/types.html#sqlalchemy.types.Numeric 

40#suggests using the cdecimal module. Since it is not standard, import decimal. 

41#TODO: test for cdecimal and use it if it exists. 

42import decimal 

43from future.utils import with_metaclass 

44 

45__all__ = ["ChunkIterator", "DBObject", "CatalogDBObject", "fileDBObject"] 

46 

47def valueOfPi(): 

48 """ 

49 A function to return the value of pi. This is needed for adding PI() 

50 to sqlite databases 

51 """ 

52 return numpy.pi 

53 

54def declareTrigFunctions(conn,connection_rec,connection_proxy): 

55 """ 

56 A database event listener 

57 which will define the math functions necessary for evaluating the 

58 Haversine function in sqlite databases (where they are not otherwise 

59 defined) 

60 

61 see: http://docs.sqlalchemy.org/en/latest/core/events.html 

62 """ 

63 

64 conn.create_function("COS",1,numpy.cos) 

65 conn.create_function("SIN",1,numpy.sin) 

66 conn.create_function("ASIN",1,numpy.arcsin) 

67 conn.create_function("SQRT",1,numpy.sqrt) 

68 conn.create_function("POWER",2,numpy.power) 

69 conn.create_function("PI",0,valueOfPi) 

70 

71#------------------------------------------------------------ 

72# Iterator for database chunks 

73 

74class ChunkIterator(object): 

75 """Iterator for query chunks""" 

76 def __init__(self, dbobj, query, chunk_size, arbitrarySQL = False): 

77 self.dbobj = dbobj 

78 self.exec_query = dbobj.connection.session.execute(query) 

79 self.chunk_size = chunk_size 

80 

81 #arbitrarySQL exists in case a CatalogDBObject calls 

82 #get_arbitrary_chunk_iterator; in that case, we need to 

83 #be able to tell this object to call _postprocess_arbitrary_results, 

84 #rather than _postprocess_results 

85 self.arbitrarySQL = arbitrarySQL 

86 

87 def __iter__(self): 

88 return self 

89 

90 def __next__(self): 

91 if self.chunk_size is None and not self.exec_query.closed: 

92 chunk = self.exec_query.fetchall() 

93 return self._postprocess_results(chunk) 

94 elif self.chunk_size is not None: 

95 chunk = self.exec_query.fetchmany(self.chunk_size) 

96 return self._postprocess_results(chunk) 

97 else: 

98 raise StopIteration 

99 

100 def _postprocess_results(self, chunk): 

101 if len(chunk)==0: 

102 raise StopIteration 

103 if self.arbitrarySQL: 

104 return self.dbobj._postprocess_arbitrary_results(chunk) 

105 else: 

106 return self.dbobj._postprocess_results(chunk) 

107 

108 

109class DBConnection(object): 

110 """ 

111 This is a class that will hold the engine, session, and metadata for a 

112 DBObject. This will allow multiple DBObjects to share the same 

113 sqlalchemy connection, when appropriate. 

114 """ 

115 

116 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False): 

117 """ 

118 @param [in] database is the name of the database file being connected to 

119 

120 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

121 

122 @param [in] host is the URL of the remote host, if appropriate 

123 

124 @param [in] port is the port on the remote host to connect to, if appropriate 

125 

126 @param [in] verbose is a boolean controlling sqlalchemy's verbosity 

127 """ 

128 

129 self._database = database 

130 self._driver = driver 

131 self._host = host 

132 self._port = port 

133 self._verbose = verbose 

134 

135 self._validate_conn_params() 

136 self._connect_to_engine() 

137 

138 def __del__(self): 

139 try: 

140 del self._metadata 

141 except AttributeError: 

142 pass 

143 

144 try: 

145 del self._engine 

146 except AttributeError: 

147 pass 

148 

149 try: 

150 del self._session 

151 except AttributeError: 

152 pass 

153 

154 def _connect_to_engine(self): 

155 

156 #DbAuth will not look up hosts that are None, '' or 0 

157 if self._host: 

158 try: 

159 authDict = {'username': DbAuth.username(self._host, str(self._port)), 

160 'password': DbAuth.password(self._host, str(self._port))} 

161 except: 

162 if self._driver == 'mssql+pymssql': 

163 print("\nFor more information on database authentication using the db-auth.paf" 

164 " policy file see: " 

165 "https://confluence.lsstcorp.org/display/SIM/Accessing+the+UW+CATSIM+Database\n") 

166 raise 

167 

168 dbUrl = url.URL(self._driver, 

169 host=self._host, 

170 port=self._port, 

171 database=self._database, 

172 **authDict) 

173 else: 

174 dbUrl = url.URL(self._driver, 

175 database=self._database) 

176 

177 

178 self._engine = create_engine(dbUrl, echo=self._verbose) 

179 

180 if self._engine.dialect.name == 'sqlite': 

181 event.listen(self._engine, 'checkout', declareTrigFunctions) 

182 

183 self._session = scoped_session(sessionmaker(autoflush=True, 

184 bind=self._engine)) 

185 self._metadata = MetaData(bind=self._engine) 

186 

187 

188 def _validate_conn_params(self): 

189 """Validate connection parameters 

190 

191 - Check if user passed dbAddress instead of an database. Convert and warn. 

192 - Check that required connection paramters are present 

193 - Replace default host/port if driver is 'sqlite' 

194 """ 

195 

196 if self._database is None: 

197 raise AttributeError("Cannot instantiate DBConnection; database is 'None'") 

198 

199 if '//' in self._database: 

200 warnings.warn("Database name '%s' is invalid but looks like a dbAddress. " 

201 "Attempting to convert to database, driver, host, " 

202 "and port parameters. Any usernames and passwords are ignored and must " 

203 "be in the db-auth.paf policy file. "%(self.database), FutureWarning) 

204 

205 dbUrl = url.make_url(self._database) 

206 dialect = dbUrl.get_dialect() 

207 self._driver = dialect.name + '+' + dialect.driver if dialect.driver else dialect.name 

208 for key, value in dbUrl.translate_connect_args().items(): 

209 if value is not None: 

210 setattr(self, '_'+key, value) 

211 

212 errMessage = "Please supply a 'driver' kwarg to the constructor or in class definition. " 

213 errMessage += "'driver' is formatted as dialect+driver, such as 'sqlite' or 'mssql+pymssql'." 

214 if not hasattr(self, '_driver'): 

215 raise AttributeError("%s has no attribute 'driver'. "%(self.__class__.__name__) + errMessage) 

216 elif self._driver is None: 

217 raise AttributeError("%s.driver is None. "%(self.__class__.__name__) + errMessage) 

218 

219 errMessage = "Please supply a 'database' kwarg to the constructor or in class definition. " 

220 errMessage += " 'database' is the database name or the filename path if driver is 'sqlite'. " 

221 if not hasattr(self, '_database'): 

222 raise AttributeError("%s has no attribute 'database'. "%(self.__class__.__name__) + errMessage) 

223 elif self._database is None: 

224 raise AttributeError("%s.database is None. "%(self.__class__.__name__) + errMessage) 

225 

226 if 'sqlite' in self._driver: 

227 #When passed sqlite database, override default host/port 

228 self._host = None 

229 self._port = None 

230 

231 

232 def __eq__(self, other): 

233 return (str(self._database) == str(other._database)) and \ 

234 (str(self._driver) == str(other._driver)) and \ 

235 (str(self._host) == str(other._host)) and \ 

236 (str(self._port) == str(other._port)) 

237 

238 

239 @property 

240 def engine(self): 

241 return self._engine 

242 

243 @property 

244 def session(self): 

245 return self._session 

246 

247 

248 @property 

249 def metadata(self): 

250 return self._metadata 

251 

252 @property 

253 def database(self): 

254 return self._database 

255 

256 @property 

257 def driver(self): 

258 return self._driver 

259 

260 @property 

261 def host(self): 

262 return self._host 

263 

264 @property 

265 def port(self): 

266 return self._port 

267 

268 @property 

269 def verbose(self): 

270 return self._verbose 

271 

272 

273class DBObject(object): 

274 

275 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False, 

276 connection=None, cache_connection=True): 

277 """ 

278 Initialize DBObject. 

279 

280 @param [in] database is the name of the database file being connected to 

281 

282 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

283 

284 @param [in] host is the URL of the remote host, if appropriate 

285 

286 @param [in] port is the port on the remote host to connect to, if appropriate 

287 

288 @param [in] verbose is a boolean controlling sqlalchemy's verbosity (default False) 

289 

290 @param [in] connection is an optional instance of DBConnection, in the event that 

291 this DBObject can share a database connection with another DBObject. This is only 

292 necessary or even possible in a few specialized cases and should be used carefully. 

293 

294 @param [in] cache_connection is a boolean. If True, DBObject will use a cache of 

295 DBConnections (if available) to get the connection to this database. 

296 """ 

297 

298 self.dtype = None 

299 #this is a cache for the query, so that any one query does not have to guess dtype multiple times 

300 

301 if connection is None: 

302 #Explicit constructor to DBObject preferred 

303 kwargDict = dict(database=database, 

304 driver=driver, 

305 host=host, 

306 port=port, 

307 verbose=verbose) 

308 

309 for key, value in kwargDict.items(): 

310 if value is not None or not hasattr(self, key): 

311 setattr(self, key, value) 

312 

313 self.connection = self._get_connection(self.database, self.driver, self.host, self.port, 

314 use_cache=cache_connection) 

315 

316 else: 

317 self.connection = connection 

318 self.database = connection.database 

319 self.driver = connection.driver 

320 self.host = connection.host 

321 self.port = connection.port 

322 self.verbose = connection.verbose 

323 

324 def _get_connection(self, database, driver, host, port, use_cache=True): 

325 """ 

326 Search self._connection_cache (if it exists; it won't for DBObject, but 

327 will for CatalogDBObject) for a DBConnection matching the specified 

328 parameters. If it exists, return it. If not, open a connection to 

329 the specified database, add it to the cache, and return the connection. 

330 

331 Parameters 

332 ---------- 

333 database is the name of the database file being connected to 

334 

335 driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) 

336 

337 host is the URL of the remote host, if appropriate 

338 

339 port is the port on the remote host to connect to, if appropriate 

340 

341 use_cache is a boolean specifying whether or not we try to use the 

342 cache of database connections (you don't want to if opening many 

343 connections in many threads). 

344 """ 

345 

346 if use_cache and hasattr(self, '_connection_cache'): 

347 for conn in self._connection_cache: 

348 if str(conn.database) == str(database): 

349 if str(conn.driver) == str(driver): 

350 if str(conn.host) == str(host): 

351 if str(conn.port) == str(port): 

352 return conn 

353 

354 conn = DBConnection(database=database, driver=driver, host=host, port=port) 

355 

356 if use_cache and hasattr(self, '_connection_cache'): 

357 self._connection_cache.append(conn) 

358 

359 return conn 

360 

361 def get_table_names(self): 

362 """Return a list of the names of the tables (and views) in the database""" 

363 return [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_table_names()] + \ 

364 [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_view_names()] 

365 

366 def get_column_names(self, tableName=None): 

367 """ 

368 Return a list of the names of the columns in the specified table. 

369 If no table is specified, return a dict of lists. The dict will be keyed 

370 to the table names. The lists will be of the column names in that table 

371 """ 

372 tableNameList = self.get_table_names() 

373 if tableName is not None: 

374 if tableName not in tableNameList: 

375 return [] 

376 return [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(tableName)] 

377 else: 

378 columnDict = {} 

379 for name in tableNameList: 

380 columnList = [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(name)] 

381 columnDict[name] = columnList 

382 return columnDict 

383 

384 def _final_pass(self, results): 

385 """ Make final modifications to a set of data before returning it to the user 

386 

387 **Parameters** 

388 

389 * results : a structured array constructed from the result set from a query 

390 

391 **Returns** 

392 

393 * results : a potentially modified structured array. The default is to do nothing. 

394 

395 """ 

396 return results 

397 

398 def _convert_results_to_numpy_recarray_dbobj(self, results): 

399 if self.dtype is None: 

400 """ 

401 Determine the dtype from the data. 

402 Store it in a global variable so we do not have to repeat on every chunk. 

403 """ 

404 dataString = '' 

405 

406 # We are going to detect the dtype by reading in a single row 

407 # of data with np.genfromtxt. To do this, we must pass the 

408 # row as a string delimited by a specified character. Here we 

409 # select a character that does not occur anywhere in the data. 

410 delimit_char_list = [',', ';', '|', ':', '/', '\\'] 

411 delimit_char = None 

412 for cc in delimit_char_list: 

413 is_valid = True 

414 for xx in results[0]: 

415 if cc in str(xx): 

416 is_valid = False 

417 break 

418 

419 if is_valid: 

420 delimit_char = cc 

421 break 

422 

423 if delimit_char is None: 

424 raise RuntimeError("DBObject could not detect the dtype of your return rows\n" 

425 "Please specify a dtype with the 'dtype' kwarg.") 

426 

427 for xx in results[0]: 

428 if dataString is not '': 

429 dataString += delimit_char 

430 dataString += str(xx) 

431 names = [str_cast(ww) for ww in results[0].keys()] 

432 dataArr = numpy.genfromtxt(BytesIO(dataString.encode()), dtype=None, 

433 names=names, delimiter=delimit_char, 

434 encoding='utf-8') 

435 dt_list = [] 

436 for name in dataArr.dtype.names: 

437 type_name = str(dataArr.dtype[name]) 

438 sub_list = [name] 

439 if type_name.startswith('S') or type_name.startswith('|S'): 

440 sub_list.append(str_cast) 

441 sub_list.append(int(type_name.replace('S','').replace('|',''))) 

442 else: 

443 sub_list.append(dataArr.dtype[name]) 

444 dt_list.append(tuple(sub_list)) 

445 

446 self.dtype = numpy.dtype(dt_list) 

447 

448 if len(results) == 0: 

449 return numpy.recarray((0,), dtype = self.dtype) 

450 

451 retresults = numpy.rec.fromrecords([tuple(xx) for xx in results],dtype = self.dtype) 

452 return retresults 

453 

454 def _postprocess_results(self, results): 

455 """ 

456 This wrapper exists so that a ChunkIterator built from a DBObject 

457 can have the same API as a ChunkIterator built from a CatalogDBObject 

458 """ 

459 return self._postprocess_arbitrary_results(results) 

460 

461 

462 def _postprocess_arbitrary_results(self, results): 

463 

464 if not isinstance(results, numpy.recarray): 

465 retresults = self._convert_results_to_numpy_recarray_dbobj(results) 

466 else: 

467 retresults = results 

468 

469 return self._final_pass(retresults) 

470 

471 def execute_arbitrary(self, query, dtype = None): 

472 """ 

473 Executes an arbitrary query. Returns a recarray of the results. 

474 

475 dtype will be the dtype of the output recarray. If it is None, then 

476 the code will guess the datatype and assign generic names to the columns 

477 """ 

478 

479 try: 

480 is_string = isinstance(query, basestring) 

481 except: 

482 is_string = isinstance(query, str) 

483 

484 if not is_string: 

485 raise RuntimeError("DBObject execute must be called with a string query") 

486 

487 unacceptableCommands = ["delete","drop","insert","update"] 

488 for badCommand in unacceptableCommands: 

489 if query.lower().find(badCommand.lower())>=0: 

490 raise RuntimeError("query made to DBObject execute contained %s " % badCommand) 

491 

492 self.dtype = dtype 

493 retresults = self._postprocess_arbitrary_results(self.connection.session.execute(query).fetchall()) 

494 return retresults 

495 

496 def get_arbitrary_chunk_iterator(self, query, chunk_size = None, dtype =None): 

497 """ 

498 This wrapper exists so that CatalogDBObjects can refer to 

499 get_arbitrary_chunk_iterator and DBObjects can refer to 

500 get_chunk_iterator 

501 """ 

502 return self.get_chunk_iterator(query, chunk_size = chunk_size, dtype = dtype) 

503 

504 def get_chunk_iterator(self, query, chunk_size = None, dtype = None): 

505 """ 

506 Take an arbitrary, user-specified query and return a ChunkIterator that 

507 executes that query 

508 

509 dtype will tell the ChunkIterator what datatype to expect for this query. 

510 This information gets passed to _postprocess_results. 

511 

512 If 'None', then _postprocess_results will just guess the datatype 

513 and return generic names for the columns. 

514 """ 

515 self.dtype = dtype 

516 return ChunkIterator(self, query, chunk_size, arbitrarySQL = True) 

517 

518class CatalogDBObjectMeta(type): 

519 """Meta class for registering new objects. 

520 

521 When any new type of object class is created, this registers it 

522 in a `registry` class attribute, available to all derived instance 

523 catalog. 

524 """ 

525 def __new__(cls, name, bases, dct): 

526 # check if attribute objid is specified. 

527 # If not, create a default 

528 if 'registry' in dct: 528 ↛ 529line 528 didn't jump to line 529, because the condition on line 528 was never true

529 warnings.warn("registry class attribute should not be " 

530 "over-ridden in InstanceCatalog classes. " 

531 "Proceed with caution") 

532 if 'objid' not in dct: 

533 dct['objid'] = name 

534 return super(CatalogDBObjectMeta, cls).__new__(cls, name, bases, dct) 

535 

536 def __init__(cls, name, bases, dct): 

537 # check if 'registry' is specified. 

538 # if not, then this is the base class: add the registry 

539 if not hasattr(cls, 'registry'): 

540 cls.registry = {} 

541 else: 

542 if not cls.skipRegistration: 542 ↛ 554line 542 didn't jump to line 554, because the condition on line 542 was never false

543 # add this class to the registry 

544 if cls.objid in cls.registry: 544 ↛ 545line 544 didn't jump to line 545, because the condition on line 544 was never true

545 srcfile = inspect.getsourcefile(cls.registry[cls.objid]) 

546 srcline = inspect.getsourcelines(cls.registry[cls.objid])[1] 

547 warnings.warn('duplicate object identifier %s specified. '%(cls.objid)+\ 

548 'This will override previous definition on line %i of %s'% 

549 (srcline, srcfile)) 

550 cls.registry[cls.objid] = cls 

551 

552 # check if the list of unique ids is specified 

553 # if not, then this is the base class: add the list 

554 if not hasattr(cls, 'objectTypeIdList'): 

555 cls.objectTypeIdList = [] 

556 else: 

557 if cls.skipRegistration: 557 ↛ 558line 557 didn't jump to line 558, because the condition on line 557 was never true

558 pass 

559 elif cls.objectTypeId is None: 

560 pass #Don't add typeIds that are None 

561 elif cls.objectTypeId in cls.objectTypeIdList: 

562 warnings.warn('Duplicate object type id %s specified: '%cls.objectTypeId+\ 

563 '\nOutput object ids may not be unique.\nThis may not be a problem if you do not '+\ 

564 'want globally unique id values') 

565 else: 

566 cls.objectTypeIdList.append(cls.objectTypeId) 

567 return super(CatalogDBObjectMeta, cls).__init__(name, bases, dct) 

568 

569 def __str__(cls): 

570 dbObjects = cls.registry.keys() 

571 outstr = "++++++++++++++++++++++++++++++++++++++++++++++\n"+\ 

572 "Registered object types are:\n" 

573 for dbObject in dbObjects: 

574 outstr += "%s\n"%(dbObject) 

575 outstr += "\n\n" 

576 outstr += "To query the possible column names do:\n" 

577 outstr += "$> CatalogDBObject.from_objid([name]).show_mapped_columns()\n" 

578 outstr += "+++++++++++++++++++++++++++++++++++++++++++++" 

579 return outstr 

580 

581class CatalogDBObject(with_metaclass(CatalogDBObjectMeta, DBObject)): 

582 """Database Object base class 

583 

584 """ 

585 

586 epoch = 2000.0 

587 skipRegistration = False 

588 objid = None 

589 tableid = None 

590 idColKey = None 

591 objectTypeId = None 

592 columns = None 

593 generateDefaultColumnMap = True 

594 dbDefaultValues = {} 

595 raColName = None 

596 decColName = None 

597 

598 _connection_cache = [] # a list to store open database connections in 

599 

600 #Provide information if this object should be tested in the unit test 

601 doRunTest = False 

602 testObservationMetaData = None 

603 

604 #: Mapping of DDL types to python types. Strings are assumed to be 256 characters 

605 #: this can be overridden by modifying the dbTypeMap or by making a custom columns 

606 #: list. 

607 #: numpy doesn't know how to convert decimal.Decimal types, so I changed this to float 

608 #: TODO this doesn't seem to make a difference but make sure. 

609 dbTypeMap = {'BIGINT':(int,), 'BOOLEAN':(bool,), 'FLOAT':(float,), 'INTEGER':(int,), 

610 'NUMERIC':(float,), 'SMALLINT':(int,), 'TINYINT':(int,), 'VARCHAR':(str, 256), 

611 'TEXT':(str, 256), 'CLOB':(str, 256), 'NVARCHAR':(str, 256), 

612 'NCLOB':(str, 256), 'NTEXT':(str, 256), 'CHAR':(str, 1), 'INT':(int,), 

613 'REAL':(float,), 'DOUBLE':(float,), 'STRING':(str, 256), 'DOUBLE_PRECISION':(float,), 

614 'DECIMAL':(float,)} 

615 

616 @classmethod 

617 def from_objid(cls, objid, *args, **kwargs): 

618 """Given a string objid, return an instance of 

619 the appropriate CatalogDBObject class. 

620 """ 

621 if objid not in cls.registry: 

622 raise RuntimeError('Attempting to construct an object that does not exist') 

623 cls = cls.registry.get(objid, CatalogDBObject) 

624 return cls(*args, **kwargs) 

625 

626 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False, 

627 table=None, objid=None, idColKey=None, connection=None, 

628 cache_connection=True): 

629 if not verbose: 

630 with warnings.catch_warnings(): 

631 warnings.simplefilter("ignore", category=sa_exc.SAWarning) 

632 

633 if self.tableid is not None and table is not None: 

634 raise ValueError("Double-specified tableid in CatalogDBObject:" 

635 " once in class definition, once in __init__") 

636 

637 if table is not None: 

638 self.tableid = table 

639 

640 if self.objid is not None and objid is not None: 

641 raise ValueError("Double-specified objid in CatalogDBObject:" 

642 " once in class definition, once in __init__") 

643 

644 if objid is not None: 

645 self.objid = objid 

646 

647 if self.idColKey is not None and idColKey is not None: 

648 raise ValueError("Double-specified idColKey in CatalogDBObject:" 

649 " once in class definition, once in __init__") 

650 

651 if idColKey is not None: 

652 self.idColKey = idColKey 

653 

654 if self.idColKey is None: 

655 self.idColKey = self.getIdColKey() 

656 if (self.objid is None) or (self.tableid is None) or (self.idColKey is None): 

657 msg = ("CatalogDBObject must be subclassed, and " 

658 "define objid, tableid and idColKey. You are missing: ") 

659 if self.objid is None: 

660 msg += "objid, " 

661 if self.tableid is None: 

662 msg += "tableid, " 

663 if self.idColKey is None: 

664 msg += "idColKey" 

665 raise ValueError(msg) 

666 

667 if (self.objectTypeId is None) and verbose: 

668 warnings.warn("objectTypeId has not " 

669 "been set. Input files for phosim are not " 

670 "possible.") 

671 

672 super(CatalogDBObject, self).__init__(database=database, driver=driver, host=host, port=port, 

673 verbose=verbose, connection=connection, cache_connection=True) 

674 

675 try: 

676 self._get_table() 

677 except sa_exc.OperationalError as e: 

678 if self.driver == 'mssql+pymssql': 

679 message = "\n To connect to the UW CATSIM database: " 

680 message += " Check that you have valid connection parameters, an open ssh tunnel " 

681 message += "and that your $HOME/.lsst/db-auth.paf contains the appropriate credientials. " 

682 message += "Please consult the following link for more information on access: " 

683 message += " https://confluence.lsstcorp.org/display/SIM/Accessing+the+UW+CATSIM+Database " 

684 else: 

685 message = '' 

686 raise RuntimeError("Failed to connect to %s: sqlalchemy.%s %s" % (self.connection.engine, e.args[0], message)) 

687 

688 #Need to do this after the table is instantiated so that 

689 #the default columns can be filled from the table object. 

690 if self.generateDefaultColumnMap: 

691 self._make_default_columns() 

692 # build column mapping and type mapping dicts from columns 

693 self._make_column_map() 

694 self._make_type_map() 

695 

696 def show_mapped_columns(self): 

697 for col in self.columnMap.keys(): 

698 print("%s -- %s"%(col, self.typeMap[col][0].__name__)) 

699 

700 def show_db_columns(self): 

701 for col in self.table.c.keys(): 

702 print("%s -- %s"%(col, self.table.c[col].type.__visit_name__)) 

703 

704 

705 def getCatalog(self, ftype, *args, **kwargs): 

706 try: 

707 from lsst.sims.catalogs.definitions import InstanceCatalog 

708 return InstanceCatalog.new_catalog(ftype, self, *args, **kwargs) 

709 except ImportError: 

710 raise ImportError("sims_catalogs not set up. Cannot get InstanceCatalog from the object.") 

711 

712 def getIdColKey(self): 

713 return self.idColKey 

714 

715 def getObjectTypeId(self): 

716 return self.objectTypeId 

717 

718 def _get_table(self): 

719 self.table = Table(self.tableid, self.connection.metadata, 

720 autoload=True) 

721 

722 def _make_column_map(self): 

723 self.columnMap = OrderedDict([(el[0], el[1] if el[1] else el[0]) 

724 for el in self.columns]) 

725 def _make_type_map(self): 

726 self.typeMap = OrderedDict([(el[0], el[2:] if len(el)> 2 else (float,)) 

727 for el in self.columns]) 

728 

729 def _make_default_columns(self): 

730 if self.columns: 

731 colnames = [el[0] for el in self.columns] 

732 else: 

733 self.columns = [] 

734 colnames = [] 

735 for col in self.table.c.keys(): 

736 dbtypestr = self.table.c[col].type.__visit_name__ 

737 dbtypestr = dbtypestr.upper() 

738 if col in colnames: 

739 if self.verbose: #Warn for possible column redefinition 

740 warnings.warn("Database column, %s, overridden in self.columns... "%(col)+ 

741 "Skipping default assignment.") 

742 elif dbtypestr in self.dbTypeMap: 

743 self.columns.append((col, col)+self.dbTypeMap[dbtypestr]) 

744 else: 

745 if self.verbose: 

746 warnings.warn("Can't create default column for %s. There is no mapping "%(col)+ 

747 "for type %s. Modify the dbTypeMap, or make a custom columns "%(dbtypestr)+ 

748 "list.") 

749 

750 def _get_column_query(self, colnames=None): 

751 """Given a list of valid column names, return the query object""" 

752 if colnames is None: 

753 colnames = [k for k in self.columnMap] 

754 try: 

755 vals = [self.columnMap[k] for k in colnames] 

756 except KeyError: 

757 offending_columns = '\n' 

758 for col in colnames: 

759 if col in self.columnMap: 

760 continue 

761 else: 

762 offending_columns +='%s\n' % col 

763 raise ValueError('entries in colnames must be in self.columnMap. ' 

764 'These:%sare not' % offending_columns) 

765 

766 # Get the first query 

767 idColName = self.columnMap[self.idColKey] 

768 if idColName in vals: 

769 idLabel = self.idColKey 

770 else: 

771 idLabel = idColName 

772 

773 query = self.connection.session.query(self.table.c[idColName].label(idLabel)) 

774 

775 for col, val in zip(colnames, vals): 

776 if val is idColName: 

777 continue 

778 #Check if the column is a default column (col == val) 

779 if col == val: 

780 #If column is in the table, use it. 

781 query = query.add_columns(self.table.c[col].label(col)) 

782 else: 

783 #If not assume the user specified the column correctly 

784 query = query.add_columns(expression.literal_column(val).label(col)) 

785 

786 return query 

787 

788 def filter(self, query, bounds): 

789 """Filter the query by the associated metadata""" 

790 if bounds is not None: 

791 on_clause = bounds.to_SQL(self.raColName,self.decColName) 

792 query = query.filter(text(on_clause)) 

793 return query 

794 

795 def _convert_results_to_numpy_recarray_catalogDBObj(self, results): 

796 """Post-process the query results to put them 

797 in a structured array. 

798 

799 **Parameters** 

800 

801 * results : a result set as returned by execution of the query 

802 

803 **Returns** 

804 

805 * _final_pass(retresults) : the result of calling the _final_pass method on a 

806 structured array constructed from the query data. 

807 """ 

808 

809 if len(results) > 0: 

810 cols = [str(k) for k in results[0].keys()] 

811 else: 

812 return results 

813 

814 if sys.version_info.major == 2: 

815 dt_list = [] 

816 for k in cols: 

817 sub_list = [past_str(k)] 

818 if self.typeMap[k][0] is not str: 

819 for el in self.typeMap[k]: 

820 sub_list.append(el) 

821 else: 

822 sub_list.append(past_str) 

823 for el in self.typeMap[k][1:]: 

824 sub_list.append(el) 

825 dt_list.append(tuple(sub_list)) 

826 

827 dtype = numpy.dtype(dt_list) 

828 

829 else: 

830 dtype = numpy.dtype([(k,)+self.typeMap[k] for k in cols]) 

831 

832 if len(set(cols)&set(self.dbDefaultValues)) > 0: 

833 

834 results_array = [] 

835 

836 for result in results: 

837 results_array.append(tuple(result[colName] 

838 if result[colName] or 

839 colName not in self.dbDefaultValues 

840 else self.dbDefaultValues[colName] 

841 for colName in cols)) 

842 

843 else: 

844 results_array = [tuple(rr) for rr in results] 

845 

846 retresults = numpy.rec.fromrecords(results_array, dtype=dtype) 

847 return retresults 

848 

849 def _postprocess_results(self, results): 

850 if not isinstance(results, numpy.recarray): 

851 retresults = self._convert_results_to_numpy_recarray_catalogDBObj(results) 

852 else: 

853 retresults = results 

854 return self._final_pass(retresults) 

855 

856 def query_columns(self, colnames=None, chunk_size=None, 

857 obs_metadata=None, constraint=None, limit=None): 

858 """Execute a query 

859 

860 **Parameters** 

861 

862 * colnames : list or None 

863 a list of valid column names, corresponding to entries in the 

864 `columns` class attribute. If not specified, all columns are 

865 queried. 

866 * chunk_size : int (optional) 

867 if specified, then return an iterator object to query the database, 

868 each time returning the next `chunk_size` elements. If not 

869 specified, all matching results will be returned. 

870 * obs_metadata : object (optional) 

871 an observation metadata object which has a "filter" method, which 

872 will add a filter string to the query. 

873 * constraint : str (optional) 

874 a string which is interpreted as SQL and used as a predicate on the query 

875 * limit : int (optional) 

876 limits the number of rows returned by the query 

877 

878 **Returns** 

879 

880 * result : list or iterator 

881 If chunk_size is not specified, then result is a list of all 

882 items which match the specified query. If chunk_size is specified, 

883 then result is an iterator over lists of the given size. 

884 

885 """ 

886 query = self._get_column_query(colnames) 

887 

888 if obs_metadata is not None: 

889 query = self.filter(query, obs_metadata.bounds) 

890 

891 if constraint is not None: 

892 query = query.filter(text(constraint)) 

893 

894 if limit is not None: 

895 query = query.limit(limit) 

896 

897 return ChunkIterator(self, query, chunk_size) 

898 

899sims_clean_up.targets.append(CatalogDBObject._connection_cache) 

900 

901class fileDBObject(CatalogDBObject): 

902 ''' Class to read a file into a database and then query it''' 

903 #Column names to index. Specify compound indexes using tuples of column names 

904 indexCols = [] 

905 def __init__(self, dataLocatorString, runtable=None, driver="sqlite", host=None, port=None, database=":memory:", 

906 dtype=None, numGuess=1000, delimiter=None, verbose=False, idColKey=None, **kwargs): 

907 """ 

908 Initialize an object for querying databases loaded from a file 

909 

910 Keyword arguments: 

911 @param dataLocatorString: Path to the file to load 

912 @param runtable: The name of the table to create. If None, a random table name will be used. 

913 @param driver: name of database driver (e.g. 'sqlite', 'mssql+pymssql') 

914 @param host: hostname for database connection (None if sqlite) 

915 @param port: port for database connection (None if sqlite) 

916 @param database: name of database (filename if sqlite) 

917 @param dtype: The numpy dtype to use when loading the file. If None, it the dtype will be guessed. 

918 @param numGuess: The number of lines to use in guessing the dtype from the file. 

919 @param delimiter: The delimiter to use when parsing the file default is white space. 

920 @param idColKey: The name of the column that uniquely identifies each row in the database 

921 """ 

922 self.verbose = verbose 

923 

924 if idColKey is not None: 

925 self.idColKey = idColKey 

926 

927 if(self.objid is None) or (self.idColKey is None): 

928 raise ValueError("CatalogDBObject must be subclassed, and " 

929 "define objid and tableid and idColKey.") 

930 

931 if (self.objectTypeId is None) and self.verbose: 

932 warnings.warn("objectTypeId has not " 

933 "been set. Input files for phosim are not " 

934 "possible.") 

935 

936 if os.path.exists(dataLocatorString): 

937 self.driver = driver 

938 self.host = host 

939 self.port = port 

940 self.database = database 

941 self.connection = DBConnection(database=self.database, driver=self.driver, host=self.host, 

942 port=self.port, verbose=verbose) 

943 self.tableid = loadData(dataLocatorString, dtype, delimiter, runtable, self.idColKey, 

944 self.connection.engine, self.connection.metadata, numGuess, 

945 indexCols=self.indexCols, **kwargs) 

946 self._get_table() 

947 else: 

948 raise ValueError("Could not locate file %s."%(dataLocatorString)) 

949 

950 if self.generateDefaultColumnMap: 

951 self._make_default_columns() 

952 

953 self._make_column_map() 

954 self._make_type_map() 

955 

956 @classmethod 

957 def from_objid(cls, objid, *args, **kwargs): 

958 """Given a string objid, return an instance of 

959 the appropriate fileDBObject class. 

960 """ 

961 cls = cls.registry.get(objid, CatalogDBObject) 

962 return cls(*args, **kwargs)