Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["OracleDatabase"] 

24 

25from contextlib import closing, contextmanager 

26import copy 

27from typing import Optional 

28 

29import sqlalchemy 

30import sqlalchemy.ext.compiler 

31 

32from ..interfaces import Database, ReadOnlyDatabaseError 

33from ...core import ddl 

34from ..nameShrinker import NameShrinker 

35 

36 

37class _Merge(sqlalchemy.sql.expression.Executable, sqlalchemy.sql.ClauseElement): 

38 """A SQLAlchemy query that compiles to a MERGE invocation that is the 

39 equivalent of PostgreSQL and SQLite's INSERT ... ON CONFLICT REPLACE on the 

40 primary key constraint for the table. 

41 """ 

42 

43 def __init__(self, table): 

44 super().__init__() 

45 self.table = table 

46 

47 

48@sqlalchemy.ext.compiler.compiles(_Merge, "oracle") 

49def _merge(merge, compiler, **kw): 

50 """Generate MERGE query for inserting or updating records. 

51 """ 

52 table = merge.table 

53 preparer = compiler.preparer 

54 

55 allColumns = [col.name for col in table.columns] 

56 pkColumns = [col.name for col in table.primary_key] 

57 nonPkColumns = [col for col in allColumns if col not in pkColumns] 

58 

59 # To properly support type decorators defined in core/ddl.py we need 

60 # to pass column type to `bindparam`. 

61 selectColumns = [sqlalchemy.sql.bindparam(col.name, type_=col.type).label(col.name) 

62 for col in table.columns] 

63 selectClause = sqlalchemy.sql.select(selectColumns) 

64 

65 tableAlias = table.alias("t") 

66 tableAliasText = compiler.process(tableAlias, asfrom=True, **kw) 

67 selectAlias = selectClause.alias("d") 

68 selectAliasText = compiler.process(selectAlias, asfrom=True, **kw) 

69 

70 condition = sqlalchemy.sql.and_( 

71 *[tableAlias.columns[col] == selectAlias.columns[col] for col in pkColumns] 

72 ) 

73 conditionText = compiler.process(condition, **kw) 

74 

75 query = f"MERGE INTO {tableAliasText}" \ 

76 f"\nUSING {selectAliasText}" \ 

77 f"\nON ({conditionText})" 

78 updates = [] 

79 for col in nonPkColumns: 

80 src = compiler.process(selectAlias.columns[col], **kw) 

81 dst = compiler.process(tableAlias.columns[col], **kw) 

82 updates.append(f"{dst} = {src}") 

83 updates = ", ".join(updates) 

84 query += f"\nWHEN MATCHED THEN UPDATE SET {updates}" 

85 

86 insertColumns = ", ".join([preparer.format_column(col) for col in table.columns]) 

87 insertValues = ", ".join([compiler.process(selectAlias.columns[col], **kw) for col in allColumns]) 

88 

89 query += f"\nWHEN NOT MATCHED THEN INSERT ({insertColumns}) VALUES ({insertValues})" 

90 return query 

91 

92 

93class OracleDatabase(Database): 

94 """An implementation of the `Database` interface for Oracle. 

95 

96 Parameters 

97 ---------- 

98 connection : `sqlalchemy.engine.Connection` 

99 An existing connection created by a previous call to `connect`. 

100 origin : `int` 

101 An integer ID that should be used as the default for any datasets, 

102 quanta, or other entities that use a (autoincrement, origin) compound 

103 primary key. 

104 namespace : `str`, optional 

105 The namespace (schema) this database is associated with. If `None`, 

106 the default schema for the connection is used (which may be `None`). 

107 writeable : `bool`, optional 

108 If `True`, allow write operations on the database, including 

109 ``CREATE TABLE``. 

110 prefix : `str`, optional 

111 Prefix to add to all table names, effectively defining a virtual 

112 schema that can coexist with others within the same actual database 

113 schema. This prefix must not be used in the un-prefixed names of 

114 tables. 

115 

116 Notes 

117 ----- 

118 To use a prefix from standardized factory functions like `Database.fromUri` 

119 and `Database.fromConnectionStruct`, a '+' character in the namespace will 

120 be interpreted as a combination of ``namespace`` (first) and ``prefix`` 

121 (second). Either may be empty. This does *not* work when constructing 

122 an `OracleDatabase` instance directly. 

123 """ 

