Coverage for python / felis / db / database_context.py: 32%

353 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:37 +0000

1"""API for managing database operations across different dialects.""" 

2 

3# This file is part of felis. 

4# 

5# Developed for the LSST Data Management System. 

6# This product includes software developed by the LSST Project 

7# (https://www.lsst.org). 

8# See the COPYRIGHT file at the top-level directory of this distribution 

9# for details of code ownership. 

10# 

11# This program is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU General Public License as published by 

13# the Free Software Foundation, either version 3 of the License, or 

14# (at your option) any later version. 

15# 

16# This program is distributed in the hope that it will be useful, 

17# but WITHOUT ANY WARRANTY; without even the implied warranty of 

18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19# GNU General Public License for more details. 

20# 

21# You should have received a copy of the GNU General Public License 

22# along with this program. If not, see <https://www.gnu.org/licenses/>. 

23 

24from __future__ import annotations 

25 

26import logging 

27from abc import abstractmethod 

28from collections.abc import Callable, Iterator 

29from contextlib import AbstractContextManager, contextmanager 

30from typing import IO, Any, Literal, TypeAlias 

31 

32from sqlalchemy import ( 

33 Engine, 

34 MetaData, 

35 create_engine, 

36 inspect, 

37 make_url, 

38 quoted_name, 

39) 

40from sqlalchemy.engine import ( 

41 Connection, 

42 Dialect, 

43 Result, 

44) 

45from sqlalchemy.engine.mock import MockConnection, create_mock_engine 

46from sqlalchemy.engine.url import URL 

47from sqlalchemy.exc import SQLAlchemyError 

48from sqlalchemy.schema import ( 

49 CreateSchema, 

50 DropSchema, 

51) 

52from sqlalchemy.sql import ( 

53 Executable, 

54 text, 

55) 

56from sqlalchemy.sql.elements import TextClause 

57 

58__all__ = [ 

59 "DatabaseContext", 

60 "DatabaseContextError", 

61 "MockContext", 

62 "MySQLContext", 

63 "PostgreSQLContext", 

64 "SQLiteContext", 

65 "create_database_context", 

66] 

67 

68logger = logging.getLogger("felis") 

69 

70SQLStatement = str | Executable | TextClause 

71 

72 

73def _normalize_statement(statement: SQLStatement) -> Executable | TextClause: 

74 if isinstance(statement, str): 

75 return text(statement) 

76 return statement 

77 

78 

