Coverage for python/lsst/daf/butler/registry/databases/oracle.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["OracleDatabase"]
25from contextlib import closing, contextmanager
26from typing import Optional
28import sqlalchemy
29import sqlalchemy.ext.compiler
31from ..interfaces import Database, ReadOnlyDatabaseError
32from ..nameShrinker import NameShrinker
35class _Merge(sqlalchemy.sql.expression.Executable, sqlalchemy.sql.ClauseElement):
36 """A SQLAlchemy query that compiles to a MERGE invocation that is the
37 equivalent of PostgreSQL and SQLite's INSERT ... ON CONFLICT REPLACE on the
38 primary key constraint for the table.
39 """
41 def __init__(self, table):
42 super().__init__()
43 self.table = table
46@sqlalchemy.ext.compiler.compiles(_Merge, "oracle")
47def _merge(merge, compiler, **kw):
48 """Generate MERGE query for inserting or updating records.
49 """
50 table = merge.table
51 preparer = compiler.preparer
53 allColumns = [col.name for col in table.columns]
54 pkColumns = [col.name for col in table.primary_key]
55 nonPkColumns = [col for col in allColumns if col not in pkColumns]
57 # To properly support type decorators defined in core/ddl.py we need
58 # to pass column type to `bindparam`.
59 selectColumns = [sqlalchemy.sql.bindparam(col.name, type_=col.type).label(col.name)
60 for col in table.columns]
61 selectClause = sqlalchemy.sql.select(selectColumns)
63 tableAlias = table.alias("t")
64 tableAliasText = compiler.process(tableAlias, asfrom=True, **kw)
65 selectAlias = selectClause.alias("d")
66 selectAliasText = compiler.process(selectAlias, asfrom=True, **kw)
68 condition = sqlalchemy.sql.and_(
69 *[tableAlias.columns[col] == selectAlias.columns[col] for col in pkColumns]
70 )
71 conditionText = compiler.process(condition, **kw)
73 query = f"MERGE INTO {tableAliasText}" \
74 f"\nUSING {selectAliasText}" \
75 f"\nON ({conditionText})"
76 updates = []
77 for col in nonPkColumns:
78 src = compiler.process(selectAlias.columns[col], **kw)
79 dst = compiler.process(tableAlias.columns[col], **kw)
80 updates.append(f"{dst} = {src}")
81 updates = ", ".join(updates)
82 query += f"\nWHEN MATCHED THEN UPDATE SET {updates}"
84 insertColumns = ", ".join([preparer.format_column(col) for col in table.columns])
85 insertValues = ", ".join([compiler.process(selectAlias.columns[col], **kw) for col in allColumns])
87 query += f"\nWHEN NOT MATCHED THEN INSERT ({insertColumns}) VALUES ({insertValues})"
88 return query
91class OracleDatabase(Database):
92 """An implementation of the `Database` interface for Oracle.
94 Parameters
95 ----------
96 connection : `sqlalchemy.engine.Connection`
97 An existing connection created by a previous call to `connect`.
98 origin : `int`
99 An integer ID that should be used as the default for any datasets,
100 quanta, or other entities that use a (autoincrement, origin) compound
101 primary key.
102 namespace : `str`, optional
103 The namespace (schema) this database is associated with. If `None`,
104 the default schema for the connection is used (which may be `None`).
105 writeable : `bool`, optional
106 If `True`, allow write operations on the database, including
107 ``CREATE TABLE``.
108 prefix : `str`, optional
109 Prefix to add to all table names, effectively defining a virtual
110 schema that can coexist with others within the same actual database
111 schema. This prefix must not be used in the un-prefixed names of
112 tables.
114 Notes
115 -----
116 To use a prefix from standardized factory functions like `Database.fromUri`
117 and `Database.fromConnectionStruct`, a '+' character in the namespace will
118 be interpreted as a combination of ``namespace`` (first) and ``prefix``
119 (second). Either may be empty. This does *not* work when constructing
120 an `OracleDatabase` instance directly.
121 """
123 def __init__(self, *, connection: sqlalchemy.engine.Connection, origin: int,
124 namespace: Optional[str] = None, writeable: bool = True, prefix: Optional[str] = None):
125 # Get the schema that was included/implicit in the URI we used to
126 # connect.
127 dbapi = connection.engine.raw_connection()
128 namespace = dbapi.current_schema
129 super().__init__(connection=connection, origin=origin, namespace=namespace)
130 self._writeable = writeable
131 self.dsn = dbapi.dsn
132 self.prefix = prefix
133 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
135 @classmethod
136 def connect(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Connection:
137 connection = sqlalchemy.engine.create_engine(uri, pool_size=1).connect()
138 # Work around SQLAlchemy assuming that the Oracle limit on identifier
139 # names is even shorter than it is after 12.2.
140 oracle_ver = connection.engine.dialect._get_server_version_info(connection)
141 if oracle_ver < (12, 2):
142 raise RuntimeError("Oracle server version >= 12.2 required.")
143 connection.engine.dialect.max_identifier_length = 128
144 return connection
146 @classmethod
147 def fromConnection(cls, connection: sqlalchemy.engine.Connection, *, origin: int,
148 namespace: Optional[str] = None, writeable: bool = True) -> Database:
149 if namespace and "+" in namespace:
150 namespace, prefix = namespace.split("+")
151 if not namespace:
152 namespace = None
153 if not prefix:
154 prefix = None
155 else:
156 prefix = None
157 return cls(connection=connection, origin=origin, writeable=writeable, namespace=namespace,
158 prefix=prefix)
160 @contextmanager
161 def transaction(self, *, interrupting: bool = False) -> None:
162 with super().transaction(interrupting=interrupting):
163 if not self.isWriteable():
164 with closing(self._connection.connection.cursor()) as cursor:
165 cursor.execute("SET TRANSACTION READ ONLY")
166 yield
168 def isWriteable(self) -> bool:
169 return self._writeable
171 def __str__(self) -> str:
172 if self.namespace is None:
173 name = self.dsn
174 else:
175 name = f"{self.dsn:self.namespace}"
176 return f"Oracle@{name}"
178 def shrinkDatabaseEntityName(self, original: str) -> str:
179 return self._shrinker.shrink(original)
181 def expandDatabaseEntityName(self, shrunk: str) -> str:
182 return self._shrinker.expand(shrunk)
184 def _mangleTableName(self, name: str) -> str:
185 if self.prefix is not None and not name.startswith(self.prefix):
186 name = self.prefix + name
187 return name
189 def replace(self, table: sqlalchemy.schema.Table, *rows: dict):
190 if not self.isWriteable():
191 raise ReadOnlyDatabaseError(f"Attempt to replace into read-only database '{self}'.")
192 self._connection.execute(_Merge(table), *rows)
194 prefix: Optional[str]
195 """A prefix included in all table names to simulate a database namespace
196 (`str` or `None`).
197 """
199 dsn: str
200 """The TNS entry of the database this instance is connected to (`str`).
201 """