Coverage for python/lsst/sims/catalogs/db/dbConnection.py : 20%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import print_function
2from future import standard_library
3standard_library.install_aliases()
4import sys
5from builtins import str
7# 2017 March 9
8# str_cast exists because numpy.dtype does
9# not like unicode-like things as the names
10# of columns. Unfortunately, in python 2,
11# builtins.str looks unicode-like. We will
12# use str_cast in python 2 to maintain
13# both python 3 compatibility and our use of
14# numpy dtype
15str_cast = str
16if sys.version_info.major == 2: 16 ↛ 17line 16 didn't jump to line 17, because the condition on line 16 was never true
17 from past.builtins import str as past_str
18 str_cast = past_str
20from builtins import zip
21from builtins import object
22import warnings
23import numpy
24import os
25import inspect
26from io import BytesIO
27from collections import OrderedDict
29from .utils import loadData
30from sqlalchemy.orm import scoped_session, sessionmaker
31from sqlalchemy.sql import expression
32from sqlalchemy.engine import reflection, url
33from sqlalchemy import (create_engine, MetaData,
34 Table, event, text)
35from sqlalchemy import exc as sa_exc
36from lsst.daf.butler.registry import DbAuth
37from lsst.sims.utils.CodeUtilities import sims_clean_up
39#The documentation at http://docs.sqlalchemy.org/en/rel_0_7/core/types.html#sqlalchemy.types.Numeric
40#suggests using the cdecimal module. Since it is not standard, import decimal.
41#TODO: test for cdecimal and use it if it exists.
42import decimal
43from future.utils import with_metaclass
45__all__ = ["ChunkIterator", "DBObject", "CatalogDBObject", "fileDBObject"]
47def valueOfPi():
48 """
49 A function to return the value of pi. This is needed for adding PI()
50 to sqlite databases
51 """
52 return numpy.pi
54def declareTrigFunctions(conn,connection_rec,connection_proxy):
55 """
56 A database event listener
57 which will define the math functions necessary for evaluating the
58 Haversine function in sqlite databases (where they are not otherwise
59 defined)
61 see: http://docs.sqlalchemy.org/en/latest/core/events.html
62 """
64 conn.create_function("COS",1,numpy.cos)
65 conn.create_function("SIN",1,numpy.sin)
66 conn.create_function("ASIN",1,numpy.arcsin)
67 conn.create_function("SQRT",1,numpy.sqrt)
68 conn.create_function("POWER",2,numpy.power)
69 conn.create_function("PI",0,valueOfPi)
71#------------------------------------------------------------
72# Iterator for database chunks
74class ChunkIterator(object):
75 """Iterator for query chunks"""
76 def __init__(self, dbobj, query, chunk_size, arbitrarySQL = False):
77 self.dbobj = dbobj
78 self.exec_query = dbobj.connection.session.execute(query)
79 self.chunk_size = chunk_size
81 #arbitrarySQL exists in case a CatalogDBObject calls
82 #get_arbitrary_chunk_iterator; in that case, we need to
83 #be able to tell this object to call _postprocess_arbitrary_results,
84 #rather than _postprocess_results
85 self.arbitrarySQL = arbitrarySQL
87 def __iter__(self):
88 return self
90 def __next__(self):
91 if self.chunk_size is None and not self.exec_query.closed:
92 chunk = self.exec_query.fetchall()
93 return self._postprocess_results(chunk)
94 elif self.chunk_size is not None:
95 chunk = self.exec_query.fetchmany(self.chunk_size)
96 return self._postprocess_results(chunk)
97 else:
98 raise StopIteration
100 def _postprocess_results(self, chunk):
101 if len(chunk)==0:
102 raise StopIteration
103 if self.arbitrarySQL:
104 return self.dbobj._postprocess_arbitrary_results(chunk)
105 else:
106 return self.dbobj._postprocess_results(chunk)
109class DBConnection(object):
110 """
111 This is a class that will hold the engine, session, and metadata for a
112 DBObject. This will allow multiple DBObjects to share the same
113 sqlalchemy connection, when appropriate.
114 """
116 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False):
117 """
118 @param [in] database is the name of the database file being connected to
120 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
122 @param [in] host is the URL of the remote host, if appropriate
124 @param [in] port is the port on the remote host to connect to, if appropriate
126 @param [in] verbose is a boolean controlling sqlalchemy's verbosity
127 """
129 self._database = database
130 self._driver = driver
131 self._host = host
132 self._port = port
133 self._verbose = verbose
135 self._validate_conn_params()
136 self._connect_to_engine()
138 def __del__(self):
139 try:
140 del self._metadata
141 except AttributeError:
142 pass
144 try:
145 del self._engine
146 except AttributeError:
147 pass
149 try:
150 del self._session
151 except AttributeError:
152 pass
154 def _connect_to_engine(self):
156 #DbAuth will not look up hosts that are None, '' or 0
157 if self._host:
158 auth = DbAuth(
159 os.path.join(os.environ["HOME"], ".lsst", "db-auth.yaml"))
160 username, password = auth.getAuth(
161 self._driver, host=self._host, port=self._port,
162 database=self._database)
163 dbUrl = url.URL(self._driver,
164 host=self._host,
165 port=self._port,
166 database=self._database,
167 username=username,
168 password=password)
169 else:
170 dbUrl = url.URL(self._driver,
171 database=self._database)
174 self._engine = create_engine(dbUrl, echo=self._verbose)
176 if self._engine.dialect.name == 'sqlite':
177 event.listen(self._engine, 'checkout', declareTrigFunctions)
179 self._session = scoped_session(sessionmaker(autoflush=True,
180 bind=self._engine))
181 self._metadata = MetaData(bind=self._engine)
184 def _validate_conn_params(self):
185 """Validate connection parameters
187 - Check if user passed dbAddress instead of an database. Convert and warn.
188 - Check that required connection paramters are present
189 - Replace default host/port if driver is 'sqlite'
190 """
192 if self._database is None:
193 raise AttributeError("Cannot instantiate DBConnection; database is 'None'")
195 if '//' in self._database:
196 warnings.warn("Database name '%s' is invalid but looks like a dbAddress. "
197 "Attempting to convert to database, driver, host, "
198 "and port parameters. Any usernames and passwords are ignored and must "
199 "be in the db-auth.paf policy file. "%(self.database), FutureWarning)
201 dbUrl = url.make_url(self._database)
202 dialect = dbUrl.get_dialect()
203 self._driver = dialect.name + '+' + dialect.driver if dialect.driver else dialect.name
204 for key, value in dbUrl.translate_connect_args().items():
205 if value is not None:
206 setattr(self, '_'+key, value)
208 errMessage = "Please supply a 'driver' kwarg to the constructor or in class definition. "
209 errMessage += "'driver' is formatted as dialect+driver, such as 'sqlite' or 'mssql+pymssql'."
210 if not hasattr(self, '_driver'):
211 raise AttributeError("%s has no attribute 'driver'. "%(self.__class__.__name__) + errMessage)
212 elif self._driver is None:
213 raise AttributeError("%s.driver is None. "%(self.__class__.__name__) + errMessage)
215 errMessage = "Please supply a 'database' kwarg to the constructor or in class definition. "
216 errMessage += " 'database' is the database name or the filename path if driver is 'sqlite'. "
217 if not hasattr(self, '_database'):
218 raise AttributeError("%s has no attribute 'database'. "%(self.__class__.__name__) + errMessage)
219 elif self._database is None:
220 raise AttributeError("%s.database is None. "%(self.__class__.__name__) + errMessage)
222 if 'sqlite' in self._driver:
223 #When passed sqlite database, override default host/port
224 self._host = None
225 self._port = None
228 def __eq__(self, other):
229 return (str(self._database) == str(other._database)) and \
230 (str(self._driver) == str(other._driver)) and \
231 (str(self._host) == str(other._host)) and \
232 (str(self._port) == str(other._port))
235 @property
236 def engine(self):
237 return self._engine
239 @property
240 def session(self):
241 return self._session
244 @property
245 def metadata(self):
246 return self._metadata
248 @property
249 def database(self):
250 return self._database
252 @property
253 def driver(self):
254 return self._driver
256 @property
257 def host(self):
258 return self._host
260 @property
261 def port(self):
262 return self._port
264 @property
265 def verbose(self):
266 return self._verbose
269class DBObject(object):
271 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False,
272 connection=None, cache_connection=True):
273 """
274 Initialize DBObject.
276 @param [in] database is the name of the database file being connected to
278 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
280 @param [in] host is the URL of the remote host, if appropriate
282 @param [in] port is the port on the remote host to connect to, if appropriate
284 @param [in] verbose is a boolean controlling sqlalchemy's verbosity (default False)
286 @param [in] connection is an optional instance of DBConnection, in the event that
287 this DBObject can share a database connection with another DBObject. This is only
288 necessary or even possible in a few specialized cases and should be used carefully.
290 @param [in] cache_connection is a boolean. If True, DBObject will use a cache of
291 DBConnections (if available) to get the connection to this database.
292 """
294 self.dtype = None
295 #this is a cache for the query, so that any one query does not have to guess dtype multiple times
297 if connection is None:
298 #Explicit constructor to DBObject preferred
299 kwargDict = dict(database=database,
300 driver=driver,
301 host=host,
302 port=port,
303 verbose=verbose)
305 for key, value in kwargDict.items():
306 if value is not None or not hasattr(self, key):
307 setattr(self, key, value)
309 self.connection = self._get_connection(self.database, self.driver, self.host, self.port,
310 use_cache=cache_connection)
312 else:
313 self.connection = connection
314 self.database = connection.database
315 self.driver = connection.driver
316 self.host = connection.host
317 self.port = connection.port
318 self.verbose = connection.verbose
320 def _get_connection(self, database, driver, host, port, use_cache=True):
321 """
322 Search self._connection_cache (if it exists; it won't for DBObject, but
323 will for CatalogDBObject) for a DBConnection matching the specified
324 parameters. If it exists, return it. If not, open a connection to
325 the specified database, add it to the cache, and return the connection.
327 Parameters
328 ----------
329 database is the name of the database file being connected to
331 driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
333 host is the URL of the remote host, if appropriate
335 port is the port on the remote host to connect to, if appropriate
337 use_cache is a boolean specifying whether or not we try to use the
338 cache of database connections (you don't want to if opening many
339 connections in many threads).
340 """
342 if use_cache and hasattr(self, '_connection_cache'):
343 for conn in self._connection_cache:
344 if str(conn.database) == str(database):
345 if str(conn.driver) == str(driver):
346 if str(conn.host) == str(host):
347 if str(conn.port) == str(port):
348 return conn
350 conn = DBConnection(database=database, driver=driver, host=host, port=port)
352 if use_cache and hasattr(self, '_connection_cache'):
353 self._connection_cache.append(conn)
355 return conn
357 def get_table_names(self):
358 """Return a list of the names of the tables (and views) in the database"""
359 return [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_table_names()] + \
360 [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_view_names()]
362 def get_column_names(self, tableName=None):
363 """
364 Return a list of the names of the columns in the specified table.
365 If no table is specified, return a dict of lists. The dict will be keyed
366 to the table names. The lists will be of the column names in that table
367 """
368 tableNameList = self.get_table_names()
369 if tableName is not None:
370 if tableName not in tableNameList:
371 return []
372 return [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(tableName)]
373 else:
374 columnDict = {}
375 for name in tableNameList:
376 columnList = [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(name)]
377 columnDict[name] = columnList
378 return columnDict
380 def _final_pass(self, results):
381 """ Make final modifications to a set of data before returning it to the user
383 **Parameters**
385 * results : a structured array constructed from the result set from a query
387 **Returns**
389 * results : a potentially modified structured array. The default is to do nothing.
391 """
392 return results
394 def _convert_results_to_numpy_recarray_dbobj(self, results):
395 if self.dtype is None:
396 """
397 Determine the dtype from the data.
398 Store it in a global variable so we do not have to repeat on every chunk.
399 """
400 dataString = ''
402 # We are going to detect the dtype by reading in a single row
403 # of data with np.genfromtxt. To do this, we must pass the
404 # row as a string delimited by a specified character. Here we
405 # select a character that does not occur anywhere in the data.
406 delimit_char_list = [',', ';', '|', ':', '/', '\\']
407 delimit_char = None
408 for cc in delimit_char_list:
409 is_valid = True
410 for xx in results[0]:
411 if cc in str(xx):
412 is_valid = False
413 break
415 if is_valid:
416 delimit_char = cc
417 break
419 if delimit_char is None:
420 raise RuntimeError("DBObject could not detect the dtype of your return rows\n"
421 "Please specify a dtype with the 'dtype' kwarg.")
423 for xx in results[0]:
424 if dataString is not '':
425 dataString += delimit_char
426 dataString += str(xx)
427 names = [str_cast(ww) for ww in results[0].keys()]
428 dataArr = numpy.genfromtxt(BytesIO(dataString.encode()), dtype=None,
429 names=names, delimiter=delimit_char,
430 encoding='utf-8')
431 dt_list = []
432 for name in dataArr.dtype.names:
433 type_name = str(dataArr.dtype[name])
434 sub_list = [name]
435 if type_name.startswith('S') or type_name.startswith('|S'):
436 sub_list.append(str_cast)
437 sub_list.append(int(type_name.replace('S','').replace('|','')))
438 else:
439 sub_list.append(dataArr.dtype[name])
440 dt_list.append(tuple(sub_list))
442 self.dtype = numpy.dtype(dt_list)
444 if len(results) == 0:
445 return numpy.recarray((0,), dtype = self.dtype)
447 retresults = numpy.rec.fromrecords([tuple(xx) for xx in results],dtype = self.dtype)
448 return retresults
450 def _postprocess_results(self, results):
451 """
452 This wrapper exists so that a ChunkIterator built from a DBObject
453 can have the same API as a ChunkIterator built from a CatalogDBObject
454 """
455 return self._postprocess_arbitrary_results(results)
458 def _postprocess_arbitrary_results(self, results):
460 if not isinstance(results, numpy.recarray):
461 retresults = self._convert_results_to_numpy_recarray_dbobj(results)
462 else:
463 retresults = results
465 return self._final_pass(retresults)
467 def execute_arbitrary(self, query, dtype = None):
468 """
469 Executes an arbitrary query. Returns a recarray of the results.
471 dtype will be the dtype of the output recarray. If it is None, then
472 the code will guess the datatype and assign generic names to the columns
473 """
475 try:
476 is_string = isinstance(query, basestring)
477 except:
478 is_string = isinstance(query, str)
480 if not is_string:
481 raise RuntimeError("DBObject execute must be called with a string query")
483 unacceptableCommands = ["delete","drop","insert","update"]
484 for badCommand in unacceptableCommands:
485 if query.lower().find(badCommand.lower())>=0:
486 raise RuntimeError("query made to DBObject execute contained %s " % badCommand)
488 self.dtype = dtype
489 retresults = self._postprocess_arbitrary_results(self.connection.session.execute(query).fetchall())
490 return retresults
492 def get_arbitrary_chunk_iterator(self, query, chunk_size = None, dtype =None):
493 """
494 This wrapper exists so that CatalogDBObjects can refer to
495 get_arbitrary_chunk_iterator and DBObjects can refer to
496 get_chunk_iterator
497 """
498 return self.get_chunk_iterator(query, chunk_size = chunk_size, dtype = dtype)
500 def get_chunk_iterator(self, query, chunk_size = None, dtype = None):
501 """
502 Take an arbitrary, user-specified query and return a ChunkIterator that
503 executes that query
505 dtype will tell the ChunkIterator what datatype to expect for this query.
506 This information gets passed to _postprocess_results.
508 If 'None', then _postprocess_results will just guess the datatype
509 and return generic names for the columns.
510 """
511 self.dtype = dtype
512 return ChunkIterator(self, query, chunk_size, arbitrarySQL = True)
514class CatalogDBObjectMeta(type):
515 """Meta class for registering new objects.
517 When any new type of object class is created, this registers it
518 in a `registry` class attribute, available to all derived instance
519 catalog.
520 """
521 def __new__(cls, name, bases, dct):
522 # check if attribute objid is specified.
523 # If not, create a default
524 if 'registry' in dct: 524 ↛ 525line 524 didn't jump to line 525, because the condition on line 524 was never true
525 warnings.warn("registry class attribute should not be "
526 "over-ridden in InstanceCatalog classes. "
527 "Proceed with caution")
528 if 'objid' not in dct:
529 dct['objid'] = name
530 return super(CatalogDBObjectMeta, cls).__new__(cls, name, bases, dct)
532 def __init__(cls, name, bases, dct):
533 # check if 'registry' is specified.
534 # if not, then this is the base class: add the registry
535 if not hasattr(cls, 'registry'):
536 cls.registry = {}
537 else:
538 if not cls.skipRegistration: 538 ↛ 550line 538 didn't jump to line 550, because the condition on line 538 was never false
539 # add this class to the registry
540 if cls.objid in cls.registry: 540 ↛ 541line 540 didn't jump to line 541, because the condition on line 540 was never true
541 srcfile = inspect.getsourcefile(cls.registry[cls.objid])
542 srcline = inspect.getsourcelines(cls.registry[cls.objid])[1]
543 warnings.warn('duplicate object identifier %s specified. '%(cls.objid)+\
544 'This will override previous definition on line %i of %s'%
545 (srcline, srcfile))
546 cls.registry[cls.objid] = cls
548 # check if the list of unique ids is specified
549 # if not, then this is the base class: add the list
550 if not hasattr(cls, 'objectTypeIdList'):
551 cls.objectTypeIdList = []
552 else:
553 if cls.skipRegistration: 553 ↛ 554line 553 didn't jump to line 554, because the condition on line 553 was never true
554 pass
555 elif cls.objectTypeId is None:
556 pass #Don't add typeIds that are None
557 elif cls.objectTypeId in cls.objectTypeIdList:
558 warnings.warn('Duplicate object type id %s specified: '%cls.objectTypeId+\
559 '\nOutput object ids may not be unique.\nThis may not be a problem if you do not '+\
560 'want globally unique id values')
561 else:
562 cls.objectTypeIdList.append(cls.objectTypeId)
563 return super(CatalogDBObjectMeta, cls).__init__(name, bases, dct)
565 def __str__(cls):
566 dbObjects = cls.registry.keys()
567 outstr = "++++++++++++++++++++++++++++++++++++++++++++++\n"+\
568 "Registered object types are:\n"
569 for dbObject in dbObjects:
570 outstr += "%s\n"%(dbObject)
571 outstr += "\n\n"
572 outstr += "To query the possible column names do:\n"
573 outstr += "$> CatalogDBObject.from_objid([name]).show_mapped_columns()\n"
574 outstr += "+++++++++++++++++++++++++++++++++++++++++++++"
575 return outstr
577class CatalogDBObject(with_metaclass(CatalogDBObjectMeta, DBObject)):
578 """Database Object base class
580 """
582 epoch = 2000.0
583 skipRegistration = False
584 objid = None
585 tableid = None
586 idColKey = None
587 objectTypeId = None
588 columns = None
589 generateDefaultColumnMap = True
590 dbDefaultValues = {}
591 raColName = None
592 decColName = None
594 _connection_cache = [] # a list to store open database connections in
596 #Provide information if this object should be tested in the unit test
597 doRunTest = False
598 testObservationMetaData = None
600 #: Mapping of DDL types to python types. Strings are assumed to be 256 characters
601 #: this can be overridden by modifying the dbTypeMap or by making a custom columns
602 #: list.
603 #: numpy doesn't know how to convert decimal.Decimal types, so I changed this to float
604 #: TODO this doesn't seem to make a difference but make sure.
605 dbTypeMap = {'BIGINT':(int,), 'BOOLEAN':(bool,), 'FLOAT':(float,), 'INTEGER':(int,),
606 'NUMERIC':(float,), 'SMALLINT':(int,), 'TINYINT':(int,), 'VARCHAR':(str, 256),
607 'TEXT':(str, 256), 'CLOB':(str, 256), 'NVARCHAR':(str, 256),
608 'NCLOB':(str, 256), 'NTEXT':(str, 256), 'CHAR':(str, 1), 'INT':(int,),
609 'REAL':(float,), 'DOUBLE':(float,), 'STRING':(str, 256), 'DOUBLE_PRECISION':(float,),
610 'DECIMAL':(float,)}
612 @classmethod
613 def from_objid(cls, objid, *args, **kwargs):
614 """Given a string objid, return an instance of
615 the appropriate CatalogDBObject class.
616 """
617 if objid not in cls.registry:
618 raise RuntimeError('Attempting to construct an object that does not exist')
619 cls = cls.registry.get(objid, CatalogDBObject)
620 return cls(*args, **kwargs)
622 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False,
623 table=None, objid=None, idColKey=None, connection=None,
624 cache_connection=True):
625 if not verbose:
626 with warnings.catch_warnings():
627 warnings.simplefilter("ignore", category=sa_exc.SAWarning)
629 if self.tableid is not None and table is not None:
630 raise ValueError("Double-specified tableid in CatalogDBObject:"
631 " once in class definition, once in __init__")
633 if table is not None:
634 self.tableid = table
636 if self.objid is not None and objid is not None:
637 raise ValueError("Double-specified objid in CatalogDBObject:"
638 " once in class definition, once in __init__")
640 if objid is not None:
641 self.objid = objid
643 if self.idColKey is not None and idColKey is not None:
644 raise ValueError("Double-specified idColKey in CatalogDBObject:"
645 " once in class definition, once in __init__")
647 if idColKey is not None:
648 self.idColKey = idColKey
650 if self.idColKey is None:
651 self.idColKey = self.getIdColKey()
652 if (self.objid is None) or (self.tableid is None) or (self.idColKey is None):
653 msg = ("CatalogDBObject must be subclassed, and "
654 "define objid, tableid and idColKey. You are missing: ")
655 if self.objid is None:
656 msg += "objid, "
657 if self.tableid is None:
658 msg += "tableid, "
659 if self.idColKey is None:
660 msg += "idColKey"
661 raise ValueError(msg)
663 if (self.objectTypeId is None) and verbose:
664 warnings.warn("objectTypeId has not "
665 "been set. Input files for phosim are not "
666 "possible.")
668 super(CatalogDBObject, self).__init__(database=database, driver=driver, host=host, port=port,
669 verbose=verbose, connection=connection, cache_connection=True)
671 try:
672 self._get_table()
673 except sa_exc.OperationalError as e:
674 if self.driver == 'mssql+pymssql':
675 message = "\n To connect to the UW CATSIM database: "
676 message += " Check that you have valid connection parameters, an open ssh tunnel "
677 message += "and that your $HOME/.lsst/db-auth.paf contains the appropriate credientials. "
678 message += "Please consult the following link for more information on access: "
679 message += " https://confluence.lsstcorp.org/display/SIM/Accessing+the+UW+CATSIM+Database "
680 else:
681 message = ''
682 raise RuntimeError("Failed to connect to %s: sqlalchemy.%s %s" % (self.connection.engine, e.args[0], message))
684 #Need to do this after the table is instantiated so that
685 #the default columns can be filled from the table object.
686 if self.generateDefaultColumnMap:
687 self._make_default_columns()
688 # build column mapping and type mapping dicts from columns
689 self._make_column_map()
690 self._make_type_map()
692 def show_mapped_columns(self):
693 for col in self.columnMap.keys():
694 print("%s -- %s"%(col, self.typeMap[col][0].__name__))
696 def show_db_columns(self):
697 for col in self.table.c.keys():
698 print("%s -- %s"%(col, self.table.c[col].type.__visit_name__))
701 def getCatalog(self, ftype, *args, **kwargs):
702 try:
703 from lsst.sims.catalogs.definitions import InstanceCatalog
704 return InstanceCatalog.new_catalog(ftype, self, *args, **kwargs)
705 except ImportError:
706 raise ImportError("sims_catalogs not set up. Cannot get InstanceCatalog from the object.")
708 def getIdColKey(self):
709 return self.idColKey
711 def getObjectTypeId(self):
712 return self.objectTypeId
714 def _get_table(self):
715 self.table = Table(self.tableid, self.connection.metadata,
716 autoload=True)
718 def _make_column_map(self):
719 self.columnMap = OrderedDict([(el[0], el[1] if el[1] else el[0])
720 for el in self.columns])
721 def _make_type_map(self):
722 self.typeMap = OrderedDict([(el[0], el[2:] if len(el)> 2 else (float,))
723 for el in self.columns])
725 def _make_default_columns(self):
726 if self.columns:
727 colnames = [el[0] for el in self.columns]
728 else:
729 self.columns = []
730 colnames = []
731 for col in self.table.c.keys():
732 dbtypestr = self.table.c[col].type.__visit_name__
733 dbtypestr = dbtypestr.upper()
734 if col in colnames:
735 if self.verbose: #Warn for possible column redefinition
736 warnings.warn("Database column, %s, overridden in self.columns... "%(col)+
737 "Skipping default assignment.")
738 elif dbtypestr in self.dbTypeMap:
739 self.columns.append((col, col)+self.dbTypeMap[dbtypestr])
740 else:
741 if self.verbose:
742 warnings.warn("Can't create default column for %s. There is no mapping "%(col)+
743 "for type %s. Modify the dbTypeMap, or make a custom columns "%(dbtypestr)+
744 "list.")
746 def _get_column_query(self, colnames=None):
747 """Given a list of valid column names, return the query object"""
748 if colnames is None:
749 colnames = [k for k in self.columnMap]
750 try:
751 vals = [self.columnMap[k] for k in colnames]
752 except KeyError:
753 offending_columns = '\n'
754 for col in colnames:
755 if col in self.columnMap:
756 continue
757 else:
758 offending_columns +='%s\n' % col
759 raise ValueError('entries in colnames must be in self.columnMap. '
760 'These:%sare not' % offending_columns)
762 # Get the first query
763 idColName = self.columnMap[self.idColKey]
764 if idColName in vals:
765 idLabel = self.idColKey
766 else:
767 idLabel = idColName
769 query = self.connection.session.query(self.table.c[idColName].label(idLabel))
771 for col, val in zip(colnames, vals):
772 if val is idColName:
773 continue
774 #Check if the column is a default column (col == val)
775 if col == val:
776 #If column is in the table, use it.
777 query = query.add_columns(self.table.c[col].label(col))
778 else:
779 #If not assume the user specified the column correctly
780 query = query.add_columns(expression.literal_column(val).label(col))
782 return query
784 def filter(self, query, bounds):
785 """Filter the query by the associated metadata"""
786 if bounds is not None:
787 on_clause = bounds.to_SQL(self.raColName,self.decColName)
788 query = query.filter(text(on_clause))
789 return query
791 def _convert_results_to_numpy_recarray_catalogDBObj(self, results):
792 """Post-process the query results to put them
793 in a structured array.
795 **Parameters**
797 * results : a result set as returned by execution of the query
799 **Returns**
801 * _final_pass(retresults) : the result of calling the _final_pass method on a
802 structured array constructed from the query data.
803 """
805 if len(results) > 0:
806 cols = [str(k) for k in results[0].keys()]
807 else:
808 return results
810 if sys.version_info.major == 2:
811 dt_list = []
812 for k in cols:
813 sub_list = [past_str(k)]
814 if self.typeMap[k][0] is not str:
815 for el in self.typeMap[k]:
816 sub_list.append(el)
817 else:
818 sub_list.append(past_str)
819 for el in self.typeMap[k][1:]:
820 sub_list.append(el)
821 dt_list.append(tuple(sub_list))
823 dtype = numpy.dtype(dt_list)
825 else:
826 dtype = numpy.dtype([(k,)+self.typeMap[k] for k in cols])
828 if len(set(cols)&set(self.dbDefaultValues)) > 0:
830 results_array = []
832 for result in results:
833 results_array.append(tuple(result[colName]
834 if result[colName] or
835 colName not in self.dbDefaultValues
836 else self.dbDefaultValues[colName]
837 for colName in cols))
839 else:
840 results_array = [tuple(rr) for rr in results]
842 retresults = numpy.rec.fromrecords(results_array, dtype=dtype)
843 return retresults
845 def _postprocess_results(self, results):
846 if not isinstance(results, numpy.recarray):
847 retresults = self._convert_results_to_numpy_recarray_catalogDBObj(results)
848 else:
849 retresults = results
850 return self._final_pass(retresults)
852 def query_columns(self, colnames=None, chunk_size=None,
853 obs_metadata=None, constraint=None, limit=None):
854 """Execute a query
856 **Parameters**
858 * colnames : list or None
859 a list of valid column names, corresponding to entries in the
860 `columns` class attribute. If not specified, all columns are
861 queried.
862 * chunk_size : int (optional)
863 if specified, then return an iterator object to query the database,
864 each time returning the next `chunk_size` elements. If not
865 specified, all matching results will be returned.
866 * obs_metadata : object (optional)
867 an observation metadata object which has a "filter" method, which
868 will add a filter string to the query.
869 * constraint : str (optional)
870 a string which is interpreted as SQL and used as a predicate on the query
871 * limit : int (optional)
872 limits the number of rows returned by the query
874 **Returns**
876 * result : list or iterator
877 If chunk_size is not specified, then result is a list of all
878 items which match the specified query. If chunk_size is specified,
879 then result is an iterator over lists of the given size.
881 """
882 query = self._get_column_query(colnames)
884 if obs_metadata is not None:
885 query = self.filter(query, obs_metadata.bounds)
887 if constraint is not None:
888 query = query.filter(text(constraint))
890 if limit is not None:
891 query = query.limit(limit)
893 return ChunkIterator(self, query, chunk_size)
895sims_clean_up.targets.append(CatalogDBObject._connection_cache)
897class fileDBObject(CatalogDBObject):
898 ''' Class to read a file into a database and then query it'''
899 #Column names to index. Specify compound indexes using tuples of column names
900 indexCols = []
901 def __init__(self, dataLocatorString, runtable=None, driver="sqlite", host=None, port=None, database=":memory:",
902 dtype=None, numGuess=1000, delimiter=None, verbose=False, idColKey=None, **kwargs):
903 """
904 Initialize an object for querying databases loaded from a file
906 Keyword arguments:
907 @param dataLocatorString: Path to the file to load
908 @param runtable: The name of the table to create. If None, a random table name will be used.
909 @param driver: name of database driver (e.g. 'sqlite', 'mssql+pymssql')
910 @param host: hostname for database connection (None if sqlite)
911 @param port: port for database connection (None if sqlite)
912 @param database: name of database (filename if sqlite)
913 @param dtype: The numpy dtype to use when loading the file. If None, it the dtype will be guessed.
914 @param numGuess: The number of lines to use in guessing the dtype from the file.
915 @param delimiter: The delimiter to use when parsing the file default is white space.
916 @param idColKey: The name of the column that uniquely identifies each row in the database
917 """
918 self.verbose = verbose
920 if idColKey is not None:
921 self.idColKey = idColKey
923 if(self.objid is None) or (self.idColKey is None):
924 raise ValueError("CatalogDBObject must be subclassed, and "
925 "define objid and tableid and idColKey.")
927 if (self.objectTypeId is None) and self.verbose:
928 warnings.warn("objectTypeId has not "
929 "been set. Input files for phosim are not "
930 "possible.")
932 if os.path.exists(dataLocatorString):
933 self.driver = driver
934 self.host = host
935 self.port = port
936 self.database = database
937 self.connection = DBConnection(database=self.database, driver=self.driver, host=self.host,
938 port=self.port, verbose=verbose)
939 self.tableid = loadData(dataLocatorString, dtype, delimiter, runtable, self.idColKey,
940 self.connection.engine, self.connection.metadata, numGuess,
941 indexCols=self.indexCols, **kwargs)
942 self._get_table()
943 else:
944 raise ValueError("Could not locate file %s."%(dataLocatorString))
946 if self.generateDefaultColumnMap:
947 self._make_default_columns()
949 self._make_column_map()
950 self._make_type_map()
952 @classmethod
953 def from_objid(cls, objid, *args, **kwargs):
954 """Given a string objid, return an instance of
955 the appropriate fileDBObject class.
956 """
957 cls = cls.registry.get(objid, CatalogDBObject)
958 return cls(*args, **kwargs)