Coverage for python/lsst/sims/catalogs/db/dbConnection.py : 20%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from future import standard_library
2standard_library.install_aliases()
3import sys
4from builtins import str
6# 2017 March 9
7# str_cast exists because numpy.dtype does
8# not like unicode-like things as the names
9# of columns. Unfortunately, in python 2,
10# builtins.str looks unicode-like. We will
11# use str_cast in python 2 to maintain
12# both python 3 compatibility and our use of
13# numpy dtype
14str_cast = str
15if sys.version_info.major == 2: 15 ↛ 16line 15 didn't jump to line 16, because the condition on line 15 was never true
16 from past.builtins import str as past_str
17 str_cast = past_str
19from builtins import zip
20from builtins import object
21import warnings
22import numpy
23import os
24import inspect
25from io import BytesIO
26from collections import OrderedDict
28from .utils import loadData
29from sqlalchemy.orm import scoped_session, sessionmaker
30from sqlalchemy.sql import expression
31from sqlalchemy.engine import reflection, url
32from sqlalchemy import (create_engine, MetaData,
33 Table, event, text)
34from sqlalchemy import exc as sa_exc
35from lsst.daf.butler.registry import DbAuth
36from lsst.sims.utils.CodeUtilities import sims_clean_up
37from lsst.utils import getPackageDir
39#The documentation at http://docs.sqlalchemy.org/en/rel_0_7/core/types.html#sqlalchemy.types.Numeric
40#suggests using the cdecimal module. Since it is not standard, import decimal.
41#TODO: test for cdecimal and use it if it exists.
42import decimal
43from future.utils import with_metaclass
45__all__ = ["ChunkIterator", "DBObject", "CatalogDBObject", "fileDBObject"]
47def valueOfPi():
48 """
49 A function to return the value of pi. This is needed for adding PI()
50 to sqlite databases
51 """
52 return numpy.pi
54def declareTrigFunctions(conn,connection_rec,connection_proxy):
55 """
56 A database event listener
57 which will define the math functions necessary for evaluating the
58 Haversine function in sqlite databases (where they are not otherwise
59 defined)
61 see: http://docs.sqlalchemy.org/en/latest/core/events.html
62 """
64 conn.create_function("COS",1,numpy.cos)
65 conn.create_function("SIN",1,numpy.sin)
66 conn.create_function("ASIN",1,numpy.arcsin)
67 conn.create_function("SQRT",1,numpy.sqrt)
68 conn.create_function("POWER",2,numpy.power)
69 conn.create_function("PI",0,valueOfPi)
71#------------------------------------------------------------
72# Iterator for database chunks
74class ChunkIterator(object):
75 """Iterator for query chunks"""
76 def __init__(self, dbobj, query, chunk_size, arbitrarySQL = False):
77 self.dbobj = dbobj
78 self.exec_query = dbobj.connection.session.execute(query)
79 self.chunk_size = chunk_size
81 #arbitrarySQL exists in case a CatalogDBObject calls
82 #get_arbitrary_chunk_iterator; in that case, we need to
83 #be able to tell this object to call _postprocess_arbitrary_results,
84 #rather than _postprocess_results
85 self.arbitrarySQL = arbitrarySQL
87 def __iter__(self):
88 return self
90 def __next__(self):
91 if self.chunk_size is None and not self.exec_query.closed:
92 chunk = self.exec_query.fetchall()
93 return self._postprocess_results(chunk)
94 elif self.chunk_size is not None:
95 chunk = self.exec_query.fetchmany(self.chunk_size)
96 return self._postprocess_results(chunk)
97 else:
98 raise StopIteration
100 def _postprocess_results(self, chunk):
101 if len(chunk)==0:
102 raise StopIteration
103 if self.arbitrarySQL:
104 return self.dbobj._postprocess_arbitrary_results(chunk)
105 else:
106 return self.dbobj._postprocess_results(chunk)
109class DBConnection(object):
110 """
111 This is a class that will hold the engine, session, and metadata for a
112 DBObject. This will allow multiple DBObjects to share the same
113 sqlalchemy connection, when appropriate.
114 """
116 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False):
117 """
118 @param [in] database is the name of the database file being connected to
120 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
122 @param [in] host is the URL of the remote host, if appropriate
124 @param [in] port is the port on the remote host to connect to, if appropriate
126 @param [in] verbose is a boolean controlling sqlalchemy's verbosity
127 """
129 self._database = database
130 self._driver = driver
131 self._host = host
132 self._port = port
133 self._verbose = verbose
135 self._validate_conn_params()
136 self._connect_to_engine()
138 def __del__(self):
139 try:
140 del self._metadata
141 except AttributeError:
142 pass
144 try:
145 del self._engine
146 except AttributeError:
147 pass
149 try:
150 del self._session
151 except AttributeError:
152 pass
154 def _connect_to_engine(self):
156 #DbAuth will not look up hosts that are None, '' or 0
157 if self._host:
158 # This is triggered when you need to connect to a remote database.
159 # Use 'HOME' as the default location (backwards compatibility) but fail graciously
160 authdir = os.getenv('HOME')
161 if authdir is None:
162 # Use an empty file in this package, which causes
163 # a fallback to database-native authentication.
164 authdir = getPackageDir('sims_catalogs')
165 auth = DbAuth(os.path.join(authdir, "tests", "db-auth.yaml"))
166 else:
167 auth = DbAuth(os.path.join(authdir, ".lsst", "db-auth.yaml"))
168 username, password = auth.getAuth(
169 self._driver, host=self._host, port=self._port,
170 database=self._database)
171 dbUrl = url.URL(self._driver,
172 host=self._host,
173 port=self._port,
174 database=self._database,
175 username=username,
176 password=password)
177 else:
178 dbUrl = url.URL(self._driver,
179 database=self._database)
182 self._engine = create_engine(dbUrl, echo=self._verbose)
184 if self._engine.dialect.name == 'sqlite':
185 event.listen(self._engine, 'checkout', declareTrigFunctions)
187 self._session = scoped_session(sessionmaker(autoflush=True,
188 bind=self._engine))
189 self._metadata = MetaData(bind=self._engine)
192 def _validate_conn_params(self):
193 """Validate connection parameters
195 - Check if user passed dbAddress instead of an database. Convert and warn.
196 - Check that required connection paramters are present
197 - Replace default host/port if driver is 'sqlite'
198 """
200 if self._database is None:
201 raise AttributeError("Cannot instantiate DBConnection; database is 'None'")
203 if '//' in self._database:
204 warnings.warn("Database name '%s' is invalid but looks like a dbAddress. "
205 "Attempting to convert to database, driver, host, "
206 "and port parameters. Any usernames and passwords are ignored and must "
207 "be in the db-auth.paf policy file. "%(self.database), FutureWarning)
209 dbUrl = url.make_url(self._database)
210 dialect = dbUrl.get_dialect()
211 self._driver = dialect.name + '+' + dialect.driver if dialect.driver else dialect.name
212 for key, value in dbUrl.translate_connect_args().items():
213 if value is not None:
214 setattr(self, '_'+key, value)
216 errMessage = "Please supply a 'driver' kwarg to the constructor or in class definition. "
217 errMessage += "'driver' is formatted as dialect+driver, such as 'sqlite' or 'mssql+pymssql'."
218 if not hasattr(self, '_driver'):
219 raise AttributeError("%s has no attribute 'driver'. "%(self.__class__.__name__) + errMessage)
220 elif self._driver is None:
221 raise AttributeError("%s.driver is None. "%(self.__class__.__name__) + errMessage)
223 errMessage = "Please supply a 'database' kwarg to the constructor or in class definition. "
224 errMessage += " 'database' is the database name or the filename path if driver is 'sqlite'. "
225 if not hasattr(self, '_database'):
226 raise AttributeError("%s has no attribute 'database'. "%(self.__class__.__name__) + errMessage)
227 elif self._database is None:
228 raise AttributeError("%s.database is None. "%(self.__class__.__name__) + errMessage)
230 if 'sqlite' in self._driver:
231 #When passed sqlite database, override default host/port
232 self._host = None
233 self._port = None
236 def __eq__(self, other):
237 return (str(self._database) == str(other._database)) and \
238 (str(self._driver) == str(other._driver)) and \
239 (str(self._host) == str(other._host)) and \
240 (str(self._port) == str(other._port))
243 @property
244 def engine(self):
245 return self._engine
247 @property
248 def session(self):
249 return self._session
252 @property
253 def metadata(self):
254 return self._metadata
256 @property
257 def database(self):
258 return self._database
260 @property
261 def driver(self):
262 return self._driver
264 @property
265 def host(self):
266 return self._host
268 @property
269 def port(self):
270 return self._port
272 @property
273 def verbose(self):
274 return self._verbose
277class DBObject(object):
279 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False,
280 connection=None, cache_connection=True):
281 """
282 Initialize DBObject.
284 @param [in] database is the name of the database file being connected to
286 @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
288 @param [in] host is the URL of the remote host, if appropriate
290 @param [in] port is the port on the remote host to connect to, if appropriate
292 @param [in] verbose is a boolean controlling sqlalchemy's verbosity (default False)
294 @param [in] connection is an optional instance of DBConnection, in the event that
295 this DBObject can share a database connection with another DBObject. This is only
296 necessary or even possible in a few specialized cases and should be used carefully.
298 @param [in] cache_connection is a boolean. If True, DBObject will use a cache of
299 DBConnections (if available) to get the connection to this database.
300 """
302 self.dtype = None
303 #this is a cache for the query, so that any one query does not have to guess dtype multiple times
305 if connection is None:
306 #Explicit constructor to DBObject preferred
307 kwargDict = dict(database=database,
308 driver=driver,
309 host=host,
310 port=port,
311 verbose=verbose)
313 for key, value in kwargDict.items():
314 if value is not None or not hasattr(self, key):
315 setattr(self, key, value)
317 self.connection = self._get_connection(self.database, self.driver, self.host, self.port,
318 use_cache=cache_connection)
320 else:
321 self.connection = connection
322 self.database = connection.database
323 self.driver = connection.driver
324 self.host = connection.host
325 self.port = connection.port
326 self.verbose = connection.verbose
328 def _get_connection(self, database, driver, host, port, use_cache=True):
329 """
330 Search self._connection_cache (if it exists; it won't for DBObject, but
331 will for CatalogDBObject) for a DBConnection matching the specified
332 parameters. If it exists, return it. If not, open a connection to
333 the specified database, add it to the cache, and return the connection.
335 Parameters
336 ----------
337 database is the name of the database file being connected to
339 driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)
341 host is the URL of the remote host, if appropriate
343 port is the port on the remote host to connect to, if appropriate
345 use_cache is a boolean specifying whether or not we try to use the
346 cache of database connections (you don't want to if opening many
347 connections in many threads).
348 """
350 if use_cache and hasattr(self, '_connection_cache'):
351 for conn in self._connection_cache:
352 if str(conn.database) == str(database):
353 if str(conn.driver) == str(driver):
354 if str(conn.host) == str(host):
355 if str(conn.port) == str(port):
356 return conn
358 conn = DBConnection(database=database, driver=driver, host=host, port=port)
360 if use_cache and hasattr(self, '_connection_cache'):
361 self._connection_cache.append(conn)
363 return conn
365 def get_table_names(self):
366 """Return a list of the names of the tables (and views) in the database"""
367 return [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_table_names()] + \
368 [str(xx) for xx in reflection.Inspector.from_engine(self.connection.engine).get_view_names()]
370 def get_column_names(self, tableName=None):
371 """
372 Return a list of the names of the columns in the specified table.
373 If no table is specified, return a dict of lists. The dict will be keyed
374 to the table names. The lists will be of the column names in that table
375 """
376 tableNameList = self.get_table_names()
377 if tableName is not None:
378 if tableName not in tableNameList:
379 return []
380 return [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(tableName)]
381 else:
382 columnDict = {}
383 for name in tableNameList:
384 columnList = [str_cast(xx['name']) for xx in reflection.Inspector.from_engine(self.connection.engine).get_columns(name)]
385 columnDict[name] = columnList
386 return columnDict
388 def _final_pass(self, results):
389 """ Make final modifications to a set of data before returning it to the user
391 **Parameters**
393 * results : a structured array constructed from the result set from a query
395 **Returns**
397 * results : a potentially modified structured array. The default is to do nothing.
399 """
400 return results
402 def _convert_results_to_numpy_recarray_dbobj(self, results):
403 if self.dtype is None:
404 """
405 Determine the dtype from the data.
406 Store it in a global variable so we do not have to repeat on every chunk.
407 """
408 dataString = ''
410 # We are going to detect the dtype by reading in a single row
411 # of data with np.genfromtxt. To do this, we must pass the
412 # row as a string delimited by a specified character. Here we
413 # select a character that does not occur anywhere in the data.
414 delimit_char_list = [',', ';', '|', ':', '/', '\\']
415 delimit_char = None
416 for cc in delimit_char_list:
417 is_valid = True
418 for xx in results[0]:
419 if cc in str(xx):
420 is_valid = False
421 break
423 if is_valid:
424 delimit_char = cc
425 break
427 if delimit_char is None:
428 raise RuntimeError("DBObject could not detect the dtype of your return rows\n"
429 "Please specify a dtype with the 'dtype' kwarg.")
431 for xx in results[0]:
432 if dataString != '':
433 dataString += delimit_char
434 dataString += str(xx)
435 names = [str_cast(ww) for ww in results[0].keys()]
436 dataArr = numpy.genfromtxt(BytesIO(dataString.encode()), dtype=None,
437 names=names, delimiter=delimit_char,
438 encoding='utf-8')
439 dt_list = []
440 for name in dataArr.dtype.names:
441 type_name = str(dataArr.dtype[name])
442 sub_list = [name]
443 if type_name.startswith('S') or type_name.startswith('|S'):
444 sub_list.append(str_cast)
445 sub_list.append(int(type_name.replace('S','').replace('|','')))
446 else:
447 sub_list.append(dataArr.dtype[name])
448 dt_list.append(tuple(sub_list))
450 self.dtype = numpy.dtype(dt_list)
452 if len(results) == 0:
453 return numpy.recarray((0,), dtype = self.dtype)
455 retresults = numpy.rec.fromrecords([tuple(xx) for xx in results],dtype = self.dtype)
456 return retresults
458 def _postprocess_results(self, results):
459 """
460 This wrapper exists so that a ChunkIterator built from a DBObject
461 can have the same API as a ChunkIterator built from a CatalogDBObject
462 """
463 return self._postprocess_arbitrary_results(results)
466 def _postprocess_arbitrary_results(self, results):
468 if not isinstance(results, numpy.recarray):
469 retresults = self._convert_results_to_numpy_recarray_dbobj(results)
470 else:
471 retresults = results
473 return self._final_pass(retresults)
475 def execute_arbitrary(self, query, dtype = None):
476 """
477 Executes an arbitrary query. Returns a recarray of the results.
479 dtype will be the dtype of the output recarray. If it is None, then
480 the code will guess the datatype and assign generic names to the columns
481 """
483 try:
484 is_string = isinstance(query, basestring)
485 except:
486 is_string = isinstance(query, str)
488 if not is_string:
489 raise RuntimeError("DBObject execute must be called with a string query")
491 unacceptableCommands = ["delete","drop","insert","update"]
492 for badCommand in unacceptableCommands:
493 if query.lower().find(badCommand.lower())>=0:
494 raise RuntimeError("query made to DBObject execute contained %s " % badCommand)
496 self.dtype = dtype
497 retresults = self._postprocess_arbitrary_results(self.connection.session.execute(query).fetchall())
498 return retresults
500 def get_arbitrary_chunk_iterator(self, query, chunk_size = None, dtype =None):
501 """
502 This wrapper exists so that CatalogDBObjects can refer to
503 get_arbitrary_chunk_iterator and DBObjects can refer to
504 get_chunk_iterator
505 """
506 return self.get_chunk_iterator(query, chunk_size = chunk_size, dtype = dtype)
508 def get_chunk_iterator(self, query, chunk_size = None, dtype = None):
509 """
510 Take an arbitrary, user-specified query and return a ChunkIterator that
511 executes that query
513 dtype will tell the ChunkIterator what datatype to expect for this query.
514 This information gets passed to _postprocess_results.
516 If 'None', then _postprocess_results will just guess the datatype
517 and return generic names for the columns.
518 """
519 self.dtype = dtype
520 return ChunkIterator(self, query, chunk_size, arbitrarySQL = True)
522class CatalogDBObjectMeta(type):
523 """Meta class for registering new objects.
525 When any new type of object class is created, this registers it
526 in a `registry` class attribute, available to all derived instance
527 catalog.
528 """
529 def __new__(cls, name, bases, dct):
530 # check if attribute objid is specified.
531 # If not, create a default
532 if 'registry' in dct: 532 ↛ 533line 532 didn't jump to line 533, because the condition on line 532 was never true
533 warnings.warn("registry class attribute should not be "
534 "over-ridden in InstanceCatalog classes. "
535 "Proceed with caution")
536 if 'objid' not in dct:
537 dct['objid'] = name
538 return super(CatalogDBObjectMeta, cls).__new__(cls, name, bases, dct)
540 def __init__(cls, name, bases, dct):
541 # check if 'registry' is specified.
542 # if not, then this is the base class: add the registry
543 if not hasattr(cls, 'registry'):
544 cls.registry = {}
545 else:
546 if not cls.skipRegistration: 546 ↛ 558line 546 didn't jump to line 558, because the condition on line 546 was never false
547 # add this class to the registry
548 if cls.objid in cls.registry: 548 ↛ 549line 548 didn't jump to line 549, because the condition on line 548 was never true
549 srcfile = inspect.getsourcefile(cls.registry[cls.objid])
550 srcline = inspect.getsourcelines(cls.registry[cls.objid])[1]
551 warnings.warn('duplicate object identifier %s specified. '%(cls.objid)+\
552 'This will override previous definition on line %i of %s'%
553 (srcline, srcfile))
554 cls.registry[cls.objid] = cls
556 # check if the list of unique ids is specified
557 # if not, then this is the base class: add the list
558 if not hasattr(cls, 'objectTypeIdList'):
559 cls.objectTypeIdList = []
560 else:
561 if cls.skipRegistration: 561 ↛ 562line 561 didn't jump to line 562, because the condition on line 561 was never true
562 pass
563 elif cls.objectTypeId is None:
564 pass #Don't add typeIds that are None
565 elif cls.objectTypeId in cls.objectTypeIdList:
566 warnings.warn('Duplicate object type id %s specified: '%cls.objectTypeId+\
567 '\nOutput object ids may not be unique.\nThis may not be a problem if you do not '+\
568 'want globally unique id values')
569 else:
570 cls.objectTypeIdList.append(cls.objectTypeId)
571 return super(CatalogDBObjectMeta, cls).__init__(name, bases, dct)
573 def __str__(cls):
574 dbObjects = cls.registry.keys()
575 outstr = "++++++++++++++++++++++++++++++++++++++++++++++\n"+\
576 "Registered object types are:\n"
577 for dbObject in dbObjects:
578 outstr += "%s\n"%(dbObject)
579 outstr += "\n\n"
580 outstr += "To query the possible column names do:\n"
581 outstr += "$> CatalogDBObject.from_objid([name]).show_mapped_columns()\n"
582 outstr += "+++++++++++++++++++++++++++++++++++++++++++++"
583 return outstr
585class CatalogDBObject(with_metaclass(CatalogDBObjectMeta, DBObject)):
586 """Database Object base class
588 """
590 epoch = 2000.0
591 skipRegistration = False
592 objid = None
593 tableid = None
594 idColKey = None
595 objectTypeId = None
596 columns = None
597 generateDefaultColumnMap = True
598 dbDefaultValues = {}
599 raColName = None
600 decColName = None
602 _connection_cache = [] # a list to store open database connections in
604 #Provide information if this object should be tested in the unit test
605 doRunTest = False
606 testObservationMetaData = None
608 #: Mapping of DDL types to python types. Strings are assumed to be 256 characters
609 #: this can be overridden by modifying the dbTypeMap or by making a custom columns
610 #: list.
611 #: numpy doesn't know how to convert decimal.Decimal types, so I changed this to float
612 #: TODO this doesn't seem to make a difference but make sure.
613 dbTypeMap = {'BIGINT':(int,), 'BOOLEAN':(bool,), 'FLOAT':(float,), 'INTEGER':(int,),
614 'NUMERIC':(float,), 'SMALLINT':(int,), 'TINYINT':(int,), 'VARCHAR':(str, 256),
615 'TEXT':(str, 256), 'CLOB':(str, 256), 'NVARCHAR':(str, 256),
616 'NCLOB':(str, 256), 'NTEXT':(str, 256), 'CHAR':(str, 1), 'INT':(int,),
617 'REAL':(float,), 'DOUBLE':(float,), 'STRING':(str, 256), 'DOUBLE_PRECISION':(float,),
618 'DECIMAL':(float,)}
620 @classmethod
621 def from_objid(cls, objid, *args, **kwargs):
622 """Given a string objid, return an instance of
623 the appropriate CatalogDBObject class.
624 """
625 if objid not in cls.registry:
626 raise RuntimeError('Attempting to construct an object that does not exist')
627 cls = cls.registry.get(objid, CatalogDBObject)
628 return cls(*args, **kwargs)
630 def __init__(self, database=None, driver=None, host=None, port=None, verbose=False,
631 table=None, objid=None, idColKey=None, connection=None,
632 cache_connection=True):
633 if not verbose:
634 with warnings.catch_warnings():
635 warnings.simplefilter("ignore", category=sa_exc.SAWarning)
637 if self.tableid is not None and table is not None:
638 raise ValueError("Double-specified tableid in CatalogDBObject:"
639 " once in class definition, once in __init__")
641 if table is not None:
642 self.tableid = table
644 if self.objid is not None and objid is not None:
645 raise ValueError("Double-specified objid in CatalogDBObject:"
646 " once in class definition, once in __init__")
648 if objid is not None:
649 self.objid = objid
651 if self.idColKey is not None and idColKey is not None:
652 raise ValueError("Double-specified idColKey in CatalogDBObject:"
653 " once in class definition, once in __init__")
655 if idColKey is not None:
656 self.idColKey = idColKey
658 if self.idColKey is None:
659 self.idColKey = self.getIdColKey()
660 if (self.objid is None) or (self.tableid is None) or (self.idColKey is None):
661 msg = ("CatalogDBObject must be subclassed, and "
662 "define objid, tableid and idColKey. You are missing: ")
663 if self.objid is None:
664 msg += "objid, "
665 if self.tableid is None:
666 msg += "tableid, "
667 if self.idColKey is None:
668 msg += "idColKey"
669 raise ValueError(msg)
671 if (self.objectTypeId is None) and verbose:
672 warnings.warn("objectTypeId has not "
673 "been set. Input files for phosim are not "
674 "possible.")
676 super(CatalogDBObject, self).__init__(database=database, driver=driver, host=host, port=port,
677 verbose=verbose, connection=connection, cache_connection=True)
679 try:
680 self._get_table()
681 except sa_exc.OperationalError as e:
682 if self.driver == 'mssql+pymssql':
683 message = "\n To connect to the UW CATSIM database: "
684 message += " Check that you have valid connection parameters, an open ssh tunnel "
685 message += "and that your $HOME/.lsst/db-auth.paf contains the appropriate credientials. "
686 message += "Please consult the following link for more information on access: "
687 message += " https://confluence.lsstcorp.org/display/SIM/Accessing+the+UW+CATSIM+Database "
688 else:
689 message = ''
690 raise RuntimeError("Failed to connect to %s: sqlalchemy.%s %s" % (self.connection.engine, e.args[0], message))
692 #Need to do this after the table is instantiated so that
693 #the default columns can be filled from the table object.
694 if self.generateDefaultColumnMap:
695 self._make_default_columns()
696 # build column mapping and type mapping dicts from columns
697 self._make_column_map()
698 self._make_type_map()
700 def show_mapped_columns(self):
701 for col in self.columnMap.keys():
702 print("%s -- %s"%(col, self.typeMap[col][0].__name__))
704 def show_db_columns(self):
705 for col in self.table.c.keys():
706 print("%s -- %s"%(col, self.table.c[col].type.__visit_name__))
709 def getCatalog(self, ftype, *args, **kwargs):
710 try:
711 from lsst.sims.catalogs.definitions import InstanceCatalog
712 return InstanceCatalog.new_catalog(ftype, self, *args, **kwargs)
713 except ImportError:
714 raise ImportError("sims_catalogs not set up. Cannot get InstanceCatalog from the object.")
716 def getIdColKey(self):
717 return self.idColKey
719 def getObjectTypeId(self):
720 return self.objectTypeId
722 def _get_table(self):
723 self.table = Table(self.tableid, self.connection.metadata,
724 autoload=True)
726 def _make_column_map(self):
727 self.columnMap = OrderedDict([(el[0], el[1] if el[1] else el[0])
728 for el in self.columns])
729 def _make_type_map(self):
730 self.typeMap = OrderedDict([(el[0], el[2:] if len(el)> 2 else (float,))
731 for el in self.columns])
733 def _make_default_columns(self):
734 if self.columns:
735 colnames = [el[0] for el in self.columns]
736 else:
737 self.columns = []
738 colnames = []
739 for col in self.table.c.keys():
740 dbtypestr = self.table.c[col].type.__visit_name__
741 dbtypestr = dbtypestr.upper()
742 if col in colnames:
743 if self.verbose: #Warn for possible column redefinition
744 warnings.warn("Database column, %s, overridden in self.columns... "%(col)+
745 "Skipping default assignment.")
746 elif dbtypestr in self.dbTypeMap:
747 self.columns.append((col, col)+self.dbTypeMap[dbtypestr])
748 else:
749 if self.verbose:
750 warnings.warn("Can't create default column for %s. There is no mapping "%(col)+
751 "for type %s. Modify the dbTypeMap, or make a custom columns "%(dbtypestr)+
752 "list.")
754 def _get_column_query(self, colnames=None):
755 """Given a list of valid column names, return the query object"""
756 if colnames is None:
757 colnames = [k for k in self.columnMap]
758 try:
759 vals = [self.columnMap[k] for k in colnames]
760 except KeyError:
761 offending_columns = '\n'
762 for col in colnames:
763 if col in self.columnMap:
764 continue
765 else:
766 offending_columns +='%s\n' % col
767 raise ValueError('entries in colnames must be in self.columnMap. '
768 'These:%sare not' % offending_columns)
770 # Get the first query
771 idColName = self.columnMap[self.idColKey]
772 if idColName in vals:
773 idLabel = self.idColKey
774 else:
775 idLabel = idColName
777 query = self.connection.session.query(self.table.c[idColName].label(idLabel))
779 for col, val in zip(colnames, vals):
780 if val is idColName:
781 continue
782 #Check if the column is a default column (col == val)
783 if col == val:
784 #If column is in the table, use it.
785 query = query.add_columns(self.table.c[col].label(col))
786 else:
787 #If not assume the user specified the column correctly
788 query = query.add_columns(expression.literal_column(val).label(col))
790 return query
792 def filter(self, query, bounds):
793 """Filter the query by the associated metadata"""
794 if bounds is not None:
795 on_clause = bounds.to_SQL(self.raColName,self.decColName)
796 query = query.filter(text(on_clause))
797 return query
799 def _convert_results_to_numpy_recarray_catalogDBObj(self, results):
800 """Post-process the query results to put them
801 in a structured array.
803 **Parameters**
805 * results : a result set as returned by execution of the query
807 **Returns**
809 * _final_pass(retresults) : the result of calling the _final_pass method on a
810 structured array constructed from the query data.
811 """
813 if len(results) > 0:
814 cols = [str(k) for k in results[0].keys()]
815 else:
816 return results
818 if sys.version_info.major == 2:
819 dt_list = []
820 for k in cols:
821 sub_list = [past_str(k)]
822 if self.typeMap[k][0] is not str:
823 for el in self.typeMap[k]:
824 sub_list.append(el)
825 else:
826 sub_list.append(past_str)
827 for el in self.typeMap[k][1:]:
828 sub_list.append(el)
829 dt_list.append(tuple(sub_list))
831 dtype = numpy.dtype(dt_list)
833 else:
834 dtype = numpy.dtype([(k,)+self.typeMap[k] for k in cols])
836 if len(set(cols)&set(self.dbDefaultValues)) > 0:
838 results_array = []
840 for result in results:
841 results_array.append(tuple(result[colName]
842 if result[colName] or
843 colName not in self.dbDefaultValues
844 else self.dbDefaultValues[colName]
845 for colName in cols))
847 else:
848 results_array = [tuple(rr) for rr in results]
850 retresults = numpy.rec.fromrecords(results_array, dtype=dtype)
851 return retresults
853 def _postprocess_results(self, results):
854 if not isinstance(results, numpy.recarray):
855 retresults = self._convert_results_to_numpy_recarray_catalogDBObj(results)
856 else:
857 retresults = results
858 return self._final_pass(retresults)
860 def query_columns(self, colnames=None, chunk_size=None,
861 obs_metadata=None, constraint=None, limit=None):
862 """Execute a query
864 **Parameters**
866 * colnames : list or None
867 a list of valid column names, corresponding to entries in the
868 `columns` class attribute. If not specified, all columns are
869 queried.
870 * chunk_size : int (optional)
871 if specified, then return an iterator object to query the database,
872 each time returning the next `chunk_size` elements. If not
873 specified, all matching results will be returned.
874 * obs_metadata : object (optional)
875 an observation metadata object which has a "filter" method, which
876 will add a filter string to the query.
877 * constraint : str (optional)
878 a string which is interpreted as SQL and used as a predicate on the query
879 * limit : int (optional)
880 limits the number of rows returned by the query
882 **Returns**
884 * result : list or iterator
885 If chunk_size is not specified, then result is a list of all
886 items which match the specified query. If chunk_size is specified,
887 then result is an iterator over lists of the given size.
889 """
890 query = self._get_column_query(colnames)
892 if obs_metadata is not None:
893 query = self.filter(query, obs_metadata.bounds)
895 if constraint is not None:
896 query = query.filter(text(constraint))
898 if limit is not None:
899 query = query.limit(limit)
901 return ChunkIterator(self, query, chunk_size)
903sims_clean_up.targets.append(CatalogDBObject._connection_cache)
905class fileDBObject(CatalogDBObject):
906 ''' Class to read a file into a database and then query it'''
907 #Column names to index. Specify compound indexes using tuples of column names
908 indexCols = []
909 def __init__(self, dataLocatorString, runtable=None, driver="sqlite", host=None, port=None, database=":memory:",
910 dtype=None, numGuess=1000, delimiter=None, verbose=False, idColKey=None, **kwargs):
911 """
912 Initialize an object for querying databases loaded from a file
914 Keyword arguments:
915 @param dataLocatorString: Path to the file to load
916 @param runtable: The name of the table to create. If None, a random table name will be used.
917 @param driver: name of database driver (e.g. 'sqlite', 'mssql+pymssql')
918 @param host: hostname for database connection (None if sqlite)
919 @param port: port for database connection (None if sqlite)
920 @param database: name of database (filename if sqlite)
921 @param dtype: The numpy dtype to use when loading the file. If None, it the dtype will be guessed.
922 @param numGuess: The number of lines to use in guessing the dtype from the file.
923 @param delimiter: The delimiter to use when parsing the file default is white space.
924 @param idColKey: The name of the column that uniquely identifies each row in the database
925 """
926 self.verbose = verbose
928 if idColKey is not None:
929 self.idColKey = idColKey
931 if(self.objid is None) or (self.idColKey is None):
932 raise ValueError("CatalogDBObject must be subclassed, and "
933 "define objid and tableid and idColKey.")
935 if (self.objectTypeId is None) and self.verbose:
936 warnings.warn("objectTypeId has not "
937 "been set. Input files for phosim are not "
938 "possible.")
940 if os.path.exists(dataLocatorString):
941 self.driver = driver
942 self.host = host
943 self.port = port
944 self.database = database
945 self.connection = DBConnection(database=self.database, driver=self.driver, host=self.host,
946 port=self.port, verbose=verbose)
947 self.tableid = loadData(dataLocatorString, dtype, delimiter, runtable, self.idColKey,
948 self.connection.engine, self.connection.metadata, numGuess,
949 indexCols=self.indexCols, **kwargs)
950 self._get_table()
951 else:
952 raise ValueError("Could not locate file %s."%(dataLocatorString))
954 if self.generateDefaultColumnMap:
955 self._make_default_columns()
957 self._make_column_map()
958 self._make_type_map()
960 @classmethod
961 def from_objid(cls, objid, *args, **kwargs):
962 """Given a string objid, return an instance of
963 the appropriate fileDBObject class.
964 """
965 cls = cls.registry.get(objid, CatalogDBObject)
966 return cls(*args, **kwargs)