Coverage for python / felis / db / _dialects.py: 43%

49 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:37 +0000

1"""Utilities for accessing SQLAlchemy dialects and their type modules.""" 

2 

3# This file is part of felis. 

4# 

5# Developed for the LSST Data Management System. 

6# This product includes software developed by the LSST Project 

7# (https://www.lsst.org). 

8# See the COPYRIGHT file at the top-level directory of this distribution 

9# for details of code ownership. 

10# 

11# This program is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU General Public License as published by 

13# the Free Software Foundation, either version 3 of the License, or 

14# (at your option) any later version. 

15# 

16# This program is distributed in the hope that it will be useful, 

17# but WITHOUT ANY WARRANTY; without even the implied warranty of 

18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19# GNU General Public License for more details. 

20# 

21# You should have received a copy of the GNU General Public License 

22# along with this program. If not, see <https://www.gnu.org/licenses/>. 

23 

24from __future__ import annotations 

25 

26import re 

27from collections.abc import Mapping 

28from types import MappingProxyType, ModuleType 

29 

30from sqlalchemy import dialects, types 

31from sqlalchemy.engine import Dialect 

32from sqlalchemy.engine.mock import create_mock_engine 

33from sqlalchemy.types import TypeEngine 

34 

35from ._sqltypes import MYSQL, POSTGRES, SQLITE 

36 

37__all__ = ["get_dialect_module", "get_supported_dialects", "string_to_typeengine"] 

38 

39_DIALECT_NAMES = (MYSQL, POSTGRES, SQLITE) 

40"""List of supported dialect names. 

41 

42This list is used to create the dialect and module dictionaries. 

43""" 

44 

45_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?") 

46"""Regular expression to match data types with parameters in parentheses.""" 

47 

48 

49def _dialect(dialect_name: str) -> Dialect: 

50 """Create the SQLAlchemy dialect for the given name using a mock engine. 

51 

52 Parameters 

53 ---------- 

54 dialect_name 

55 The name of the dialect to create. 

56 

57 Returns 

58 ------- 

59 `~sqlalchemy.engine.Dialect` 

60 The SQLAlchemy dialect. 

61 """ 

62 return create_mock_engine(f"{dialect_name}://", executor=None).dialect 

63 

64 

65_DIALECTS = MappingProxyType({name: _dialect(name) for name in _DIALECT_NAMES}) 

66"""Dictionary of dialect names to SQLAlchemy dialects.""" 

67 

68 

69def get_supported_dialects() -> Mapping[str, Dialect]: 

70 """Get a dictionary of the supported SQLAlchemy dialects. 

71 

72 Returns 

73 ------- 

74 `dict` [ `str`, `~sqlalchemy.engine.Dialect`] 

75 A dictionary of the supported SQLAlchemy dialects. 

76 

77 Notes 

78 ----- 

79 The dictionary is keyed by the dialect name and the value is the SQLAlchemy 

80 dialect object. This function is intended as the primary interface for 

81 getting the supported dialects. 

82 """ 

83 return _DIALECTS 

84 

85 

86def _dialect_module(dialect_name: str) -> ModuleType: 

87 """Get the SQLAlchemy dialect module for the given name. 

88 

89 Parameters 

90 ---------- 

91 dialect_name 

92 The name of the dialect module to get from the SQLAlchemy package. 

93 """ 

94 return getattr(dialects, dialect_name) 

95 

96 

97_DIALECT_MODULES = MappingProxyType({name: _dialect_module(name) for name in _DIALECT_NAMES}) 

98"""Dictionary of dialect names to SQLAlchemy modules.""" 

99 

100 

101def get_dialect_module(dialect_name: str) -> ModuleType: 

102 """Get the SQLAlchemy dialect module for the given name. 

103 

104 Parameters 

105 ---------- 

106 dialect_name 

107 The name of the dialect module to get from the SQLAlchemy package. 

108 

109 Returns 

110 ------- 

111 `~types.ModuleType` 

112 The SQLAlchemy dialect module. 

113 

114 Raises 

115 ------ 

116 ValueError 

117 Raised if the dialect name is not supported. 

118 """ 

119 if dialect_name not in _DIALECT_MODULES: 

120 raise ValueError(f"Unsupported dialect: {dialect_name}") 

121 return _DIALECT_MODULES[dialect_name] 

122 

123 

124def string_to_typeengine( 

125 type_string: str, dialect: Dialect | None = None, length: int | None = None 

126) -> TypeEngine: 

127 """Convert a string representation of a datatype to a SQLAlchemy type. 

128 

129 Parameters 

130 ---------- 

131 type_string 

132 The string representation of the data type. 

133 dialect 

134 The SQLAlchemy dialect to use. If None, the default dialect will be 

135 used. 

136 length 

137 The length of the data type. If the data type does not have a length 

138 attribute, this parameter will be ignored. 

139 

140 Returns 

141 ------- 

142 `sqlalchemy.types.TypeEngine` 

143 The SQLAlchemy type engine object. 

144 

145 Raises 

146 ------ 

147 ValueError 

148 Raised if the type string is invalid or the type is not supported. 

149 

150 Notes 

151 ----- 

152 This function is used when converting type override strings defined in 

153 fields such as ``mysql:datatype`` in the schema data. 

154 """ 

155 match = _DATATYPE_REGEXP.search(type_string) 

156 if not match: 

157 raise ValueError(f"Invalid type string: {type_string}") 

158 

159 type_name, _, params = match.groups() 

160 if dialect is None: 

161 type_class = getattr(types, type_name.upper(), None) 

162 else: 

163 try: 

164 dialect_module = get_dialect_module(dialect.name) 

165 except KeyError: 

166 raise ValueError(f"Unsupported dialect: {dialect}") 

167 type_class = getattr(dialect_module, type_name.upper(), None) 

168 

169 if not type_class: 

170 raise ValueError(f"Unsupported type: {type_name.upper()}") 

171 

172 if params: 

173 params = [int(param) if param.isdigit() else param for param in params.split(",")] 

174 type_obj = type_class(*params) 

175 else: 

176 type_obj = type_class() 

177 

178 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None: 

179 type_obj.length = length 

180 

181 return type_obj