79def _create_mock_connection(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection: 

80 writer = _SQLWriter(output_file) 

81 engine = create_mock_engine(engine_url, executor=writer.write, paramstyle="pyformat") 

82 writer.dialect = engine.dialect 

83 return engine 

84 

85 

86def _dialect_name(url: URL) -> str: 

87 dialect_name = url.drivername 

88 # Normalize dialect name (e.g., "postgresql+psycopg2" -> "postgresql") 

89 if "+" in dialect_name: 

90 dialect_name = dialect_name.split("+")[0] 

91 return dialect_name 

92 

93 

94def _clear_schema(metadata: MetaData) -> None: 

95 if metadata.schema: 

96 metadata.schema = None 

97 for table in metadata.tables.values(): 

98 table.schema = None 

99 

100 

101def _get_existing_indexes(inspector: Any, table_name: str, schema: str | None) -> set[str]: 

102 return { 

103 ix["name"] 

104 for ix in inspector.get_indexes(table_name, schema=schema) 

105 if "name" in ix and ix["name"] is not None 

106 } 

107 

108 

109def is_mock_url(url: URL) -> bool: 

110 """Check if the engine URL points to a mock connection. 

111 

112 Parameters 

113 ---------- 

114 url 

115 The SQLAlchemy engine URL. 

116 

117 Returns 

118 ------- 

119 bool 

120 True if the URL is a mock URL, False otherwise. 

121 """ 

122 return (url.drivername == "sqlite" and url.database is None) or ( 

123 url.drivername != "sqlite" and url.host is None 

124 ) 

125 

126 

127def is_sqlite_url(url: URL | str) -> bool: 

128 """Check if the engine URL points to a SQLite database. 

129 

130 Parameters 

131 ---------- 

132 url 

133 The SQLAlchemy engine URL or string. 

134 

135 Returns 

136 ------- 

137 bool 

138 True if the URL is a SQLite URL, False otherwise. 

139 """ 

140 if isinstance(url, str): 

141 url = make_url(url) 

142 return url.drivername.startswith("sqlite") 

143 

144 

145class DatabaseContextError(Exception): 

146 """Exception raised for errors in the DatabaseContext operations.""" 

147 

148 

149class DatabaseContext(AbstractContextManager): 

150 """Interface for managing database operations across different 

151 SQL dialects. 

152 """ 

153 

154 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: 

155 """Exit the context manager and clean up resources.""" 

156 try: 

157 self.close() 

158 except Exception: 

159 logger.exception("Error during cleanup of database context") 

160 return False 

161 

162 @abstractmethod 

163 def close(self) -> None: 

164 """Close and clean up database resources.""" 

165 ... 

166 

167 @property 

168 @abstractmethod 

169 def metadata(self) -> MetaData: 

170 """The SQLAlchemy metadata representing the database for the context 

171 (`~sqlalchemy.sql.schema.MetaData`). 

172 """ 

173 ... 

174 

175 @property 

176 @abstractmethod 

177 def engine(self) -> Engine: 

178 """The SQAlchemy engine for the context 

179 (`~sqlalchemy.engine.Engine`). 

180 """ 

181 ... 

182 

183 @property 

184 @abstractmethod 

185 def dialect(self) -> Dialect: 

186 """The SQLAlchemy dialect for the context 

187 (`~sqlalchemy.engine.Dialect`). 

188 """ 

189 ... 

190 

191 @property 

192 @abstractmethod 

193 def dialect_name(self) -> str: 

194 """Get the dialect name for this database context (``str``).""" 

195 ... 

196 

197 @abstractmethod 

198 def initialize(self) -> None: 

199 """Create the target schema in the database if it does not exist 

200 already. 

201 

202 Sub-classes should implement idempotent behavior so that calling this 

203 method multiple times has no adverse effects. If the schema already 

204 exists, the method should simply return without raising an error. (A 

205 warning message may be logged in this case.) 

206 

207 Raises 

208 ------ 

209 DatabaseContextError 

210 If there is an error instantiating the schema. 

211 """ 

212 ... 

213 

214 @abstractmethod 

215 def drop(self) -> None: 

216 """Drop the schema in the database if it exists. 

217 

218 Implementations should use ``IF EXISTS`` semantics to avoid raising 

219 an error if the schema does not exist. 

220 

221 Raises 

222 ------ 

223 DatabaseContextError 

224 If there is an error dropping the schema. 

225 """ 

226 ... 

227 

228 @abstractmethod 

229 def create_all(self) -> None: 

230 """Create all database objects in the schema using the metadata 

231 object. 

232 

233 Raises 

234 ------ 

235 DatabaseContextError 

236 If there is an error creating the schema objects in the database. 

237 """ 

238 ... 

239 

240 @abstractmethod 

241 def create_indexes(self) -> None: 

242 """Create all indexes in the schema using the metadata object. 

243 

244 Raises 

245 ------ 

246 DatabaseContextError 

247 If there is an error creating the indexes in the database. 

248 """ 

249 ... 

250 

251 @abstractmethod 

252 def drop_indexes(self) -> None: 

253 """Drop all indexes in the schema using the metadata object. 

254 

255 Raises 

256 ------ 

257 DatabaseContextError 

258 If there is an error dropping the indexes in the database. 

259 """ 

260 ... 

261 

262 @abstractmethod 

263 def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result: 

264 """Execute a SQL statement and return the result. 

265 

266 Parameters 

267 ---------- 

268 statement 

269 The SQL statement to execute. 

270 

271 Returns 

272 ------- 

273 `~sqlalchemy.engine.Result` 

274 The result of the statement execution. 

275 

276 Raises 

277 ------ 

278 DatabaseContextError 

279 If there is an error executing the SQL statement. 

280 """ 

281 ... 

282 

283 

284class _BaseContext(DatabaseContext): 

285 """Base database context providing common behavior. 

286 

287 Parameters 

288 ---------- 

289 engine_url 

290 The SQLAlchemy engine for connecting to the database. 

291 metadata 

292 The SQLAlchemy metadata representing the database objects. 

293 require_schema 

294 True if a valid schema name is required on the MetaData, False if not. 

295 """ 

296 

297 # Subclasses should set this to the dialect name. 

298 DIALECT: str 

299 

300 def __init__(self, engine_url: URL, metadata: MetaData, require_schema: bool = False) -> None: 

301 self._engine_url = engine_url 

302 self._metadata = metadata 

303 self._schema_name: str | None = metadata.schema 

304 self._engine: Engine | None = None 

305 self._echo: bool = False 

306 

307 # Check that the URL dialect matches this context's expected dialect 

308 self._validate_dialect(engine_url) 

309 

310 # Ensure the schema name is set for dialects that require it 

311 if require_schema and self._schema_name is None: 

312 raise DatabaseContextError(f"Schema name must be set for context: {self.dialect_name}") 

313 

314 @property 

315 def echo(self) -> bool: 

316 """Whether to log all SQL statements executed by the engine 

317 (``bool``). 

318 """ 

319 return self._echo 

320 

321 @echo.setter 

322 def echo(self, value: bool) -> None: 

323 self._echo = value 

324 if self.engine is not None: 

325 self.engine.echo = value 

326 

327 @classmethod 

328 def _validate_dialect(cls, engine_url: URL) -> None: 

329 """Validate that the engine dialect matches this context's expected 

330 dialect. 

331 

332 Parameters 

333 ---------- 

334 engine_url 

335 The SQLAlchemy database URL to validate. 

336 

337 Raises 

338 ------ 

339 DatabaseContextError 

340 If the engine dialect doesn't match the context's expected dialect. 

341 """ 

342 # Normalize both the engine dialect and expected dialect for comparison 

343 engine_dialect = _dialect_name(engine_url) 

344 expected_dialect = cls.DIALECT.lower() 

345 

346 if engine_dialect != expected_dialect: 

347 raise DatabaseContextError( 

348 f"Engine dialect '{engine_dialect}' does not match the context's expected dialect: " 

349 f"{expected_dialect}" 

350 ) 

351 

352 @property 

353 def engine(self) -> Engine: 

354 if self._engine is None: 

355 self._engine = create_engine(self._engine_url) 

356 return self._engine 

357 

358 @property 

359 def metadata(self) -> MetaData: 

360 return self._metadata 

361 

362 @property 

363 def dialect(self) -> Dialect: 

364 return self.engine.dialect 

365 

366 @property 

367 def dialect_name(self) -> str: 

368 """Get the dialect name for this database context. 

369 

370 Returns 

371 ------- 

372 str 

373 The normalized dialect name. 

374 """ 

375 return self.DIALECT 

376 

377 @property 

378 def schema_name(self) -> str | None: 

379 """Effective schema name for this context (may be None). 

380 

381 Returns 

382 ------- 

383 str | None 

384 The schema name, or None if no schema is set. 

385 """ 

386 return self._schema_name 

387 

388 @contextmanager 

389 def connect(self) -> Iterator[Connection]: 

390 """Context manager for database connection.""" 

391 with self.engine.connect() as connection: 

392 yield connection 

393 

394 def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result: 

395 statement = _normalize_statement(statement) 

396 try: 

397 with self.connect() as conn: 

398 with conn.begin(): 

399 if parameters: 

400 result = conn.execute(statement, parameters) 

401 else: 

402 result = conn.execute(statement) 

403 return result 

404 except SQLAlchemyError as e: 

405 raise DatabaseContextError(f"Error executing statement: {e}") from e 

406 

407 def create_all(self) -> None: 

408 with self.connect() as conn: 

409 with conn.begin(): 

410 try: 

411 self.metadata.create_all(bind=conn) 

412 except SQLAlchemyError as e: 

413 raise DatabaseContextError(f"Error creating database: {e}") from e 

414 

415 def _manage_indexes(self, action: str) -> None: 

416 """Manage indexes by creating or dropping them. 

417 

418 Parameters 

419 ---------- 

420 action 

421 The action to perform, either "create" or "drop". 

422 

423 Raises 

424 ------ 

425 DatabaseContextError 

426 If there is an error managing the indexes in the database. 

427 """ 

428 with self.connect() as conn: 

429 with conn.begin(): 

430 try: 

431 inspector = inspect(conn) 

432 for table in self.metadata.tables.values(): 

433 # Fetch all existing indexes for this table once 

434 existing_indexes = _get_existing_indexes(inspector, table.name, self.schema_name) 

435 

436 for index in table.indexes: 

437 if index.name is None: 

438 # Anonymous indexes can't be checked by name 

439 logger.warning(f"Skipping anonymous index on table '{table.name}'") 

440 continue 

441 

442 if action == "create": 

443 if index.name in existing_indexes: 

444 logger.warning( 

445 f"Skipping creation of index '{index.name}' which already exists" 

446 ) 

447 continue 

448 index.create(bind=conn, checkfirst=False) # We already checked 

449 logger.info(f"Created index '{index.name}'") 

450 elif action == "drop": 

451 if index.name not in existing_indexes: 

452 logger.warning(f"Skipping index '{index.name}' which does not exist") 

453 continue 

454 index.drop(bind=conn, checkfirst=False) # We already checked 

455 logger.info(f"Dropped index '{index.name}'") 

456 else: 

457 raise ValueError(f"Invalid action '{action}'. Must be 'create' or 'drop'.") 

458 except SQLAlchemyError as e: 

459 raise DatabaseContextError(f"Error {action}ing indexes: {e}") from e 

460 

461 def create_indexes(self) -> None: 

462 """Create all indexes in the schema using the metadata object. 

463 

464 Raises 

465 ------ 

466 DatabaseContextError 

467 If there is an error creating the indexes in the database. 

468 """ 

469 self._manage_indexes("create") 

470 

471 def drop_indexes(self) -> None: 

472 """Drop all indexes in the schema using the metadata object. 

473 

474 Raises 

475 ------ 

476 DatabaseContextError 

477 If there is an error dropping the indexes in the database. 

478 """ 

479 self._manage_indexes("drop") 

480 

481 def _required_schema_name(self) -> str: 

482 """Return the schema name, ensuring that it is set. 

483 

484 This is mainly here for typing purposes, because the schema_name 

485 property may be None, and mypy doesn't understand that we already 

486 checked it during initialization. 

487 """ 

488 if self.schema_name is None: 

489 raise DatabaseContextError("Schema name is required but not set.") 

490 return self.schema_name 

491 

492 def close(self) -> None: 

493 """Close and dispose of the database engine.""" 

494 if self._engine is not None: 

495 self._engine.dispose() 

496 self._engine = None 

497 

498 

499_ContextClass: TypeAlias = type[_BaseContext] 

500_ContextDecorator: TypeAlias = Callable[[_ContextClass], _ContextClass] 

501 

502 

503class DatabaseContextFactory: 

504 """Factory for creating DatabaseContext instances based on dialect type.""" 

505 

506 _registry: dict[str, _ContextClass] = {} 

507 

508 @classmethod 

509 def register(cls) -> _ContextDecorator: 

510 """Register a context class for its dialect. 

511 

512 The dialect is determined by reading the DIALECT attribute from the 

513 decorated class. 

514 

515 Returns 

516 ------- 

517 Callable 

518 The decorator function that registers the context class. 

519 

520 Examples 

521 -------- 

522 >>> @DatabaseContextFactory.register() 

523 ... class PostgreSQLContext(_BaseContext): 

524 ... DIALECT = "postgresql" 

525 ... pass 

526 

527 Notes 

528 ----- 

529 The registry is populated at module import time and afterwards should 

530 be treated as read-only. 

531 """ 

532 

533 def decorator(context_class: type[_BaseContext]) -> type[_BaseContext]: 

534 # Get the dialect from the class's DIALECT attribute 

535 if not hasattr(context_class, "DIALECT"): 535 ↛ 536line 535 didn't jump to line 536 because the condition on line 535 was never true

536 raise ValueError(f"Context class {context_class.__name__} must define a DIALECT attribute") 

537 cls._registry[context_class.DIALECT] = context_class 

538 return context_class 

539 

540 return decorator 

541 

542 @classmethod 

543 def register_class(cls, dialect: str, context_class: type[_BaseContext]) -> None: 

544 """Register a context class for a specific dialect programmatically. 

545 

546 Parameters 

547 ---------- 

548 dialect 

549 The dialect name to register. 

550 context_class 

551 The context class to use for this dialect. 

552 """ 

553 dialect_name = dialect.lower() 

554 if "+" in dialect_name: 

555 dialect_name = dialect_name.split("+")[0] 

556 cls._registry[dialect_name] = context_class 

557 

558 @classmethod 

559 def create_context(cls, dialect: str, engine_url: URL, metadata: MetaData) -> DatabaseContext: 

560 """Create a context instance for the given dialect. 

561 

562 Parameters 

563 ---------- 

564 dialect 

565 The database dialect name. 

566 engine_url 

567 The SQLAlchemy database URL. 

568 metadata 

569 The SQLAlchemy metadata. 

570 

571 Returns 

572 ------- 

573 DatabaseContext 

574 The appropriate context instance. 

575 

576 Raises 

577 ------ 

578 ValueError 

579 If no context class is registered for the dialect. 

580 """ 

581 dialect_name = dialect.lower() 

582 if "+" in dialect_name: 

583 dialect_name = dialect_name.split("+")[0] 

584 

585 if dialect_name not in cls._registry: 

586 supported = cls.get_supported_dialects() 

587 raise ValueError( 

588 f"No context class registered for dialect: {dialect_name}. " 

589 f"Supported dialects: {', '.join(supported)}" 

590 ) 

591 

592 context_class = cls._registry[dialect_name] 

593 return context_class(engine_url, metadata) 

594 

595 @classmethod 

596 def get_supported_dialects(cls) -> list[str]: 

597 """Get a list of supported dialect names. 

598 

599 Returns 

600 ------- 

601 list[str] 

602 List of supported dialect names. 

603 """ 

604 return list(cls._registry.keys()) 

605 

606 

607class _SQLWriter: 

608 """Write SQL statements to stdout or a file. 

609 

610 Parameters 

611 ---------- 

612 file 

613 The file to write the SQL statements to. If None, the statements 

614 will be written to stdout. 

615 """ 

616 

617 def __init__(self, file: IO[str] | None = None) -> None: 

618 """Initialize the SQL writer.""" 

619 self.file = file 

620 self.dialect: Dialect | None = None 

621 

622 def write(self, sql: Any, *multiparams: Any, **params: Any) -> None: 

623 """Write the SQL statement to a file or stdout. 

624 

625 Statements with parameters will be formatted with the values 

626 inserted into the resultant SQL output. 

627 

628 Parameters 

629 ---------- 

630 sql 

631 The SQL statement to write. 

632 *multiparams 

633 The multiparams to use for the SQL statement. 

634 **params 

635 The params to use for the SQL statement. 

636 

637 Notes 

638 ----- 

639 The functions arguments are typed very loosely because this method in 

640 SQLAlchemy is untyped, amd we do not call it directly. 

641 """ 

642 compiled = sql.compile(dialect=self.dialect) 

643 sql_str = str(compiled) + ";" 

644 params_list = [compiled.params] 

645 for params in params_list: 

646 if not params: 

647 print(sql_str, file=self.file) 

648 continue 

649 new_params = {} 

650 for key, value in params.items(): 

651 if isinstance(value, str): 

652 new_params[key] = f"'{value}'" 

653 elif value is None: 

654 new_params[key] = "null" 

655 else: 

656 new_params[key] = value 

657 print(sql_str % new_params, file=self.file) 

658 

659 

660@DatabaseContextFactory.register() 

661class PostgreSQLContext(_BaseContext): 

662 """Database context for Postgres. 

663 

664 Parameters 

665 ---------- 

666 engine_url 

667 The SQLAlchemy database URL for connecting to the database. 

668 metadata 

669 The SQLAlchemy metadata representing the database objects. 

670 """ 

671 

672 DIALECT = "postgresql" 

673 

674 def __init__(self, engine_url: URL, metadata: MetaData): 

675 super().__init__(engine_url, metadata, require_schema=True) 

676 

677 def initialize(self) -> None: 

678 schema_name = self._required_schema_name() 

679 try: 

680 logger.debug(f"Checking if PG schema exists: {schema_name}") 

681 result = self.execute( 

682 """ 

683 SELECT schema_name 

684 FROM information_schema.schemata 

685 WHERE schema_name = :schema_name 

686 """, 

687 {"schema_name": schema_name}, 

688 ) 

689 if result.fetchone(): 

690 return 

691 logger.debug(f"Creating PG schema: {schema_name}") 

692 self.execute(CreateSchema(schema_name)) 

693 except SQLAlchemyError as e: 

694 raise DatabaseContextError(f"Error initializing Postgres schema: {e}") from e 

695 

696 def drop(self) -> None: 

697 schema_name = self._required_schema_name() 

698 try: 

699 logger.debug(f"Dropping PostgreSQL schema if exists: {schema_name}") 

700 self.execute(DropSchema(schema_name, if_exists=True, cascade=True)) 

701 except SQLAlchemyError as e: 

702 raise DatabaseContextError(f"Error dropping Postgres database: {e}") from e 

703 

704 

705@DatabaseContextFactory.register() 

706class MySQLContext(_BaseContext): 

707 """Database context for MySQL. 

708 

709 Parameters 

710 ---------- 

711 engine_url 

712 The SQLAlchemy database URL for connecting to the database. 

713 metadata 

714 The SQLAlchemy metadata representing the database objects. 

715 """ 

716 

717 DIALECT = "mysql" 

718 

719 def __init__(self, engine_url: URL, metadata: MetaData): 

720 super().__init__(engine_url, metadata, require_schema=True) 

721 

722 def initialize(self) -> None: 

723 # The schema is instantiated as a database, as MySQL does not have a 

724 # distinct schema concept, unlike Postgres. 

725 schema_name = self._required_schema_name() 

726 try: 

727 logger.debug(f"Checking if MySQL database exists: {schema_name}") 

728 result = self.execute("SHOW DATABASES LIKE :schema_name", {"schema_name": schema_name}) 

729 if result.fetchone(): 

730 return 

731 logger.debug(f"Creating MySQL database: {schema_name}") 

732 from sqlalchemy import DDL 

733 

734 create_stmt = DDL(f"CREATE DATABASE {quoted_name(schema_name, quote=True)}") 

735 self.execute(create_stmt) 

736 except SQLAlchemyError as e: 

737 raise DatabaseContextError(f"Error initializing MySQL database: {e}") from e 

738 

739 def drop(self) -> None: 

740 schema_name = self._required_schema_name() 

741 try: 

742 logger.debug(f"Dropping MySQL database if exists: {schema_name}") 

743 from sqlalchemy import DDL 

744 

745 drop_stmt = DDL(f"DROP DATABASE IF EXISTS {quoted_name(schema_name, quote=True)}") 

746 self.execute(drop_stmt) 

747 except SQLAlchemyError as e: 

748 raise DatabaseContextError(f"Error dropping MySQL database: {e}") from e 

749 

750 

751@DatabaseContextFactory.register() 

752class SQLiteContext(_BaseContext): 

753 """Database context for SQLite. 

754 

755 Parameters 

756 ---------- 

757 engine_url 

758 The SQLAlchemy database URL for connecting to the database. 

759 metadata 

760 The SQLAlchemy metadata representing the database objects. 

761 """ 

762 

763 DIALECT = "sqlite" 

764 

765 def __init__(self, engine_url: URL, metadata: MetaData): 

766 # Schema name needs to be cleared, if set. 

767 _clear_schema(metadata) 

768 # Schema name is not required. 

769 super().__init__(engine_url, metadata) 

770 

771 def initialize(self) -> None: 

772 # Nothing needs to be done for SQLite initialization. 

773 return 

774 

775 def drop(self) -> None: 

776 try: 

777 logger.debug("Dropping tables in SQLite schema") 

778 # Drop all the tables in the database file. 

779 self.metadata.drop_all(bind=self.engine) 

780 except SQLAlchemyError as e: 

781 raise DatabaseContextError(f"Error dropping SQLite database: {e}") from e 

782 

783 

784class MockContext(DatabaseContext): 

785 """Database context for a mock connection. 

786 

787 Parameters 

788 ---------- 

789 metadata 

790 The SQLAlchemy metadata defining the database objects. 

791 connection 

792 The SQLAlchemy mock connection. 

793 """ 

794 

795 def __init__(self, metadata: MetaData, connection: MockConnection): 

796 self._metadata = metadata 

797 self._connection = connection 

798 self._dialect = connection.dialect 

799 

800 @property 

801 def dialect(self) -> Dialect: 

802 return self._dialect 

803 

804 @property 

805 def dialect_name(self) -> str: 

806 return self.dialect.name 

807 

808 @property 

809 def metadata(self) -> MetaData: 

810 return self._metadata 

811 

812 @property 

813 def engine(self) -> Engine: 

814 raise DatabaseContextError("MockContext does not provide an engine.") 

815 

816 def initialize(self) -> None: 

817 # Mock connection doesn't do any initialization. 

818 pass 

819 

820 def drop(self) -> None: 

821 # Mock connection doesn't drop. 

822 pass 

823 

824 def create_all(self) -> None: 

825 self._metadata.create_all(self._connection) 

826 

827 def create_indexes(self) -> None: 

828 # Mock connection can't create indexes. 

829 pass 

830 

831 def drop_indexes(self) -> None: 

832 # Mock connection can't drop indexes. 

833 pass 

834 

835 def execute(self, statement: SQLStatement, parameters: dict[str, Any] | None = None) -> Result: 

836 statement = _normalize_statement(statement) 

837 if parameters: 

838 return self._connection.connect().execute(statement, parameters) 

839 else: 

840 return self._connection.connect().execute(statement) 

841 

842 def close(self) -> None: 

843 """Close the mock connection (no-op).""" 

844 pass 

845 

846 

847def create_database_context( 

848 engine_url: str | URL, 

849 metadata: MetaData, 

850 output_file: IO[str] | None = None, 

851 dry_run: bool = False, 

852 echo: bool | None = None, 

853) -> DatabaseContext: 

854 """Create a DatabaseContext object based on the engine URL. 

855 

856 Parameters 

857 ---------- 

858 engine_url 

859 The database URL for the database connection. 

860 metadata 

861 The SQLAlchemy MetaData representing the database objects. 

862 output_file 

863 Output file for writing generated SQL commands. 

864 dry_run 

865 If True, configure the context to perform a dry run, where operations 

866 will not be executed. 

867 If False, use a normal context where operations are executed. 

868 echo 

869 If True, the SQLAlchemy engine will log all statements to the console. 

870 

871 Returns 

872 ------- 

873 DatabaseContext 

874 A database context appropriate for the given engine URL. This will be 

875 a `MockContext` if the URL appears like a mock URL or if ``dry_run`` is 

876 True, otherwise it will be a context based on the dialect using the 

877 factory pattern. 

878 

879 Raises 

880 ------ 

881 DatabaseContextError 

882 If the dialect is not supported or if there's an issue creating 

883 the context. 

884 """ 

885 if isinstance(engine_url, str): 

886 engine_url = make_url(engine_url) 

887 

888 if is_mock_url(engine_url) or dry_run: 

889 # Use a mock context for mock URLs or dry run mode. 

890 dialect_name = _dialect_name(engine_url) 

891 if dialect_name == "sqlite": 

892 _clear_schema(metadata) 

893 mock_connection = _create_mock_connection(engine_url, output_file) 

894 return MockContext(metadata, mock_connection) 

895 else: 

896 # Create a real engine and context for the given dialect. 

897 try: 

898 dialect_name = _dialect_name(engine_url) 

899 

900 # Use the factory to create the appropriate context 

901 try: 

902 db_ctx = DatabaseContextFactory.create_context(dialect_name, engine_url, metadata) 

903 if echo is not None: 

904 # This is settable for real contexts only. 

905 if hasattr(db_ctx, "echo"): 

906 db_ctx.echo = echo 

907 return db_ctx 

908 except ValueError as e: 

909 supported = DatabaseContextFactory.get_supported_dialects() 

910 raise DatabaseContextError( 

911 f"Unsupported dialect: {dialect_name}. Supported dialects are: {', '.join(supported)}" 

912 ) from e 

913 

914 except Exception as e: 

915 if isinstance(e, DatabaseContextError): 

916 raise 

917 raise DatabaseContextError(f"Failed to create database context: {e}") from e