Coverage for python/lsst/sims/catalogs/db/dbConnection.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import print_function
2from future import standard_library
3standard_library.install_aliases()
4import sys
5from builtins import str
7# 2017 March 9
8# str_cast exists because numpy.dtype does
9# not like unicode-like things as the names
10# of columns. Unfortunately, in python 2,
11# builtins.str looks unicode-like. We will
12# use str_cast in python 2 to maintain
13# both python 3 compatibility and our use of
14# numpy dtype
15str_cast = str
16if sys.version_info.major == 2: 16 ↛ 17line 16 didn't jump to line 17, because the condition on line 16 was never true
17 from past.builtins import str as past_str
18 str_cast = past_str
20from builtins import zip
21from builtins import object
22import warnings
23import numpy
24import os
25import inspect
26from io import BytesIO
27from collections import OrderedDict
29from .utils import loadData
30from sqlalchemy.orm import scoped_session, sessionmaker
31from sqlalchemy.sql import expression
32from sqlalchemy.engine import reflection, url
33from sqlalchemy import (create_engine, MetaData,
34 Table, event, text)
35from sqlalchemy import exc as sa_exc
36from lsst.daf.persistence import DbAuth
37from lsst.sims.utils.CodeUtilities import sims_clean_up
39#The documentation at http://docs.sqlalchemy.org/en/rel_0_7/core/types.html#sqlalchemy.types.Numeric
40#suggests using the cdecimal module. Since it is not standard, import decimal.
41#TODO: test for cdecimal and use it if it exists.
42import decimal
43from future.utils import with_metaclass
45__all__ = ["ChunkIterator", "DBObject", "CatalogDBObject", "fileDBObject"]
47def valueOfPi():
48 """
49 A function to return the value of pi. This is needed for adding PI()
50 to sqlite databases
51 """
52 return numpy.pi
54def declareTrigFunctions(conn,connection_rec,connection_proxy):
55 """
56 A database event listener
57 which will define the math functions necessary for evaluating the
58 Haversine function in sqlite databases (where they are not otherwise
59 defined)
61 see: http://docs.sqlalchemy.org/en/latest/core/events.html
62 """
64 conn.create_function("COS",1,numpy.cos)
65 conn.create_function("SIN",1,numpy.sin)
66 conn.create_function("ASIN",1,numpy.arcsin)
67 conn.create_function("SQRT",1,numpy.sqrt)
68 conn.create_function("POWER",2,numpy.power)
69 conn.create_function("PI",0,valueOfPi)
71#------------------------------------------------------------
72# Iterator for database chunks
74class ChunkIterator(object):
75 """Iterator for query chunks"""
76 def __init__(self, dbobj, query, chunk_size, arbitrarySQL = False):
77 self.dbobj = dbobj
78 self.exec_query = dbobj.connection.session.execute(query)
79 self.chunk_size = chunk_size
81 #arbitrarySQL exists in case a CatalogDBObject calls
82 #get_arbitrary_chunk_iterator; in that case, we need to
83 #be able to tell this object to call _postprocess_arbitrary_results,
84 #rather than _postprocess_results
85 self.arbitrarySQL = arbitrarySQL
87 def __iter__(self):
88 return self
90 def __next__(self):
91 if self.chunk_size is None and not self.exec_query.closed:
92 chunk = self.exec_query.fetchall()
93 return self._postprocess_results(chunk)
94 elif self.chunk_size is not None:
95 chunk = self.exec_query.fetchmany(self.chunk_size)
96 return self._postprocess_results(chunk)
97 else:
98 raise StopIteration
100 def _postprocess_results(self, chunk):
101 if len(chunk)==0:
102 raise StopIteration
103 if self.arbitrarySQL:
104 return self.dbobj._postprocess_arbitrary_results(chunk)
105 else:
106 return self.dbobj._postprocess_results(chunk)
109class DBConnection(object):
110 """
111 This is a class that will hold the engine, session, and metadata for a
112 DBObject. This will allow multiple DBObjects to share the same
113 sqlalchemy connection, when appropriate.
114 """
116 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False):
117 """
118 @param [in] database is the name of the database file being connected to
120 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
122 @param [in] host is the URL of the remote host, if appropriate
124 @param [in] port is the port on the remote host to connect to, if appropriate
126 @param [in] verbose is a boolean controlling sqlalchemy's verbosity
127 """
129 self._database = database
130 self._driver = driver
131 self._host = host
132 self._port = port
133 self._verbose = verbose
135 self._validate_conn_params()
136 self._connect_to_engine()
138 def __del__(self):
139 try:
140 del self._metadata
141 except AttributeError:
142 pass
144 try:
145 del self._engine
146 except AttributeError:
147 pass
149 try:
150 del self._session
151 except AttributeError:
152 pass
154 def _connect_to_engine(self):
156 #DbAuth will not look up hosts that are None, '' or 0
157 if self._host:
158 try:
159 authDict = {'username': DbAuth.username(self._host, str(self._port)),
160 'password': DbAuth.password(self._host, str(self._port))}
161 except:
162 if self._driver == 'mssql+pymssql':
163 print("\nFor more information on database authentication using the db-auth.paf"
164 " policy file see: "
165 "https://confluence.lsstcorp.org/display/SIM/Accessing+the+UW+CATSIM+Database\n")
166 raise
168 dbUrl = url.URL(self._driver,
169 host=self._host,
170 port=self._port,
171 database=self._database,
172 **authDict)
173 else:
174 dbUrl = url.URL(self._driver,
175 database=self._database)
178 self._engine = create_engine(dbUrl, echo=self._verbose)
180 if self._engine.dialect.name == 'sqlite':
181 event.listen(self._engine, 'checkout', declareTrigFunctions)
183 self._session = scoped_session(sessionmaker(autoflush=True,
184 bind=self._engine))
185 self._metadata = MetaData(bind=self._engine)
188 def _validate_conn_params(self):
189 """Validate connection parameters
191 - Check if user passed dbAddress instead of an database. Convert and warn.
192 - Check that required connection paramters are present
193 - Replace default host/port if driver is 'sqlite'
194 """
196 if self._database is None:
197 raise AttributeError("Cannot instantiate DBConnection; database is 'None'")
199 if '//' in self._database:
200 warnings.warn("Database name '%s' is invalid but looks like a dbAddress. "
201 "Attempting to convert to database, driver, host, "
202 "and port parameters. Any usernames and passwords are ignored and must "
203 "be in the db-auth.paf policy file. "%(self.database), FutureWarning)
205 dbUrl = url.make_url(self._database)
206 dialect = dbUrl.get_dialect()
207 self._driver = dialect.name + '+' + dialect.driver if dialect.driver else dialect.name
208 for key, value in dbUrl.translate_connect_args().items():
209 if value is not None:
210 setattr(self, '_'+key, value)
212 errMessage = "Please supply a 'driver' kwarg to the constructor or in class definition. "
213 errMessage += "'driver' is formatted as dialect+driver, such as 'sqlite' or 'mssql+pymssql'."
214 if not hasattr(self, '_driver'):
215 raise AttributeError("%s has no attribute 'driver'. "%(self.__class__.__name__) + errMessage)
216 elif self._driver is None:
217 raise AttributeError("%s.driver is None. "%(self.__class__.__name__) + errMessage)
219 errMessage = "Please supply a 'database' kwarg to the constructor or in class definition. "
220 errMessage += " 'database' is the database name or the filename path if driver is 'sqlite'. "
221 if not hasattr(self, '_database'):
222 raise AttributeError("%s has no attribute 'database'. "%(self.__class__.__name__) + errMessage)
223 elif self._database is None:
224 raise AttributeError("%s.database is None. "%(self.__class__.__name__) + errMessage)
226 if 'sqlite' in self._driver:
227 #When passed sqlite database, override default host/port
228 self._host = None
229 self._port = None
232 def __eq__(self, other):
233 return (str(self._database) == str(other._database)) and \
234 (str(self._driver) == str(other._driver)) and \
235 (str(self._host) == str(other._host)) and \
236 (str(self._port) == str(other._port))
239 @property
240 def engine(self):
241 return self._engine
243 @property
244 def session(self):
245 return self._session
248 @property
249 def metadata(self):
250 return self._metadata
252 @property
253 def database(self):
254 return self._database
256 @property
257 def driver(self):
258 return self._driver
260 @property
261 def host(self):
262 return self._host
264 @property
265 def port(self):
266 return self._port
268 @property
269 def verbose(self):
270 return self._verbose
273class DBObject(object):
275 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False,
276 connection=None, cache_connection=True):
277 """
278 Initialize DBObject.
280 @param [in] database is the name of the database file being connected to
282 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
284 @param [in] host is the URL of the remote host, if appropriate
286 @param [in] port is the port on the remote host to connect to, if appropriate
288 @param [in] verbose is a boolean controlling sqlalchemy's verbosity (default False)
290 @param [in] connection is an optional instance of DBConnection, in the event that
291 this DBObject can share a database connection with another DBObject. This is only
292 necessary or even possible in a few specialized cases and should be used carefully.
294 @param [in] cache_connection is a boolean. If True, DBObject will use a cache of
295 DBConnections (if available) to get the connection to this database.
296 """
298 self.dtype = None
299 #this is a cache for the query, so that any one query does not have to guess dtype multiple times
301 if connection is None:
302 #Explicit constructor to DBObject preferred
303 kwargDict = dict(database=database,
304 driver=driver,
305 host=host,
306 port=port,
307 verbose=verbose)
309 for key, value in kwargDict.items():
310 if value is not None or not hasattr(self, key):
311 setattr(self, key, value)
313 self.connection = self._get_connection(self.database, self.driver, self.host, self.port,
314 use_cache=cache_connection)
316 else:
317 self.connection = connection
318 self.database = connection.database
319 self.driver = connection.driver
320 self.host = connection.host
321 self.port = connection.port
322 self.verbose = connection.verbose
324 def _get_connection(self, database, driver, host, port, use_cache=True):
325 """
326 Search self._connection_cache (if it exists; it won't for DBObject, but
327 will for CatalogDBObject) for a DBConnection matching the specified
328 parameters. If it exists, return it. If not, open a connection to
329 the specified database, add it to the cache, and return the connection.
331 Parameters
332 ----------
333 database is the name of the database file being connected to
335 driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
337 host is the URL of the remote host, if appropriate
339 port is the port on the remote host to connect to, if appropriate
341 use_cache is a boolean specifying whether or not we try to use the
342 cache of database connections (you don't want to if opening many
343 connections in many threads).
344 """
346 if use_cache and hasattr(self, '_connection_cache'):
347 for conn in self._connection_cache:
348 if str(conn.database) == str(database):
349 if str(conn.driver) == str(driver):
350 if str(conn.host) == str(host):
351 if str(conn.port) == str(port):
352 return conn
354 conn = DBConnection(database=database, driver=driver, host=host, port=port)
356 if use_cache and hasattr(self, '_connection_cache'):
357 self._connection_cache.append(conn)
359 return conn
361 def get_table_names(self):
362 """Return a list of the names of the tables (and views) in the database"""
363 return [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_table_names()] + \
364 [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_view_names()]
366 def get_column_names(self, tableName=None):
367 """
368 Return a list of the names of the columns in the specified table.
369 If no table is specified, return a dict of lists. The dict will be keyed
370 to the table names. The lists will be of the column names in that table
371 """
372 tableNameList = self.get_table_names()
373 if tableName is not None:
374 if tableName not in tableNameList:
375 return []
376 return [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(tableName)]
377 else:
378 columnDict = {}
379 for name in tableNameList:
380 columnList = [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(name)]
381 columnDict[name] = columnList
382 return columnDict
384 def _final_pass(self, results):
385 """ Make final modifications to a set of data before returning it to the user
387 **Parameters**
389 * results : a structured array constructed from the result set from a query
391 **Returns**
393 * results : a potentially modified structured array. The default is to do nothing.
395 """
396 return results
398 def _convert_results_to_numpy_recarray_dbobj(self, results):
399 if self.dtype is None:
400 """
401 Determine the dtype from the data.
402 Store it in a global variable so we do not have to repeat on every chunk.
403 """
404 dataString = ''
406 # We are going to detect the dtype by reading in a single row
407 # of data with np.genfromtxt. To do this, we must pass the
408 # row as a string delimited by a specified character. Here we
409 # select a character that does not occur anywhere in the data.
410 delimit_char_list = [',', ';', '|', ':', '/', '\\']
411 delimit_char = None
412 for cc in delimit_char_list:
413 is_valid = True
414 for xx in results[0]:
415 if cc in str(xx):
416 is_valid = False
417 break
419 if is_valid:
420 delimit_char = cc
421 break
423 if delimit_char is None:
424 raise RuntimeError("DBObject could not detect the dtype of your return rows\n"
425 "Please specify a dtype with the 'dtype' kwarg.")
427 for xx in results[0]:
428 if dataString is not '':
429 dataString += delimit_char
430 dataString += str(xx)
431 names = [str_cast(ww) for ww in results[0].keys()]
432 dataArr = numpy.genfromtxt(BytesIO(dataString.encode()), dtype=None,
433 names=names, delimiter=delimit_char,
434 encoding='utf-8')
435 dt_list = []
436 for name in dataArr.dtype.names:
437 type_name = str(dataArr.dtype[name])
438 sub_list = [name]
439 if type_name.startswith('S') or type_name.startswith('|S'):
440 sub_list.append(str_cast)
441 sub_list.append(int(type_name.replace('S','').replace('|','')))
442 else:
443 sub_list.append(dataArr.dtype[name])
444 dt_list.append(tuple(sub_list))
446 self.dtype = numpy.dtype(dt_list)
448 if len(results) == 0:
449 return numpy.recarray((0,), dtype = self.dtype)
451 retresults = numpy.rec.fromrecords([tuple(xx) for xx in results],dtype = self.dtype)
452 return retresults
454 def _postprocess_results(self, results):
455 """
456 This wrapper exists so that a ChunkIterator built from a DBObject
457 can have the same API as a ChunkIterator built from a CatalogDBObject
458 """
459 return self._postprocess_arbitrary_results(results)
462 def _postprocess_arbitrary_results(self, results):
464 if not isinstance(results, numpy.recarray):
465 retresults = self._convert_results_to_numpy_recarray_dbobj(results)
466 else:
467 retresults = results
469 return self._final_pass(retresults)
471 def execute_arbitrary(self, query, dtype = None):
472 """
473 Executes an arbitrary query. Returns a recarray of the results.
475 dtype will be the dtype of the output recarray. If it is None, then
476 the code will guess the datatype and assign generic names to the columns
477 """
479 try:
480 is_string = isinstance(query, basestring)
481 except:
482 is_string = isinstance(query, str)
484 if not is_string:
485 raise RuntimeError("DBObject execute must be called with a string query")
487 unacceptableCommands = ["delete","drop","insert","update"]
488 for badCommand in unacceptableCommands:
489 if query.lower().find(badCommand.lower())>=0:
490 raise RuntimeError("query made to DBObject execute contained %s " % badCommand)
492 self.dtype = dtype
493 retresults = self._postprocess_arbitrary_results(self.connection.session.execute(query).fetchall())
494 return retresults
496 def get_arbitrary_chunk_iterator(self, query, chunk_size = None, dtype =None):
497 """
498 This wrapper exists so that CatalogDBObjects can refer to
499 get_arbitrary_chunk_iterator and DBObjects can refer to
500 get_chunk_iterator
501 """
502 return self.get_chunk_iterator(query, chunk_size = chunk_size, dtype = dtype)
504 def get_chunk_iterator(self, query, chunk_size = None, dtype = None):
505 """
506 Take an arbitrary, user-specified query and return a ChunkIterator that
507 executes that query
509 dtype will tell the ChunkIterator what datatype to expect for this query.
510 This information gets passed to _postprocess_results.
512 If 'None', then _postprocess_results will just guess the datatype
513 and return generic names for the columns.
514 """
515 self.dtype = dtype
516 return ChunkIterator(self, query, chunk_size, arbitrarySQL = True)
518class CatalogDBObjectMeta(type):
519 """Meta class for registering new objects.
521 When any new type of object class is created, this registers it
522 in a `registry` class attribute, available to all derived instance
523 catalog.
524 """
525 def __new__(cls, name, bases, dct):
526 # check if attribute objid is specified.
527 # If not, create a default
528 if 'registry' in dct: 528 ↛ 529line 528 didn't jump to line 529, because the condition on line 528 was never true
529 warnings.warn("registry class attribute should not be "
530 "over-ridden in InstanceCatalog classes. "
531 "Proceed with caution")
532 if 'objid' not in dct:
533 dct['objid'] = name
534 return super(CatalogDBObjectMeta, cls).__new__(cls, name, bases, dct)
536 def __init__(cls, name, bases, dct):
537 # check if 'registry' is specified.
538 # if not, then this is the base class: add the registry
539 if not hasattr(cls, 'registry'):
540 cls.registry = {}
541 else:
542 if not cls.skipRegistration: 542 ↛ 554line 542 didn't jump to line 554, because the condition on line 542 was never false
543 # add this class to the registry
544 if cls.objid in cls.registry: 544 ↛ 545line 544 didn't jump to line 545, because the condition on line 544 was never true
545 srcfile = inspect.getsourcefile(cls.registry[cls.objid])
546 srcline = inspect.getsourcelines(cls.registry[cls.objid])[1]
547 warnings.warn('duplicate object identifier %s specified. '%(cls.objid)+\
548 'This will override previous definition on line %i of %s'%
549 (srcline, srcfile))
550 cls.registry[cls.objid] = cls
552 # check if the list of unique ids is specified
553 # if not, then this is the base class: add the list
554 if not hasattr(cls, 'objectTypeIdList'):
555 cls.objectTypeIdList = []
556 else:
557 if cls.skipRegistration: 557 ↛ 558line 557 didn't jump to line 558, because the condition on line 557 was never true
558 pass
559 elif cls.objectTypeId is None:
560 pass #Don't add typeIds that are None
561 elif cls.objectTypeId in cls.objectTypeIdList:
562 warnings.warn('Duplicate object type id %s specified: '%cls.objectTypeId+\
563 '\nOutput object ids may not be unique.\nThis may not be a problem if you do not '+\
564 'want globally unique id values')
565 else:
566 cls.objectTypeIdList.append(cls.objectTypeId)
567 return super(CatalogDBObjectMeta, cls).__init__(name, bases, dct)
569 def __str__(cls):
570 dbObjects = cls.registry.keys()
571 outstr = "++++++++++++++++++++++++++++++++++++++++++++++\n"+\
572 "Registered object types are:\n"
573 for dbObject in dbObjects:
574 outstr += "%s\n"%(dbObject)
575 outstr += "\n\n"
576 outstr += "To query the possible column names do:\n"
577 outstr += "$> CatalogDBObject.from_objid([name]).show_mapped_columns()\n"
578 outstr += "+++++++++++++++++++++++++++++++++++++++++++++"
579 return outstr
581class CatalogDBObject(with_metaclass(CatalogDBObjectMeta, DBObject)):
582 """Database Object base class
584 """
586 epoch = 2000.0
587 skipRegistration = False
588 objid = None
589 tableid = None
590 idColKey = None
591 objectTypeId = None
592 columns = None
593 generateDefaultColumnMap = True
594 dbDefaultValues = {}
595 raColName = None
596 decColName = None
598 _connection_cache = [] # a list to store open database connections in
600 #Provide information if this object should be tested in the unit test
601 doRunTest = False
602 testObservationMetaData = None
604 #: Mapping of DDL types to python types. Strings are assumed to be 256 characters
605 #: this can be overridden by modifying the dbTypeMap or by making a custom columns
606 #: list.
607 #: numpy doesn't know how to convert decimal.Decimal types, so I changed this to float
608 #: TODO this doesn't seem to make a difference but make sure.
609 dbTypeMap = {'BIGINT':(int,), 'BOOLEAN':(bool,), 'FLOAT':(float,), 'INTEGER':(int,),
610 'NUMERIC':(float,), 'SMALLINT':(int,), 'TINYINT':(int,), 'VARCHAR':(str, 256),
611 'TEXT':(str, 256), 'CLOB':(str, 256), 'NVARCHAR':(str, 256),
612 'NCLOB':(str, 256), 'NTEXT':(str, 256), 'CHAR':(str, 1), 'INT':(int,),
613 'REAL':(float,), 'DOUBLE':(float,), 'STRING':(str, 256), 'DOUBLE_PRECISION':(float,),
614 'DECIMAL':(float,)}
616 @classmethod
617 def from_objid(cls, objid, *args, **kwargs):
618 """Given a string objid, return an instance of
619 the appropriate CatalogDBObject class.
620 """
621 if objid not in cls.registry:
622 raise RuntimeError('Attempting to construct an object that does not exist')
623 cls = cls.registry.get(objid, CatalogDBObject)
624 return cls(*args, **kwargs)
626 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False,
627 table=None, objid=None, idColKey=None, connection=None,
628 cache_connection=True):
629 if not verbose:
630 with warnings.catch_warnings():
631 warnings.simplefilter("ignore", category=sa_exc.SAWarning)
633 if self.tableid is not None and table is not None:
634 raise ValueError("Double-specified tableid in CatalogDBObject:"
635 " once in class definition, once in __init__")
637 if table is not None:
638 self.tableid = table
640 if self.objid is not None and objid is not None:
641 raise ValueError("Double-specified objid in CatalogDBObject:"
642 " once in class definition, once in __init__")
644 if objid is not None:
645 self.objid = objid
647 if self.idColKey is not None and idColKey is not None:
648 raise ValueError("Double-specified idColKey in CatalogDBObject:"
649 " once in class definition, once in __init__")
651 if idColKey is not None:
652 self.idColKey = idColKey
654 if self.idColKey is None:
655 self.idColKey = self.getIdColKey()
656 if (self.objid is None) or (self.tableid is None) or (self.idColKey is None):
657 msg = ("CatalogDBObject must be subclassed, and "
658 "define objid, tableid and idColKey. You are missing: ")
659 if self.objid is None:
660 msg += "objid, "
661 if self.tableid is None:
662 msg += "tableid, "
663 if self.idColKey is None:
664 msg += "idColKey"
665 raise ValueError(msg)
667 if (self.objectTypeId is None) and verbose:
668 warnings.warn("objectTypeId has not "
669 "been set. Input files for phosim are not "
670 "possible.")
672 super(CatalogDBObject, self).__init__(database=database, driver=driver, host=host, port=port,
673 verbose=verbose, connection=connection, cache_connection=True)
675 try:
676 self._get_table()
677 except sa_exc.OperationalError as e:
678 if self.driver == 'mssql+pymssql':
679 message = "\n To connect to the UW CATSIM database: "
680 message += " Check that you have valid connection parameters, an open ssh tunnel "
681 message += "and that your $HOME/.lsst/db-auth.paf contains the appropriate credientials. "
682 message += "Please consult the following link for more information on access: "
683 message += " https://confluence.lsstcorp.org/display/SIM/Accessing+the+UW+CATSIM+Database "
684 else:
685 message = ''
686 raise RuntimeError("Failed to connect to %s: sqlalchemy.%s %s" % (self.connection.engine, e.args[0], message))
688 #Need to do this after the table is instantiated so that
689 #the default columns can be filled from the table object.
690 if self.generateDefaultColumnMap:
691 self._make_default_columns()
692 # build column mapping and type mapping dicts from columns
693 self._make_column_map()
694 self._make_type_map()
696 def show_mapped_columns(self):
697 for col in self.columnMap.keys():
698 print("%s -- %s"%(col, self.typeMap[col][0].__name__))
700 def show_db_columns(self):
701 for col in self.table.c.keys():
702 print("%s -- %s"%(col, self.table.c[col].type.__visit_name__))
705 def getCatalog(self, ftype, *args, **kwargs):
706 try:
707 from lsst.sims.catalogs.definitions import InstanceCatalog
708 return InstanceCatalog.new_catalog(ftype, self, *args, **kwargs)
709 except ImportError:
710 raise ImportError("sims_catalogs not set up. Cannot get InstanceCatalog from the object.")
712 def getIdColKey(self):
713 return self.idColKey
715 def getObjectTypeId(self):
716 return self.objectTypeId
718 def _get_table(self):
719 self.table = Table(self.tableid, self.connection.metadata,
720 autoload=True)
722 def _make_column_map(self):
723 self.columnMap = OrderedDict([(el[0], el[1] if el[1] else el[0])
724 for el in self.columns])
725 def _make_type_map(self):
726 self.typeMap = OrderedDict([(el[0], el[2:] if len(el)> 2 else (float,))
727 for el in self.columns])
729 def _make_default_columns(self):
730 if self.columns:
731 colnames = [el[0] for el in self.columns]
732 else:
733 self.columns = []
734 colnames = []
735 for col in self.table.c.keys():
736 dbtypestr = self.table.c[col].type.__visit_name__
737 dbtypestr = dbtypestr.upper()
738 if col in colnames:
739 if self.verbose: #Warn for possible column redefinition
740 warnings.warn("Database column, %s, overridden in self.columns... "%(col)+
741 "Skipping default assignment.")
742 elif dbtypestr in self.dbTypeMap:
743 self.columns.append((col, col)+self.dbTypeMap[dbtypestr])
744 else:
745 if self.verbose:
746 warnings.warn("Can't create default column for %s. There is no mapping "%(col)+
747 "for type %s. Modify the dbTypeMap, or make a custom columns "%(dbtypestr)+
748 "list.")
750 def _get_column_query(self, colnames=None):
751 """Given a list of valid column names, return the query object"""
752 if colnames is None:
753 colnames = [k for k in self.columnMap]
754 try:
755 vals = [self.columnMap[k] for k in colnames]
756 except KeyError:
757 offending_columns = '\n'
758 for col in colnames:
759 if col in self.columnMap:
760 continue
761 else:
762 offending_columns +='%s\n' % col
763 raise ValueError('entries in colnames must be in self.columnMap. '
764 'These:%sare not' % offending_columns)
766 # Get the first query
767 idColName = self.columnMap[self.idColKey]
768 if idColName in vals:
769 idLabel = self.idColKey
770 else:
771 idLabel = idColName
773 query = self.connection.session.query(self.table.c[idColName].label(idLabel))
775 for col, val in zip(colnames, vals):
776 if val is idColName:
777 continue
778 #Check if the column is a default column (col == val)
779 if col == val:
780 #If column is in the table, use it.
781 query = query.add_columns(self.table.c[col].label(col))
782 else:
783 #If not assume the user specified the column correctly
784 query = query.add_columns(expression.literal_column(val).label(col))
786 return query
788 def filter(self, query, bounds):
789 """Filter the query by the associated metadata"""
790 if bounds is not None:
791 on_clause = bounds.to_SQL(self.raColName,self.decColName)
792 query = query.filter(text(on_clause))
793 return query
795 def _convert_results_to_numpy_recarray_catalogDBObj(self, results):
796 """Post-process the query results to put them
797 in a structured array.
799 **Parameters**
801 * results : a result set as returned by execution of the query
803 **Returns**
805 * _final_pass(retresults) : the result of calling the _final_pass method on a
806 structured array constructed from the query data.
807 """
809 if len(results) > 0:
810 cols = [str(k) for k in results[0].keys()]
811 else:
812 return results
814 if sys.version_info.major == 2:
815 dt_list = []
816 for k in cols:
817 sub_list = [past_str(k)]
818 if self.typeMap[k][0] is not str:
819 for el in self.typeMap[k]:
820 sub_list.append(el)
821 else:
822 sub_list.append(past_str)
823 for el in self.typeMap[k][1:]:
824 sub_list.append(el)
825 dt_list.append(tuple(sub_list))
827 dtype = numpy.dtype(dt_list)
829 else:
830 dtype = numpy.dtype([(k,)+self.typeMap[k] for k in cols])
832 if len(set(cols)&set(self.dbDefaultValues)) > 0:
834 results_array = []
836 for result in results:
837 results_array.append(tuple(result[colName]
838 if result[colName] or
839 colName not in self.dbDefaultValues
840 else self.dbDefaultValues[colName]
841 for colName in cols))
843 else:
844 results_array = [tuple(rr) for rr in results]
846 retresults = numpy.rec.fromrecords(results_array, dtype=dtype)
847 return retresults
849 def _postprocess_results(self, results):
850 if not isinstance(results, numpy.recarray):
851 retresults = self._convert_results_to_numpy_recarray_catalogDBObj(results)
852 else:
853 retresults = results
854 return self._final_pass(retresults)
856 def query_columns(self, colnames=None, chunk_size=None,
857 obs_metadata=None, constraint=None, limit=None):
858 """Execute a query
860 **Parameters**
862 * colnames : list or None
863 a list of valid column names, corresponding to entries in the
864 `columns` class attribute. If not specified, all columns are
865 queried.
866 * chunk_size : int (optional)
867 if specified, then return an iterator object to query the database,
868 each time returning the next `chunk_size` elements. If not
869 specified, all matching results will be returned.
870 * obs_metadata : object (optional)
871 an observation metadata object which has a "filter" method, which
872 will add a filter string to the query.
873 * constraint : str (optional)
874 a string which is interpreted as SQL and used as a predicate on the query
875 * limit : int (optional)
876 limits the number of rows returned by the query
878 **Returns**
880 * result : list or iterator
881 If chunk_size is not specified, then result is a list of all
882 items which match the specified query. If chunk_size is specified,
883 then result is an iterator over lists of the given size.
885 """
886 query = self._get_column_query(colnames)
888 if obs_metadata is not None:
889 query = self.filter(query, obs_metadata.bounds)
891 if constraint is not None:
892 query = query.filter(text(constraint))
894 if limit is not None:
895 query = query.limit(limit)
897 return ChunkIterator(self, query, chunk_size)
899sims_clean_up.targets.append(CatalogDBObject._connection_cache)
901class fileDBObject(CatalogDBObject):
902 ''' Class to read a file into a database and then query it'''
903 #Column names to index. Specify compound indexes using tuples of column names
904 indexCols = []
905 def __init__(self, dataLocatorString, runtable=None, driver="sqlite", host=None, port=None, database=":memory:",
906 dtype=None, numGuess=1000, delimiter=None, verbose=False, idColKey=None, **kwargs):
907 """
908 Initialize an object for querying databases loaded from a file
910 Keyword arguments:
911 @param dataLocatorString: Path to the file to load
912 @param runtable: The name of the table to create. If None, a random table name will be used.
913 @param driver: name of database driver (e.g. 'sqlite', 'mssql+pymssql')
914 @param host: hostname for database connection (None if sqlite)
915 @param port: port for database connection (None if sqlite)
916 @param database: name of database (filename if sqlite)
917 @param dtype: The numpy dtype to use when loading the file. If None, it the dtype will be guessed.
918 @param numGuess: The number of lines to use in guessing the dtype from the file.
919 @param delimiter: The delimiter to use when parsing the file default is white space.
920 @param idColKey: The name of the column that uniquely identifies each row in the database
921 """
922 self.verbose = verbose
924 if idColKey is not None:
925 self.idColKey = idColKey
927 if(self.objid is None) or (self.idColKey is None):
928 raise ValueError("CatalogDBObject must be subclassed, and "
929 "define objid and tableid and idColKey.")
931 if (self.objectTypeId is None) and self.verbose:
932 warnings.warn("objectTypeId has not "
933 "been set. Input files for phosim are not "
934 "possible.")
936 if os.path.exists(dataLocatorString):
937 self.driver = driver
938 self.host = host
939 self.port = port
940 self.database = database
941 self.connection = DBConnection(database=self.database, driver=self.driver, host=self.host,
942 port=self.port, verbose=verbose)
943 self.tableid = loadData(dataLocatorString, dtype, delimiter, runtable, self.idColKey,
944 self.connection.engine, self.connection.metadata, numGuess,
945 indexCols=self.indexCols, **kwargs)
946 self._get_table()
947 else:
948 raise ValueError("Could not locate file %s."%(dataLocatorString))
950 if self.generateDefaultColumnMap:
951 self._make_default_columns()
953 self._make_column_map()
954 self._make_type_map()
956 @classmethod
957 def from_objid(cls, objid, *args, **kwargs):
958 """Given a string objid, return an instance of
959 the appropriate fileDBObject class.
960 """
961 cls = cls.registry.get(objid, CatalogDBObject)
962 return cls(*args, **kwargs)