Coverage for python / felis / db / _dialects.py: 43%
49 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:37 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:37 +0000
1"""Utilities for accessing SQLAlchemy dialects and their type modules."""
3# This file is part of felis.
4#
5# Developed for the LSST Data Management System.
6# This product includes software developed by the LSST Project
7# (https://www.lsst.org).
8# See the COPYRIGHT file at the top-level directory of this distribution
9# for details of code ownership.
10#
11# This program is free software: you can redistribute it and/or modify
12# it under the terms of the GNU General Public License as published by
13# the Free Software Foundation, either version 3 of the License, or
14# (at your option) any later version.
15#
16# This program is distributed in the hope that it will be useful,
17# but WITHOUT ANY WARRANTY; without even the implied warranty of
18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19# GNU General Public License for more details.
20#
21# You should have received a copy of the GNU General Public License
22# along with this program. If not, see <https://www.gnu.org/licenses/>.
24from __future__ import annotations
26import re
27from collections.abc import Mapping
28from types import MappingProxyType, ModuleType
30from sqlalchemy import dialects, types
31from sqlalchemy.engine import Dialect
32from sqlalchemy.engine.mock import create_mock_engine
33from sqlalchemy.types import TypeEngine
35from ._sqltypes import MYSQL, POSTGRES, SQLITE
37__all__ = ["get_dialect_module", "get_supported_dialects", "string_to_typeengine"]
39_DIALECT_NAMES = (MYSQL, POSTGRES, SQLITE)
40"""List of supported dialect names.
42This list is used to create the dialect and module dictionaries.
43"""
45_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
46"""Regular expression to match data types with parameters in parentheses."""
49def _dialect(dialect_name: str) -> Dialect:
50 """Create the SQLAlchemy dialect for the given name using a mock engine.
52 Parameters
53 ----------
54 dialect_name
55 The name of the dialect to create.
57 Returns
58 -------
59 `~sqlalchemy.engine.Dialect`
60 The SQLAlchemy dialect.
61 """
62 return create_mock_engine(f"{dialect_name}://", executor=None).dialect
65_DIALECTS = MappingProxyType({name: _dialect(name) for name in _DIALECT_NAMES})
66"""Dictionary of dialect names to SQLAlchemy dialects."""
69def get_supported_dialects() -> Mapping[str, Dialect]:
70 """Get a dictionary of the supported SQLAlchemy dialects.
72 Returns
73 -------
74 `dict` [ `str`, `~sqlalchemy.engine.Dialect`]
75 A dictionary of the supported SQLAlchemy dialects.
77 Notes
78 -----
79 The dictionary is keyed by the dialect name and the value is the SQLAlchemy
80 dialect object. This function is intended as the primary interface for
81 getting the supported dialects.
82 """
83 return _DIALECTS
86def _dialect_module(dialect_name: str) -> ModuleType:
87 """Get the SQLAlchemy dialect module for the given name.
89 Parameters
90 ----------
91 dialect_name
92 The name of the dialect module to get from the SQLAlchemy package.
93 """
94 return getattr(dialects, dialect_name)
97_DIALECT_MODULES = MappingProxyType({name: _dialect_module(name) for name in _DIALECT_NAMES})
98"""Dictionary of dialect names to SQLAlchemy modules."""
101def get_dialect_module(dialect_name: str) -> ModuleType:
102 """Get the SQLAlchemy dialect module for the given name.
104 Parameters
105 ----------
106 dialect_name
107 The name of the dialect module to get from the SQLAlchemy package.
109 Returns
110 -------
111 `~types.ModuleType`
112 The SQLAlchemy dialect module.
114 Raises
115 ------
116 ValueError
117 Raised if the dialect name is not supported.
118 """
119 if dialect_name not in _DIALECT_MODULES:
120 raise ValueError(f"Unsupported dialect: {dialect_name}")
121 return _DIALECT_MODULES[dialect_name]
124def string_to_typeengine(
125 type_string: str, dialect: Dialect | None = None, length: int | None = None
126) -> TypeEngine:
127 """Convert a string representation of a datatype to a SQLAlchemy type.
129 Parameters
130 ----------
131 type_string
132 The string representation of the data type.
133 dialect
134 The SQLAlchemy dialect to use. If None, the default dialect will be
135 used.
136 length
137 The length of the data type. If the data type does not have a length
138 attribute, this parameter will be ignored.
140 Returns
141 -------
142 `sqlalchemy.types.TypeEngine`
143 The SQLAlchemy type engine object.
145 Raises
146 ------
147 ValueError
148 Raised if the type string is invalid or the type is not supported.
150 Notes
151 -----
152 This function is used when converting type override strings defined in
153 fields such as ``mysql:datatype`` in the schema data.
154 """
155 match = _DATATYPE_REGEXP.search(type_string)
156 if not match:
157 raise ValueError(f"Invalid type string: {type_string}")
159 type_name, _, params = match.groups()
160 if dialect is None:
161 type_class = getattr(types, type_name.upper(), None)
162 else:
163 try:
164 dialect_module = get_dialect_module(dialect.name)
165 except KeyError:
166 raise ValueError(f"Unsupported dialect: {dialect}")
167 type_class = getattr(dialect_module, type_name.upper(), None)
169 if not type_class:
170 raise ValueError(f"Unsupported type: {type_name.upper()}")
172 if params:
173 params = [int(param) if param.isdigit() else param for param in params.split(",")]
174 type_obj = type_class(*params)
175 else:
176 type_obj = type_class()
178 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
179 type_obj.length = length
181 return type_obj