124 

125 def __init__(self, *, connection: sqlalchemy.engine.Connection, origin: int, 

126 namespace: Optional[str] = None, writeable: bool = True, prefix: Optional[str] = None): 

127 # Get the schema that was included/implicit in the URI we used to 

128 # connect. 

129 dbapi = connection.engine.raw_connection() 

130 namespace = dbapi.current_schema 

131 super().__init__(connection=connection, origin=origin, namespace=namespace) 

132 self._writeable = writeable 

133 self.dsn = dbapi.dsn 

134 self.prefix = prefix 

135 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

136 

137 @classmethod 

138 def connect(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Connection: 

139 connection = sqlalchemy.engine.create_engine(uri, pool_size=1).connect() 

140 # Work around SQLAlchemy assuming that the Oracle limit on identifier 

141 # names is even shorter than it is after 12.2. 

142 oracle_ver = connection.engine.dialect._get_server_version_info(connection) 

143 if oracle_ver < (12, 2): 

144 raise RuntimeError("Oracle server version >= 12.2 required.") 

145 connection.engine.dialect.max_identifier_length = 128 

146 return connection 

147 

148 @classmethod 

149 def fromConnection(cls, connection: sqlalchemy.engine.Connection, *, origin: int, 

150 namespace: Optional[str] = None, writeable: bool = True) -> Database: 

151 if namespace and "+" in namespace: 

152 namespace, prefix = namespace.split("+") 

153 if not namespace: 

154 namespace = None 

155 if not prefix: 

156 prefix = None 

157 else: 

158 prefix = None 

159 return cls(connection=connection, origin=origin, writeable=writeable, namespace=namespace, 

160 prefix=prefix) 

161 

162 @contextmanager 

163 def transaction(self, *, interrupting: bool = False) -> None: 

164 with super().transaction(interrupting=interrupting): 

165 if not self.isWriteable(): 

166 with closing(self._connection.connection.cursor()) as cursor: 

167 cursor.execute("SET TRANSACTION READ ONLY") 

168 yield 

169 

170 def isWriteable(self) -> bool: 

171 return self._writeable 

172 

173 def __str__(self) -> str: 

174 if self.namespace is None: 

175 name = self.dsn 

176 else: 

177 name = f"{self.dsn:self.namespace}" 

178 return f"Oracle@{name}" 

179 

180 def shrinkDatabaseEntityName(self, original: str) -> str: 

181 return self._shrinker.shrink(original) 

182 

183 def expandDatabaseEntityName(self, shrunk: str) -> str: 

184 return self._shrinker.expand(shrunk) 

185 

186 def _convertForeignKeySpec(self, table: str, spec: ddl.ForeignKeySpec, metadata: sqlalchemy.MetaData, 

187 **kwds) -> sqlalchemy.schema.ForeignKeyConstraint: 

188 if self.prefix is not None: 

189 spec = copy.copy(spec) 

190 spec.table = self.prefix + spec.table 

191 return super()._convertForeignKeySpec(table, spec, metadata, **kwds) 

192 

193 def _convertTableSpec(self, name: str, spec: ddl.TableSpec, metadata: sqlalchemy.MetaData, 

194 **kwds) -> sqlalchemy.schema.Table: 

195 if self.prefix is not None and not name.startswith(self.prefix): 

196 name = self.prefix + name 

197 return super()._convertTableSpec(name, spec, metadata, **kwds) 

198 

199 def getExistingTable(self, name: str, spec: ddl.TableSpec) -> Optional[sqlalchemy.schema.Table]: 

200 if self.prefix is not None and not name.startswith(self.prefix): 

201 name = self.prefix + name 

202 return super().getExistingTable(name, spec) 

203 

204 def replace(self, table: sqlalchemy.schema.Table, *rows: dict): 

205 if not self.isWriteable(): 

206 raise ReadOnlyDatabaseError(f"Attempt to replace into read-only database '{self}'.") 

207 self._connection.execute(_Merge(table), *rows) 

208 

209 prefix: Optional[str] 

210 """A prefix included in all table names to simulate a database namespace 

211 (`str` or `None`). 

212 """ 

213 

214 dsn: str 

215 """The TNS entry of the database this instance is connected to (`str`). 

216 """