# Advanced Alchemy
Check out the [project documentation][project-docs] ๐ for more information.
## About
A carefully crafted, thoroughly tested, optimized companion library for SQLAlchemy,
offering:
- Sync and async repositories, featuring common CRUD and highly optimized bulk operations
- Integration with major web frameworks including Litestar, Starlette, FastAPI, Sanic
- Custom-built alembic configuration and CLI with optional framework integration
- Utility base classes with audit columns, primary keys and utility functions
- Built in `File Object` data type for storing objects:
- Unified interface for various storage backends ([`fsspec`](https://filesystem-spec.readthedocs.io/en/latest/) and [`obstore`](https://developmentseed.org/obstore/latest/))
- Optional lifecycle event hooks integrated with SQLAlchemy's event system to automatically save and delete files as records are inserted, updated, or deleted.
- Optimized JSON types including a custom JSON type for Oracle
- Integrated support for UUID6 and UUID7 using [`uuid-utils`](https://github.com/aminalaee/uuid-utils) (install with the `uuid` extra)
- Integrated support for Nano ID using [`fastnanoid`](https://github.com/oliverlambson/fastnanoid) (install with the `nanoid` extra)
- Custom encrypted text type with multiple backend support including [`pgcrypto`](https://www.postgresql.org/docs/current/pgcrypto.html) for PostgreSQL and the Fernet implementation from [`cryptography`](https://cryptography.io/en/latest/) for other databases
- Custom password hashing type with multiple backend support including [`Argon2`](https://github.com/P-H-C/phc-winner-argon2), [`Passlib`](https://passlib.readthedocs.io/en/stable/), and [`Pwdlib`](https://pwdlib.readthedocs.io/en/stable/) with automatic salt generation
- Pre-configured base classes with audit columns UUID or Big Integer primary keys and
a [sentinel column](https://docs.sqlalchemy.org/en/20/core/connections.html#configuring-sentinel-columns).
- Synchronous and asynchronous repositories featuring:
- Common CRUD operations for SQLAlchemy models
- Bulk inserts, updates, upserts, and deletes with dialect-specific enhancements
- Integrated counts, pagination, sorting, filtering with `LIKE`, `IN`, and dates before and/or after.
- Tested support for multiple database backends including:
- SQLite via [aiosqlite](https://aiosqlite.omnilib.dev/en/stable/) or [sqlite](https://docs.python.org/3/library/sqlite3.html)
- Postgres via [asyncpg](https://magicstack.github.io/asyncpg/current/) or [psycopg3 (async or sync)](https://www.psycopg.org/psycopg3/)
- MySQL via [asyncmy](https://github.com/long2ice/asyncmy)
- Oracle via [oracledb (async or sync)](https://oracle.github.io/python-oracledb/) (tested on 18c and 23c)
- Google Spanner via [spanner-sqlalchemy](https://github.com/googleapis/python-spanner-sqlalchemy/)
- DuckDB via [duckdb_engine](https://github.com/Mause/duckdb_engine)
- Microsoft SQL Server via [pyodbc](https://github.com/mkleehammer/pyodbc) or [aioodbc](https://github.com/aio-libs/aioodbc)
- CockroachDB via [sqlalchemy-cockroachdb (async or sync)](https://github.com/cockroachdb/sqlalchemy-cockroachdb)
- ...and much more
## Usage
### Installation
```shell
pip install advanced-alchemy
```
> [!IMPORTANT]\
> Check out [the installation guide][install-guide] in our official documentation!
### Repositories
Advanced Alchemy includes a set of asynchronous and synchronous repository classes for easy CRUD
operations on your SQLAlchemy models.
Click to expand the example
```python
from advanced_alchemy import base, repository, config
from sqlalchemy import create_engine
from sqlalchemy.orm import Mapped, sessionmaker
class User(base.UUIDBase):
# you can optionally override the generated table name by manually setting it.
__tablename__ = "user_account" # type: ignore[assignment]
email: Mapped[str]
name: Mapped[str]
class UserRepository(repository.SQLAlchemySyncRepository[User]):
"""User repository."""
model_type = User
db = config.SQLAlchemySyncConfig(connection_string="duckdb:///:memory:", session_config=config.SyncSessionConfig(expire_on_commit=False))
# Initializes the database.
with db.get_engine().begin() as conn:
User.metadata.create_all(conn)
with db.get_session() as db_session:
repo = UserRepository(session=db_session)
# 1) Create multiple users with `add_many`
bulk_users = [
{"email": 'cody@litestar.dev', 'name': 'Cody'},
{"email": 'janek@litestar.dev', 'name': 'Janek'},
{"email": 'peter@litestar.dev', 'name': 'Peter'},
{"email": 'jacob@litestar.dev', 'name': 'Jacob'}
]
objs = repo.add_many([User(**raw_user) for raw_user in bulk_users])
db_session.commit()
print(f"Created {len(objs)} new objects.")
# 2) Select paginated data and total row count. Pass additional filters as kwargs
created_objs, total_objs = repo.list_and_count(LimitOffset(limit=10, offset=0), name="Cody")
print(f"Selected {len(created_objs)} records out of a total of {total_objs}.")
# 3) Let's remove the batch of records selected.
deleted_objs = repo.delete_many([new_obj.id for new_obj in created_objs])
print(f"Removed {len(deleted_objs)} records out of a total of {total_objs}.")
# 4) Let's count the remaining rows
remaining_count = repo.count()
print(f"Found {remaining_count} remaining records after delete.")
```
For a full standalone example, see the sample [here][standalone-example]
### Services
Advanced Alchemy includes an additional service class to make working with a repository easier.
This class is designed to accept data as a dictionary or SQLAlchemy model,
and it will handle the type conversions for you.
Here's the same example from above but using a service to create the data:
```python
from advanced_alchemy import base, repository, filters, service, config
from sqlalchemy import create_engine
from sqlalchemy.orm import Mapped, sessionmaker
class User(base.UUIDBase):
# you can optionally override the generated table name by manually setting it.
__tablename__ = "user_account" # type: ignore[assignment]
email: Mapped[str]
name: Mapped[str]
class UserService(service.SQLAlchemySyncRepositoryService[User]):
"""User repository."""
class Repo(repository.SQLAlchemySyncRepository[User]):
"""User repository."""
model_type = User
repository_type = Repo
db = config.SQLAlchemySyncConfig(connection_string="duckdb:///:memory:", session_config=config.SyncSessionConfig(expire_on_commit=False))
# Initializes the database.
with db.get_engine().begin() as conn:
User.metadata.create_all(conn)
with db.get_session() as db_session:
service = UserService(session=db_session)
# 1) Create multiple users with `add_many`
objs = service.create_many([
{"email": 'cody@litestar.dev', 'name': 'Cody'},
{"email": 'janek@litestar.dev', 'name': 'Janek'},
{"email": 'peter@litestar.dev', 'name': 'Peter'},
{"email": 'jacob@litestar.dev', 'name': 'Jacob'}
])
print(objs)
print(f"Created {len(objs)} new objects.")
# 2) Select paginated data and total row count. Pass additional filters as kwargs
created_objs, total_objs = service.list_and_count(LimitOffset(limit=10, offset=0), name="Cody")
print(f"Selected {len(created_objs)} records out of a total of {total_objs}.")
# 3) Let's remove the batch of records selected.
deleted_objs = service.delete_many([new_obj.id for new_obj in created_objs])
print(f"Removed {len(deleted_objs)} records out of a total of {total_objs}.")
# 4) Let's count the remaining rows
remaining_count = service.count()
print(f"Found {remaining_count} remaining records after delete.")
```
### Web Frameworks
Advanced Alchemy works with nearly all Python web frameworks.
Several helpers for popular libraries are included, and additional PRs to support others are welcomed.
#### Litestar
Advanced Alchemy is the official SQLAlchemy integration for Litestar.
In addition to installing with `pip install advanced-alchemy`,
it can also be installed as a Litestar extra with `pip install litestar[sqlalchemy]`.
Litestar Example
```python
from litestar import Litestar
from litestar.plugins.sqlalchemy import SQLAlchemyPlugin, SQLAlchemyAsyncConfig
# alternately...
# from advanced_alchemy.extensions.litestar import SQLAlchemyAsyncConfig, SQLAlchemyPlugin
alchemy = SQLAlchemyPlugin(
config=SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite"),
)
app = Litestar(plugins=[alchemy])
```
For a full Litestar example, check [here][litestar-example]
#### Flask
Flask Example
```python
from flask import Flask
from advanced_alchemy.extensions.flask import AdvancedAlchemy, SQLAlchemySyncConfig
app = Flask(__name__)
alchemy = AdvancedAlchemy(
config=SQLAlchemySyncConfig(connection_string="duckdb:///:memory:"), app=app,
)
```
For a full Flask example, see [here][flask-example]
#### FastAPI
FastAPI Example
```python
from advanced_alchemy.extensions.fastapi import AdvancedAlchemy, SQLAlchemyAsyncConfig
from fastapi import FastAPI
app = FastAPI()
alchemy = AdvancedAlchemy(
config=SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite"), app=app,
)
```
For a full FastAPI example with optional CLI integration, see [here][fastapi-example]
#### Starlette
Pre-built Example Apps
```python
from advanced_alchemy.extensions.starlette import AdvancedAlchemy, SQLAlchemyAsyncConfig
from starlette.applications import Starlette
app = Starlette()
alchemy = AdvancedAlchemy(
config=SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite"), app=app,
)
```
#### Sanic
Pre-built Example Apps
```python
from sanic import Sanic
from sanic_ext import Extend
from advanced_alchemy.extensions.sanic import AdvancedAlchemy, SQLAlchemyAsyncConfig
app = Sanic("AlchemySanicApp")
alchemy = AdvancedAlchemy(
sqlalchemy_config=SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite"),
)
Extend.register(alchemy)
```
## Contributing
All [Litestar Organization][litestar-org] projects will always be a community-centered, available for contributions of any size.
Before contributing, please review the [contribution guide][contributing].
If you have any questions, reach out to us on [Discord][discord], our org-wide [GitHub discussions][litestar-discussions] page,
or the [project-specific GitHub discussions page][project-discussions].
[litestar-org]: https://github.com/litestar-org
[contributing]: https://docs.advanced-alchemy.litestar.dev/latest/contribution-guide.html
[discord]: https://discord.gg/litestar
[litestar-discussions]: https://github.com/orgs/litestar-org/discussions
[project-discussions]: https://github.com/litestar-org/advanced-alchemy/discussions
[project-docs]: https://docs.advanced-alchemy.litestar.dev
[install-guide]: https://docs.advanced-alchemy.litestar.dev/latest/#installation
[fastapi-example]: https://github.com/litestar-org/advanced-alchemy/blob/main/examples/fastapi/fastapi_service.py
[flask-example]: https://github.com/litestar-org/advanced-alchemy/blob/main/examples/flask/flask_services.py
[litestar-example]: https://github.com/litestar-org/advanced-alchemy/blob/main/examples/litestar/litestar_service.py
[standalone-example]: https://github.com/litestar-org/advanced-alchemy/blob/main/examples/standalone.py
python-advanced-alchemy-1.4.1/advanced_alchemy/ 0000775 0000000 0000000 00000000000 15003544734 0021475 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/__init__.py 0000664 0000000 0000000 00000000570 15003544734 0023610 0 ustar 00root root 0000000 0000000 from advanced_alchemy import (
alembic,
base,
cli,
config,
exceptions,
extensions,
filters,
mixins,
operations,
service,
types,
utils,
)
__all__ = (
"alembic",
"base",
"cli",
"config",
"exceptions",
"extensions",
"filters",
"mixins",
"operations",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/__main__.py 0000664 0000000 0000000 00000000366 15003544734 0023574 0 ustar 00root root 0000000 0000000 from advanced_alchemy.cli import add_migration_commands as build_cli_interface
def run_cli() -> None: # pragma: no cover
"""Advanced Alchemy CLI"""
build_cli_interface()()
if __name__ == "__main__": # pragma: no cover
run_cli()
python-advanced-alchemy-1.4.1/advanced_alchemy/__metadata__.py 0000664 0000000 0000000 00000001067 15003544734 0024427 0 ustar 00root root 0000000 0000000 """Metadata for the Project."""
from importlib.metadata import PackageNotFoundError, metadata, version # pragma: no cover
__all__ = ("__project__", "__version__") # pragma: no cover
try: # pragma: no cover
__version__ = version("advanced_alchemy")
"""Version of the project."""
__project__ = metadata("advanced_alchemy")["Name"]
"""Name of the project."""
except PackageNotFoundError: # pragma: no cover
__version__ = "0.0.1"
__project__ = "Advanced Alchemy"
finally: # pragma: no cover
del version, PackageNotFoundError, metadata
python-advanced-alchemy-1.4.1/advanced_alchemy/_listeners.py 0000664 0000000 0000000 00000047315 15003544734 0024230 0 ustar 00root root 0000000 0000000 # ruff: noqa: BLE001, C901, PLR0914, PLR0915
"""Application ORM configuration."""
import asyncio
import contextvars
import datetime
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
from sqlalchemy import event
from sqlalchemy.inspection import inspect
if TYPE_CHECKING:
from sqlalchemy.orm import Session, UOWTransaction
from advanced_alchemy.types.file_object import FileObjectSessionTracker, StorageRegistry
_active_file_operations: set[asyncio.Task[Any]] = set()
"""Stores active file operations to prevent them from being garbage collected."""
# Context variable to hold the session tracker instance for the current session context
_current_session_tracker: contextvars.ContextVar[Optional["FileObjectSessionTracker"]] = contextvars.ContextVar(
"_current_session_tracker",
default=None,
)
# Context variable to track if we're in an async context
_is_async_context: contextvars.ContextVar[bool] = contextvars.ContextVar(
"_is_async_context",
default=False,
)
logger = logging.getLogger("advanced_alchemy")
def set_async_context(is_async: bool = True) -> Optional[contextvars.Token[bool]]:
"""Set the async context flag.
Args:
is_async: Whether the context is async.
Returns:
The token for the async context.
"""
return _is_async_context.set(is_async)
def reset_async_context(token: contextvars.Token[bool]) -> None:
"""Reset the async context flag using the provided token."""
_is_async_context.reset(token)
def is_async_context() -> bool:
return _is_async_context.get()
def _get_session_tracker(create: bool = True) -> Optional["FileObjectSessionTracker"]:
from advanced_alchemy.types.file_object import FileObjectSessionTracker
tracker = _current_session_tracker.get()
if tracker is None and create:
tracker = FileObjectSessionTracker()
_current_session_tracker.set(tracker)
return tracker
def _inspect_attribute_changes(
instance: Any,
tracker: "FileObjectSessionTracker",
) -> None:
from advanced_alchemy.types.file_object import FileObject, StoredObject
from advanced_alchemy.types.mutables import MutableList
state = inspect(instance)
if not state:
return
mapper = state.mapper
if not mapper:
return
for attr_name, attr in mapper.column_attrs.items():
if not isinstance(attr.expression.type, StoredObject):
continue
is_multiple = getattr(attr.expression.type, "multiple", False)
try:
attr_state = state.attrs[attr_name]
except KeyError:
continue
history = attr_state.history
# Handle single FileObject attribute
if not is_multiple:
current_value: Optional[FileObject] = history.added[0] if history.added else None
original_value: Optional[FileObject] = history.deleted[0] if history.deleted else None
if current_value:
pending_content = getattr(current_value, "_pending_source_content", None)
pending_source_path = getattr(current_value, "_pending_source_path", None)
if pending_content is not None:
tracker.add_pending_save(current_value, pending_content)
elif pending_source_path is not None:
tracker.add_pending_save(current_value, pending_source_path)
if original_value and original_value.path:
tracker.add_pending_delete(original_value)
continue
# --- Multiple FileObjects Logic (v4 - Prioritize _pending_removed) ---
items_to_delete: set[FileObject] = set()
items_to_save: dict[FileObject, Any] = {}
current_list_instance: Optional[MutableList[FileObject]] = getattr(instance, attr_name, None)
original_list_from_history: Optional[MutableList[FileObject]] = history.deleted[0] if history.deleted else None
current_list_from_history: Optional[MutableList[FileObject]] = history.added[0] if history.added else None
# 1. Deletions from Mutations (Primary source: _pending_removed set)
if isinstance(current_list_instance, MutableList):
removed_items_internal: set[FileObject] = getattr(
current_list_instance, "_pending_removed", set[FileObject]()
)
valid_removed_internal = {item for item in removed_items_internal if item and item.path}
if valid_removed_internal:
logger.debug(
"[Multiple-Mutation] Found %d valid items in internal _pending_removed set.",
len(valid_removed_internal),
)
items_to_delete.update(valid_removed_internal)
# 2. Deletions from Replacements (Secondary source: history)
if original_list_from_history: # Indicates list replacement
logger.debug("[Multiple-Replacement] Processing list replacement via history.")
original_items_set = {item for item in original_list_from_history if item.path}
current_items_set = (
{item for item in current_list_from_history if item.path}
if current_list_from_history
else set[FileObject]()
)
removed_due_to_replacement = original_items_set - current_items_set
if removed_due_to_replacement:
logger.debug(
"[Multiple-Replacement] Found %d items removed via replacement.", len(removed_due_to_replacement)
)
items_to_delete.update(removed_due_to_replacement)
# 3. Determine items to save
# Saves from pending appends (Mutation or New)
if isinstance(current_list_instance, MutableList):
pending_append = getattr(current_list_instance, "_pending_append", [])
if pending_append:
logger.debug("[Multiple-Mutation] Found %d items in _pending_append list.", len(pending_append))
for item in pending_append:
pending_content = getattr(item, "_pending_content", None)
pending_source_path = getattr(item, "_pending_source_path", None)
if pending_content is not None:
items_to_save[item] = pending_content
elif pending_source_path is not None:
items_to_save[item] = pending_source_path
# Saves from newly added list items (New Instance or Replacement)
if current_list_from_history:
log_prefix = "[Multiple-New]" if not original_list_from_history else "[Multiple-Replacement]"
logger.debug(
"%s Checking current list from history (%d items) for pending saves.",
log_prefix,
len(current_list_from_history),
)
for item in current_list_from_history:
pending_content = getattr(item, "_pending_source_content", None)
pending_source_path = getattr(item, "_pending_source_path", None)
if pending_content is not None and item not in items_to_save:
logger.debug("%s Found pending content for %r", log_prefix, item.filename)
items_to_save[item] = pending_content
elif pending_source_path is not None and item not in items_to_save:
logger.debug("%s Found pending source path for %r", log_prefix, item.filename)
items_to_save[item] = pending_source_path
# 4. Finalize MutableList state (if applicable)
if isinstance(current_list_instance, MutableList):
finalize_method = getattr(current_list_instance, "_finalize_pending", None)
if finalize_method:
logger.debug("[Multiple] Calling _finalize_pending on list instance.")
finalize_method()
# 5. Schedule all collected operations
if items_to_delete:
logger.debug("[Multiple] Scheduling %d items for deletion.", len(items_to_delete))
for item_to_delete in items_to_delete:
tracker.add_pending_delete(item_to_delete)
if items_to_save:
logger.debug("[Multiple] Scheduling %d items for saving.", len(items_to_save))
for item_to_save, data in items_to_save.items():
tracker.add_pending_save(item_to_save, data)
class FileObjectListener: # pragma: no cover
"""Manages FileObject persistence actions during SQLAlchemy Session transactions.
This listener hooks into the SQLAlchemy Session event lifecycle to automatically
handle the saving and deletion of files associated with `FileObject` attributes
mapped using the `StoredObject` type.
How it Works:
1. **Event Registration (`setup_file_object_listeners`):**
This listener's methods are registered to be called during specific phases
of a Session's lifecycle (`before_flush`, `after_commit`, `after_rollback`).
2. **Tracking Changes (`before_flush`):**
* Before SQLAlchemy writes changes to the database (`flush`), this method
is triggered.
* It inspects objects within the session that are:
* `session.new`: Newly added to the session.
* `session.dirty`: Modified within the session.
* `session.deleted`: Marked for deletion.
* For each object, it checks attributes mapped with `StoredObject`.
* Using SQLAlchemy's attribute history, it identifies:
* New `FileObject` instances (or those with pending content/paths) that need saving.
* Old `FileObject` instances that have been replaced or belong to deleted objects and need deleting.
* These intended file operations (saves and deletes) are recorded in a
`FileObjectSessionTracker` specific to the current session context.
3. **Executing Operations (`after_commit`):**
* If the session transaction successfully commits, this method is called.
* It retrieves the `FileObjectSessionTracker` for the session.
* It instructs the tracker to execute all the pending file save and delete operations
using the appropriate storage backend.
* The tracker is then cleared.
4. **Discarding Operations (`after_rollback`):**
* If the session transaction is rolled back, this method is called.
* It retrieves the tracker and instructs it to discard all pending operations,
as the database changes they corresponded to were also discarded.
* The tracker is then cleared.
**Synchronous vs. Asynchronous Handling:**
* The listener needs to know if it's operating within a standard synchronous
SQLAlchemy Session or an `AsyncSession`.
* The `set_async_context(True)` function should be called before using an
`AsyncSession` to set a flag (using `contextvars`).
* The `is_async_context()` function checks this flag.
* In `after_commit` and `after_rollback`, if `is_async_context()` is true,
the file operations (tracker commit/rollback) are scheduled to run
asynchronously using `asyncio.create_task`. Otherwise, they are executed
synchronously.
This ensures that file operations align correctly with the database transaction
and are performed efficiently whether using sync or async sessions.
"""
@classmethod
def _is_listener_enabled(cls, session: "Session") -> bool:
enable_listener = True # Enabled by default
session_info = getattr(session, "info", {})
if "enable_file_object_listener" in session_info:
return bool(session_info["enable_file_object_listener"])
# Type hint for the list of potential option sources
options_sources: list[Optional[Union[Callable[[], dict[str, Any]], dict[str, Any]]]] = []
if session.bind:
options_sources.append(getattr(session.bind, "execution_options", None))
sync_engine = getattr(session.bind, "sync_engine", None)
if sync_engine:
options_sources.append(getattr(sync_engine, "execution_options", None))
options_sources.append(getattr(session, "execution_options", None))
for options_source in options_sources:
if options_source is None:
continue
options: Optional[dict[str, Any]] = None
if callable(options_source):
try:
result = options_source()
if isinstance(result, dict): # pyright: ignore
options = result
except Exception as e:
logger.debug("Error calling execution_options source: %s", e)
else:
# If not None and not callable, assume it's the dict based on type hint
options = options_source
# Only perform the 'in' check if we successfully got a dictionary
if options is not None and "enable_file_object_listener" in options:
enable_listener = bool(options["enable_file_object_listener"])
break
return enable_listener
@classmethod
def _process_commit(cls, tracker: "FileObjectSessionTracker") -> None:
"""Processes pending operations after a commit."""
try:
if is_async_context():
import asyncio
async def _do_async_commit() -> None:
try:
await tracker.commit_async()
except Exception as e:
# Using %s for cleaner logging of exception causes
logger.debug("An error occurred while committing a file object: %s", e.__cause__)
finally:
_current_session_tracker.set(None)
# Store the task reference, even if not awaited here
t = asyncio.create_task(_do_async_commit())
_active_file_operations.add(t)
t.add_done_callback(lambda _: _active_file_operations.remove(t))
else:
tracker.commit()
_current_session_tracker.set(None)
except Exception:
_current_session_tracker.set(None)
@classmethod
def _process_rollback(cls, tracker: "FileObjectSessionTracker") -> None:
"""Processes pending operations after a rollback."""
try:
if is_async_context():
import asyncio
async def _do_async_rollback() -> None:
try:
await tracker.rollback_async()
except Exception as e:
logger.debug("An error occurred during async FileObject rollback: %s", e.__cause__)
finally:
_current_session_tracker.set(None)
# Store the task reference, even if not awaited here
t = asyncio.create_task(_do_async_rollback())
_active_file_operations.add(t)
t.add_done_callback(lambda _: _active_file_operations.remove(t))
else:
tracker.rollback()
_current_session_tracker.set(None)
except Exception:
_current_session_tracker.set(None)
@classmethod
def before_flush(cls, session: "Session", flush_context: "UOWTransaction", instances: Optional[object]) -> None:
"""Track FileObject changes before a flush."""
from advanced_alchemy.types.file_object import StoredObject
if not cls._is_listener_enabled(session):
return
tracker = _get_session_tracker(create=True)
if not tracker:
return
for instance in session.new:
_inspect_attribute_changes(instance, tracker)
for instance in session.dirty:
_inspect_attribute_changes(instance, tracker)
for instance in session.deleted:
state = inspect(instance)
if not state:
continue
mapper = state.mapper
if not mapper:
continue
# Avoid inspecting if no StoredObject columns exist
has_stored_object = any(
isinstance(attr.expression.type, StoredObject) for attr in mapper.column_attrs.values()
)
if not has_stored_object:
continue
tracker = cls._process_pending_operations(tracker, instance, mapper)
@classmethod
def _process_pending_operations(
cls, tracker: "FileObjectSessionTracker", instance: Any, mapper: Any
) -> "FileObjectSessionTracker":
from advanced_alchemy.types.file_object import FileObject, StoredObject
from advanced_alchemy.types.mutables import MutableList
for attr_name, attr in mapper.column_attrs.items():
if isinstance(attr.expression.type, StoredObject):
is_multiple = getattr(attr.expression.type, "multiple", False)
original_value: Any = getattr(instance, attr_name, None)
if original_value is None:
continue
if not is_multiple:
tracker.add_pending_delete(original_value)
elif isinstance(original_value, (list, MutableList)):
for item in original_value: # pyright: ignore
tracker.add_pending_delete(cast("FileObject", item))
return tracker
@classmethod
def after_commit(cls, session: "Session") -> None:
"""Process file operations after a successful commit."""
tracker = _get_session_tracker(create=False)
if tracker:
cls._process_commit(tracker)
@classmethod
def after_rollback(cls, session: "Session") -> None:
"""Clean up pending file operations after a rollback."""
tracker = _get_session_tracker(create=False)
if tracker:
cls._process_rollback(tracker)
def setup_file_object_listeners(registry: Optional["StorageRegistry"] = None) -> None: # noqa: ARG001
"""Registers the FileObject event listeners globally."""
from sqlalchemy.event import contains
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
listeners = {
"before_flush": FileObjectListener.before_flush,
"after_commit": FileObjectListener.after_commit,
"after_rollback": FileObjectListener.after_rollback,
}
# Register for sync Session
for event_name, listener_func in listeners.items():
if not contains(Session, event_name, listener_func): # type: ignore[arg-type]
event.listen(Session, event_name, listener_func) # type: ignore[arg-type]
async_listeners_to_register = {
"after_commit": FileObjectListener.after_commit,
"after_rollback": FileObjectListener.after_rollback,
}
for event_name, listener_func in async_listeners_to_register.items():
if hasattr(AsyncSession, event_name) and not contains(AsyncSession, event_name, listener_func):
event.listen(AsyncSession, event_name, listener_func)
set_async_context(False)
# Existing listener (keep it)
def touch_updated_timestamp(session: "Session", *_: Any) -> None: # pragma: no cover
"""Set timestamp on update.
Called from SQLAlchemy's
:meth:`before_flush ` event to bump the ``updated``
timestamp on modified instances.
Args:
session: The sync :class:`Session ` instance that underlies the async
session.
"""
for instance in session.dirty:
state = inspect(instance)
if not state or not hasattr(state.mapper.class_, "updated_at"):
continue
updated_at_attr = state.attrs.get("updated_at")
if updated_at_attr and not updated_at_attr.history.has_changes():
instance.updated_at = datetime.datetime.now(datetime.timezone.utc)
python-advanced-alchemy-1.4.1/advanced_alchemy/_serialization.py 0000664 0000000 0000000 00000005624 15003544734 0025072 0 ustar 00root root 0000000 0000000 # ruff: noqa: PLR6301
import datetime
import enum
from typing import Any
from typing_extensions import runtime_checkable
from advanced_alchemy.exceptions import MissingDependencyError
try:
from pydantic import BaseModel # type: ignore
PYDANTIC_INSTALLED = True
except ImportError:
from typing import ClassVar, Protocol
@runtime_checkable
class BaseModel(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""
model_fields: ClassVar[dict[str, Any]]
def model_dump_json(self, *args: Any, **kwargs: Any) -> str:
"""Placeholder for pydantic.BaseModel.model_dump_json
Returns:
The JSON representation of the model.
"""
msg = "pydantic"
raise MissingDependencyError(msg)
PYDANTIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
def _type_to_string(value: Any) -> str: # pragma: no cover
if isinstance(value, datetime.datetime):
return convert_datetime_to_gmt_iso(value)
if isinstance(value, datetime.date):
return convert_date_to_iso(value)
if isinstance(value, enum.Enum):
return str(value.value)
if PYDANTIC_INSTALLED and isinstance(value, BaseModel):
return value.model_dump_json()
try:
val = str(value)
except Exception as exc:
raise TypeError from exc
return val
try:
from msgspec.json import Decoder, Encoder
encoder, decoder = Encoder(enc_hook=_type_to_string), Decoder()
decode_json = decoder.decode
def encode_json(data: Any) -> str: # pragma: no cover
return encoder.encode(data).decode("utf-8")
except ImportError:
try:
from orjson import OPT_NAIVE_UTC, OPT_SERIALIZE_NUMPY, OPT_SERIALIZE_UUID
from orjson import dumps as _encode_json
from orjson import loads as decode_json # type: ignore[no-redef,assignment]
def encode_json(data: Any) -> str: # pragma: no cover
return _encode_json(
data, default=_type_to_string, option=OPT_SERIALIZE_NUMPY | OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
).decode("utf-8") # type: ignore[no-any-return]
except ImportError:
from json import dumps as encode_json # type: ignore[assignment] # noqa: F401
from json import loads as decode_json # type: ignore[assignment] # noqa: F401
def convert_datetime_to_gmt_iso(dt: datetime.datetime) -> str: # pragma: no cover
"""Handle datetime serialization for nested timestamps.
Returns:
str: The ISO 8601 formatted datetime string.
"""
if not dt.tzinfo:
dt = dt.replace(tzinfo=datetime.timezone.utc)
return dt.isoformat().replace("+00:00", "Z")
def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover
"""Handle datetime serialization for nested timestamps.
Returns:
str: The ISO 8601 formatted date string.
"""
return dt.isoformat()
python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/ 0000775 0000000 0000000 00000000000 15003544734 0023071 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/__init__.py 0000664 0000000 0000000 00000000000 15003544734 0025170 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/commands.py 0000664 0000000 0000000 00000032323 15003544734 0025247 0 ustar 00root root 0000000 0000000 import sys
from typing import TYPE_CHECKING, Any, Optional, TextIO, Union
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
from alembic import command as migration_command
from alembic.config import Config as _AlembicCommandConfig
from alembic.ddl.impl import DefaultImpl
if TYPE_CHECKING:
import os
from argparse import Namespace
from collections.abc import Mapping
from pathlib import Path
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from advanced_alchemy.config.sync import SQLAlchemySyncConfig
from alembic.runtime.environment import ProcessRevisionDirectiveFn
from alembic.script.base import Script
class AlembicSpannerImpl(DefaultImpl):
"""Alembic implementation for Spanner."""
__dialect__ = "spanner+spanner"
class AlembicDuckDBImpl(DefaultImpl):
"""Alembic implementation for DuckDB."""
__dialect__ = "duckdb"
class AlembicCommandConfig(_AlembicCommandConfig):
def __init__(
self,
engine: "Union[Engine, AsyncEngine]",
version_table_name: str,
bind_key: "Optional[str]" = None,
file_: "Union[str, os.PathLike[str], None]" = None,
ini_section: str = "alembic",
output_buffer: "Optional[TextIO]" = None,
stdout: "TextIO" = sys.stdout,
cmd_opts: "Optional[Namespace]" = None,
config_args: "Optional[Mapping[str, Any]]" = None,
attributes: "Optional[dict[str, Any]]" = None,
template_directory: "Optional[Path]" = None,
version_table_schema: "Optional[str]" = None,
render_as_batch: bool = True,
compare_type: bool = False,
user_module_prefix: "Optional[str]" = "sa.",
) -> None:
"""Initialize the AlembicCommandConfig.
Args:
engine (sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine): The SQLAlchemy engine instance.
version_table_name (str): The name of the version table.
bind_key (str | None): The bind key for the metadata.
file_ (str | os.PathLike[str] | None): The file path for the alembic configuration.
ini_section (str): The ini section name.
output_buffer (typing.TextIO | None): The output buffer for alembic commands.
stdout (typing.TextIO): The standard output stream.
cmd_opts (argparse.Namespace | None): Command line options.
config_args (typing.Mapping[str, typing.Any] | None): Additional configuration arguments.
attributes (dict[str, typing.Any] | None): Additional attributes for the configuration.
template_directory (pathlib.Path | None): The directory for alembic templates.
version_table_schema (str | None): The schema for the version table.
render_as_batch (bool): Whether to render migrations as batch.
compare_type (bool): Whether to compare types during migrations.
user_module_prefix (str | None): The prefix for user modules.
"""
self.template_directory = template_directory
self.bind_key = bind_key
self.version_table_name = version_table_name
self.version_table_pk = engine.dialect.name != "spanner+spanner"
self.version_table_schema = version_table_schema
self.render_as_batch = render_as_batch
self.user_module_prefix = user_module_prefix
self.compare_type = compare_type
self.engine = engine
self.db_url = engine.url.render_as_string(hide_password=False)
if config_args is None:
config_args = {}
super().__init__(file_, ini_section, output_buffer, stdout, cmd_opts, config_args, attributes)
def get_template_directory(self) -> str:
"""Return the directory where Alembic setup templates are found.
This method is used by the alembic ``init`` and ``list_templates``
commands.
"""
if self.template_directory is not None:
return str(self.template_directory)
return super().get_template_directory()
class AlembicCommands:
def __init__(self, sqlalchemy_config: "Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]") -> None:
"""Initialize the AlembicCommands.
Args:
sqlalchemy_config (SQLAlchemyAsyncConfig | SQLAlchemySyncConfig): The SQLAlchemy configuration.
"""
self.sqlalchemy_config = sqlalchemy_config
self.config = self._get_alembic_command_config()
def upgrade(
self,
revision: str = "head",
sql: bool = False,
tag: "Optional[str]" = None,
) -> None:
"""Upgrade the database to a specified revision.
Args:
revision (str): The target revision to upgrade to.
sql (bool): If True, generate SQL script instead of applying changes.
tag (str | None): An optional tag to apply to the migration.
"""
return migration_command.upgrade(config=self.config, revision=revision, tag=tag, sql=sql)
def downgrade(
self,
revision: str = "head",
sql: bool = False,
tag: "Optional[str]" = None,
) -> None:
"""Downgrade the database to a specified revision.
Args:
revision (str): The target revision to downgrade to.
sql (bool): If True, generate SQL script instead of applying changes.
tag (str | None): An optional tag to apply to the migration.
"""
return migration_command.downgrade(config=self.config, revision=revision, tag=tag, sql=sql)
def check(self) -> None:
"""Check for pending upgrade operations.
This method checks if there are any pending upgrade operations
that need to be applied to the database.
"""
return migration_command.check(config=self.config)
def current(self, verbose: bool = False) -> None:
"""Display the current revision of the database.
Args:
verbose (bool): If True, display detailed information.
"""
return migration_command.current(self.config, verbose=verbose)
def edit(self, revision: str) -> None:
"""Edit the revision script using the system editor.
Args:
revision (str): The revision identifier to edit.
"""
return migration_command.edit(config=self.config, rev=revision)
def ensure_version(self, sql: bool = False) -> None:
"""Ensure the alembic version table exists.
Args:
sql (bool): If True, generate SQL script instead of applying changes.
"""
return migration_command.ensure_version(config=self.config, sql=sql)
def heads(self, verbose: bool = False, resolve_dependencies: bool = False) -> None:
"""Show current available heads in the script directory.
Args:
verbose (bool): If True, display detailed information.
resolve_dependencies (bool): If True, resolve dependencies between heads.
"""
return migration_command.heads(config=self.config, verbose=verbose, resolve_dependencies=resolve_dependencies)
def history(
self,
rev_range: "Optional[str]" = None,
verbose: bool = False,
indicate_current: bool = False,
) -> None:
"""List changeset scripts in chronological order.
Args:
rev_range (str | None): The revision range to display.
verbose (bool): If True, display detailed information.
indicate_current (bool): If True, indicate the current revision.
"""
return migration_command.history(
config=self.config,
rev_range=rev_range,
verbose=verbose,
indicate_current=indicate_current,
)
def merge(
self,
revisions: str,
message: "Optional[str]" = None,
branch_label: "Optional[str]" = None,
rev_id: "Optional[str]" = None,
) -> "Union[Script, None]":
"""Merge two revisions together.
Args:
revisions (str): The revisions to merge.
message (str | None): The commit message for the merge.
branch_label (str | None): The branch label for the merge.
rev_id (str | None): The revision ID for the merge.
Returns:
Script | None: The resulting script from the merge.
"""
return migration_command.merge(
config=self.config,
revisions=revisions,
message=message,
branch_label=branch_label,
rev_id=rev_id,
)
def revision(
self,
message: "Optional[str]" = None,
autogenerate: bool = False,
sql: bool = False,
head: str = "head",
splice: bool = False,
branch_label: "Optional[str]" = None,
version_path: "Optional[str]" = None,
rev_id: "Optional[str]" = None,
depends_on: "Optional[str]" = None,
process_revision_directives: "Optional[ProcessRevisionDirectiveFn]" = None,
) -> "Union[Script, list[Optional[Script]], None]":
"""Create a new revision file.
Args:
message (str | None): The commit message for the revision.
autogenerate (bool): If True, autogenerate the revision script.
sql (bool): If True, generate SQL script instead of applying changes.
head (str): The head revision to base the new revision on.
splice (bool): If True, create a splice revision.
branch_label (str | None): The branch label for the revision.
version_path (str | None): The path for the version file.
rev_id (str | None): The revision ID for the new revision.
depends_on (str | None): The revisions this revision depends on.
process_revision_directives (ProcessRevisionDirectiveFn | None): A function to process revision directives.
Returns:
Script | List[Script | None] | None: The resulting script(s) from the revision.
"""
return migration_command.revision(
config=self.config,
message=message,
autogenerate=autogenerate,
sql=sql,
head=head,
splice=splice,
branch_label=branch_label,
version_path=version_path,
rev_id=rev_id,
depends_on=depends_on,
process_revision_directives=process_revision_directives,
)
def show(
self,
rev: Any,
) -> None:
"""Show the revision(s) denoted by the given symbol.
Args:
rev (Any): The revision symbol to display.
"""
return migration_command.show(config=self.config, rev=rev)
def init(
self,
directory: str,
package: bool = False,
multidb: bool = False,
) -> None:
"""Initialize a new scripts directory.
Args:
directory (str): The directory to initialize.
package (bool): If True, create a package.
multidb (bool): If True, initialize for multiple databases.
"""
template = "sync"
if isinstance(self.sqlalchemy_config, SQLAlchemyAsyncConfig):
template = "asyncio"
if multidb:
template = f"{template}-multidb"
msg = "Multi database Alembic configurations are not currently supported."
raise NotImplementedError(msg)
return migration_command.init(
config=self.config,
directory=directory,
template=template,
package=package,
)
def list_templates(self) -> None:
"""List available templates.
This method lists all available templates for alembic initialization.
"""
return migration_command.list_templates(config=self.config)
def stamp(
self,
revision: str,
sql: bool = False,
tag: "Optional[str]" = None,
purge: bool = False,
) -> None:
"""Stamp the revision table with the given revision.
Args:
revision (str): The revision to stamp.
sql (bool): If True, generate SQL script instead of applying changes.
tag (str | None): An optional tag to apply to the migration.
purge (bool): If True, purge the revision history.
"""
return migration_command.stamp(config=self.config, revision=revision, sql=sql, tag=tag, purge=purge)
def _get_alembic_command_config(self) -> "AlembicCommandConfig":
"""Get the Alembic command configuration.
Returns:
AlembicCommandConfig: The configuration for Alembic commands.
"""
kwargs: dict[str, Any] = {}
if self.sqlalchemy_config.alembic_config.script_config:
kwargs["file_"] = self.sqlalchemy_config.alembic_config.script_config
if self.sqlalchemy_config.alembic_config.template_path:
kwargs["template_directory"] = self.sqlalchemy_config.alembic_config.template_path
kwargs.update(
{
"engine": self.sqlalchemy_config.get_engine(),
"version_table_name": self.sqlalchemy_config.alembic_config.version_table_name,
},
)
self.config = AlembicCommandConfig(**kwargs)
self.config.set_main_option("script_location", self.sqlalchemy_config.alembic_config.script_location)
return self.config
python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/ 0000775 0000000 0000000 00000000000 15003544734 0025067 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/__init__.py 0000664 0000000 0000000 00000000000 15003544734 0027166 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/asyncio/ 0000775 0000000 0000000 00000000000 15003544734 0026534 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/asyncio/__init__.py 0000664 0000000 0000000 00000000000 15003544734 0030633 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/asyncio/alembic.ini.mako 0000664 0000000 0000000 00000005001 15003544734 0031553 0 ustar 00root root 0000000 0000000 # Advanced Alchemy Alembic Asyncio Config
[alembic]
prepend_sys_path = src:.
# path to migration scripts
script_location = migrations
# template used to generate migration files
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(slug)s_%%(rev)s
# This is not required to be set when running through `advanced_alchemy`
# sqlalchemy.url = driver://user:pass@localhost/dbname
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone = UTC
# max length of characters to apply to the
# "slug" field
truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
output_encoding = utf-8
# [post_write_hooks]
# This section defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner,
# against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/asyncio/env.py 0000664 0000000 0000000 00000006431 15003544734 0027702 0 ustar 00root root 0000000 0000000 import asyncio
from typing import TYPE_CHECKING, cast
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import AsyncEngine, async_engine_from_config
from advanced_alchemy.base import metadata_registry
from alembic import context
from alembic.autogenerate import rewriter
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from advanced_alchemy.alembic.commands import AlembicCommandConfig
__all__ = ("do_run_migrations", "run_migrations_offline", "run_migrations_online")
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config: "AlembicCommandConfig" = context.config # type: ignore
writer = rewriter.Rewriter()
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
context.configure(
url=config.db_url,
target_metadata=metadata_registry.get(config.bind_key),
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: "Connection") -> None:
"""Run migrations."""
context.configure(
connection=connection,
target_metadata=metadata_registry.get(config.bind_key),
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine and associate a
connection with the context.
Raises:
RuntimeError: If the engine cannot be created from the config.
"""
configuration = config.get_section(config.config_ini_section) or {}
configuration["sqlalchemy.url"] = config.db_url
connectable = cast(
"AsyncEngine",
config.engine
or async_engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
),
)
if connectable is None: # pyright: ignore[reportUnnecessaryComparison]
msg = "Could not get engine from config. Please ensure your `alembic.ini` according to the official Alembic documentation."
raise RuntimeError(
msg,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())
python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/asyncio/script.py.mako 0000664 0000000 0000000 00000003440 15003544734 0031341 0 ustar 00root root 0000000 0000000 """${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
import warnings
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC, StoredObject, PasswordHash
from sqlalchemy import Text # noqa: F401
${imports if imports else ""}
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"]
sa.GUID = GUID
sa.DateTimeUTC = DateTimeUTC
sa.ORA_JSONB = ORA_JSONB
sa.EncryptedString = EncryptedString
sa.EncryptedText = EncryptedText
sa.StoredObject = StoredObject
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
schema_upgrades()
data_upgrades()
def downgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
data_downgrades()
schema_downgrades()
def schema_upgrades() -> None:
"""schema upgrade migrations go here."""
${upgrades if upgrades else "pass"}
def schema_downgrades() -> None:
"""schema downgrade migrations go here."""
${downgrades if downgrades else "pass"}
def data_upgrades() -> None:
"""Add any optional data upgrade migrations here!"""
def data_downgrades() -> None:
"""Add any optional data downgrade migrations here!"""
python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/sync/ 0000775 0000000 0000000 00000000000 15003544734 0026043 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/sync/__init__.py 0000664 0000000 0000000 00000000000 15003544734 0030142 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/sync/alembic.ini.mako 0000664 0000000 0000000 00000005002 15003544734 0031063 0 ustar 00root root 0000000 0000000 # Advanced Alchemy Alembic Sync Config
[alembic]
prepend_sys_path = src:.
# path to migration scripts
script_location = migrations
# template used to generate migration files
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(slug)s_%%(rev)s
# This is not required to be set when running through the `advanced_alchemy`
# sqlalchemy.url = driver://user:pass@localhost/dbname
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone = UTC
# max length of characters to apply to the
# "slug" field
truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
output_encoding = utf-8
# [post_write_hooks]
# This section defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner,
# against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/sync/env.py 0000664 0000000 0000000 00000006255 15003544734 0027215 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, cast
from sqlalchemy import Engine, engine_from_config, pool
from advanced_alchemy.base import metadata_registry
from alembic import context
from alembic.autogenerate import rewriter
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from advanced_alchemy.alembic.commands import AlembicCommandConfig
__all__ = ("do_run_migrations", "run_migrations_offline", "run_migrations_online")
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config: "AlembicCommandConfig" = context.config # type: ignore
writer = rewriter.Rewriter()
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
context.configure(
url=config.db_url,
target_metadata=metadata_registry.get(config.bind_key),
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: "Connection") -> None:
"""Run migrations."""
context.configure(
connection=connection,
target_metadata=metadata_registry.get(config.bind_key),
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine and associate a
connection with the context.
Raises:
RuntimeError: If the engine cannot be created from the config.
"""
configuration = config.get_section(config.config_ini_section) or {}
configuration["sqlalchemy.url"] = config.db_url
connectable = cast(
"Engine",
config.engine
or engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
),
)
if connectable is None: # pyright: ignore[reportUnnecessaryComparison]
msg = "Could not get engine from config. Please ensure your `alembic.ini` according to the official Alembic documentation."
raise RuntimeError(
msg,
)
with connectable.connect() as connection:
do_run_migrations(connection=connection)
connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/templates/sync/script.py.mako 0000664 0000000 0000000 00000003553 15003544734 0030655 0 ustar 00root root 0000000 0000000 """${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
import warnings
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC, StoredObject, PasswordHash
from sqlalchemy import Text # noqa: F401
${imports if imports else ""}
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"]
sa.GUID = GUID
sa.DateTimeUTC = DateTimeUTC
sa.ORA_JSONB = ORA_JSONB
sa.EncryptedString = EncryptedString
sa.EncryptedText = EncryptedText
sa.StoredObject = StoredObject
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: str | None = ${repr(down_revision)}
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
depends_on: str | Sequence[str] | None = ${repr(depends_on)}
def upgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
schema_upgrades()
data_upgrades()
def downgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
data_downgrades()
schema_downgrades()
def schema_upgrades() -> None:
"""schema upgrade migrations go here."""
${upgrades if upgrades else "pass"}
def schema_downgrades() -> None:
"""schema downgrade migrations go here."""
${downgrades if downgrades else "pass"}
def data_upgrades() -> None:
"""Add any optional data upgrade migrations here!"""
def data_downgrades() -> None:
"""Add any optional data downgrade migrations here!"""
python-advanced-alchemy-1.4.1/advanced_alchemy/alembic/utils.py 0000664 0000000 0000000 00000011575 15003544734 0024614 0 ustar 00root root 0000000 0000000 from contextlib import AbstractAsyncContextManager, AbstractContextManager
from pathlib import Path
from typing import TYPE_CHECKING, Union
from sqlalchemy import Engine, MetaData, Table
from typing_extensions import TypeIs
from advanced_alchemy.exceptions import MissingDependencyError
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import DeclarativeBase, Session
__all__ = ("drop_all", "dump_tables")
async def drop_all(engine: "Union[AsyncEngine, Engine]", version_table_name: str, metadata: MetaData) -> None:
"""Drop all tables in the database.
Args:
engine: The database engine.
version_table_name: The name of the version table.
metadata: The metadata object containing the tables to drop.
Raises:
MissingDependencyError: If the `rich` package is not installed.
"""
try:
from rich import get_console
except ImportError as e: # pragma: no cover
msg = "rich"
raise MissingDependencyError(msg, install_package="cli") from e
console = get_console()
def _is_sync(engine: "Union[Engine, AsyncEngine]") -> "TypeIs[Engine]":
return isinstance(engine, Engine)
def _drop_tables_sync(engine: Engine) -> None:
console.rule("[bold red]Connecting to database backend.")
with engine.begin() as db:
console.rule("[bold red]Dropping the db", align="left")
metadata.drop_all(db)
console.rule("[bold red]Dropping the version table", align="left")
Table(version_table_name, metadata).drop(db, checkfirst=True)
console.rule("[bold yellow]Successfully dropped all objects", align="left")
async def _drop_tables_async(engine: "AsyncEngine") -> None:
console.rule("[bold red]Connecting to database backend.", align="left")
async with engine.begin() as db:
console.rule("[bold red]Dropping the db", align="left")
await db.run_sync(metadata.drop_all)
console.rule("[bold red]Dropping the version table", align="left")
await db.run_sync(Table(version_table_name, metadata).drop, checkfirst=True)
console.rule("[bold yellow]Successfully dropped all objects", align="left")
if _is_sync(engine):
return _drop_tables_sync(engine)
return await _drop_tables_async(engine)
async def dump_tables(
dump_dir: Path,
session: "Union[AbstractContextManager[Session], AbstractAsyncContextManager[AsyncSession]]",
models: "list[type[DeclarativeBase]]",
) -> None:
from types import new_class
from advanced_alchemy._serialization import encode_json
try:
from rich import get_console
except ImportError as e: # pragma: no cover
msg = "rich"
raise MissingDependencyError(msg, install_package="cli") from e
console = get_console()
def _is_sync(
session: "Union[AbstractAsyncContextManager[AsyncSession], AbstractContextManager[Session]]",
) -> "TypeIs[AbstractContextManager[Session]]":
return isinstance(session, AbstractContextManager)
def _dump_table_sync(session: "AbstractContextManager[Session]") -> None:
from advanced_alchemy.repository import SQLAlchemySyncRepository
with session as _session:
for model in models:
json_path = dump_dir / f"{model.__tablename__}.json"
console.rule(
f"[yellow bold]Dumping table '{json_path.stem}' to '{json_path}'",
style="yellow",
align="left",
)
repo = new_class(
"repo",
(SQLAlchemySyncRepository,),
exec_body=lambda ns, model=model: ns.setdefault("model_type", model), # type: ignore[misc]
)
json_path.write_text(encode_json([row.to_dict() for row in repo(session=_session).list()]))
async def _dump_table_async(session: "AbstractAsyncContextManager[AsyncSession]") -> None:
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
async with session as _session:
for model in models:
json_path = dump_dir / f"{model.__tablename__}.json"
console.rule(
f"[yellow bold]Dumping table '{json_path.stem}' to '{json_path}'",
style="yellow",
align="left",
)
repo = new_class(
"repo",
(SQLAlchemyAsyncRepository,),
exec_body=lambda ns, model=model: ns.setdefault("model_type", model), # type: ignore[misc]
)
json_path.write_text(encode_json([row.to_dict() for row in await repo(session=_session).list()]))
dump_dir.mkdir(exist_ok=True)
if _is_sync(session):
return _dump_table_sync(session)
return await _dump_table_async(session)
python-advanced-alchemy-1.4.1/advanced_alchemy/base.py 0000664 0000000 0000000 00000037403 15003544734 0022770 0 ustar 00root root 0000000 0000000 # ruff: noqa: TC004
"""Common base classes for SQLAlchemy declarative models."""
import contextlib
import datetime
import re
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, cast, runtime_checkable
from uuid import UUID
from sqlalchemy import Date, MetaData, String
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import (
DeclarativeBase,
Mapper,
declared_attr,
)
from sqlalchemy.orm import (
registry as SQLAlchemyRegistry, # noqa: N812
)
from sqlalchemy.orm.decl_base import (
_TableArgsType as TableArgsType, # pyright: ignore[reportPrivateUsage]
)
from sqlalchemy.types import TypeEngine
from typing_extensions import Self, TypeVar
from advanced_alchemy.mixins import (
AuditColumns,
BigIntPrimaryKey,
NanoIDPrimaryKey,
UUIDPrimaryKey,
UUIDv6PrimaryKey,
UUIDv7PrimaryKey,
)
from advanced_alchemy.types import GUID, DateTimeUTC, FileObject, FileObjectList, JsonB, StoredObject
from advanced_alchemy.utils.dataclass import DataclassProtocol
if TYPE_CHECKING:
from sqlalchemy.sql import FromClause
from sqlalchemy.sql.schema import (
_NamingSchemaParameter as NamingSchemaParameter, # pyright: ignore[reportPrivateUsage]
)
__all__ = (
"AdvancedDeclarativeBase",
"BasicAttributes",
"BigIntAuditBase",
"BigIntBase",
"BigIntBaseT",
"CommonTableAttributes",
"ModelProtocol",
"NanoIDAuditBase",
"NanoIDBase",
"NanoIDBaseT",
"SQLQuery",
"TableArgsType",
"UUIDAuditBase",
"UUIDBase",
"UUIDBaseT",
"UUIDv6AuditBase",
"UUIDv6Base",
"UUIDv6BaseT",
"UUIDv7AuditBase",
"UUIDv7Base",
"UUIDv7BaseT",
"convention",
"create_registry",
"merge_table_arguments",
"metadata_registry",
"orm_registry",
"table_name_regexp",
)
UUIDBaseT = TypeVar("UUIDBaseT", bound="UUIDBase")
"""Type variable for :class:`UUIDBase`."""
BigIntBaseT = TypeVar("BigIntBaseT", bound="BigIntBase")
"""Type variable for :class:`BigIntBase`."""
UUIDv6BaseT = TypeVar("UUIDv6BaseT", bound="UUIDv6Base")
"""Type variable for :class:`UUIDv6Base`."""
UUIDv7BaseT = TypeVar("UUIDv7BaseT", bound="UUIDv7Base")
"""Type variable for :class:`UUIDv7Base`."""
NanoIDBaseT = TypeVar("NanoIDBaseT", bound="NanoIDBase")
"""Type variable for :class:`NanoIDBase`."""
convention: "NamingSchemaParameter" = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
"""Templates for automated constraint name generation."""
table_name_regexp = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
"""Regular expression for table name"""
def merge_table_arguments(cls: type[DeclarativeBase], table_args: Optional[TableArgsType] = None) -> TableArgsType:
"""Merge Table Arguments.
This function helps merge table arguments when using mixins that include their own table args,
making it easier to append additional information such as comments or constraints to the model.
Args:
cls (type[:class:`sqlalchemy.orm.DeclarativeBase`]): The model that will get the table args.
table_args (:class:`TableArgsType`, optional): Additional information to add to table_args.
Returns:
:class:`TableArgsType`: Merged table arguments.
"""
args: list[Any] = []
kwargs: dict[str, Any] = {}
mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in cls.__bases__) # pyright: ignore[reportUnknownParameter,reportUnknownArgumentType,reportArgumentType]
for arg_to_merge in (*mixin_table_args, table_args):
if arg_to_merge:
if isinstance(arg_to_merge, tuple):
last_positional_arg = arg_to_merge[-1] # pyright: ignore[reportUnknownVariableType]
args.extend(arg_to_merge[:-1]) # pyright: ignore[reportUnknownArgumentType]
if isinstance(last_positional_arg, dict):
kwargs.update(last_positional_arg) # pyright: ignore[reportUnknownArgumentType]
else:
args.append(last_positional_arg)
else:
kwargs.update(arg_to_merge) # pyright: ignore
if args:
if kwargs:
return (*args, kwargs)
return tuple(args)
return kwargs
@runtime_checkable
class ModelProtocol(Protocol):
"""The base SQLAlchemy model protocol.
Attributes:
__table__ (:class:`sqlalchemy.sql.FromClause`): The table associated with the model.
__mapper__ (:class:`sqlalchemy.orm.Mapper`): The mapper for the model.
__name__ (str): The name of the model.
"""
if TYPE_CHECKING:
__table__: FromClause
__mapper__: Mapper[Any]
__name__: str
def to_dict(self, exclude: Optional[set[str]] = None) -> dict[str, Any]:
"""Convert model to dictionary.
Returns:
Dict[str, Any]: A dict representation of the model
"""
...
class BasicAttributes:
"""Basic attributes for SQLAlchemy tables and queries.
Provides a method to convert the model to a dictionary representation.
Methods:
to_dict: Converts the model to a dictionary, excluding specified fields. :no-index:
"""
if TYPE_CHECKING:
__name__: str
__table__: FromClause
__mapper__: Mapper[Any]
def to_dict(self, exclude: Optional[set[str]] = None) -> dict[str, Any]:
"""Convert model to dictionary.
Returns:
Dict[str, Any]: A dict representation of the model
"""
exclude = {"sa_orm_sentinel", "_sentinel"}.union(self._sa_instance_state.unloaded).union(exclude or []) # type: ignore[attr-defined]
return {
field: getattr(self, field)
for field in self.__mapper__.columns.keys() # noqa: SIM118
if field not in exclude
}
class CommonTableAttributes(BasicAttributes):
"""Common attributes for SQLAlchemy tables.
Inherits from :class:`BasicAttributes` and provides a mechanism to infer table names from class names.
Attributes:
__tablename__ (str): The inferred table name.
"""
if TYPE_CHECKING:
__tablename__: str
else:
@declared_attr.directive
def __tablename__(cls) -> str:
"""Infer table name from class name.
Returns:
str: The inferred table name.
"""
return table_name_regexp.sub(r"_\1", cls.__name__).lower()
def create_registry(
custom_annotation_map: Optional[dict[Any, Union[type[TypeEngine[Any]], TypeEngine[Any]]]] = None,
) -> SQLAlchemyRegistry:
"""Create a new SQLAlchemy registry.
Args:
custom_annotation_map (dict, optional): Custom type annotations to use for the registry.
Returns:
:class:`sqlalchemy.orm.registry`: A new SQLAlchemy registry with the specified type annotations.
"""
import uuid as core_uuid
meta = MetaData(naming_convention=convention)
type_annotation_map: dict[Any, Union[type[TypeEngine[Any]], TypeEngine[Any]]] = {
UUID: GUID,
core_uuid.UUID: GUID,
datetime.datetime: DateTimeUTC,
datetime.date: Date,
dict: JsonB,
dict[str, Any]: JsonB,
dict[str, str]: JsonB,
DataclassProtocol: JsonB,
FileObject: StoredObject,
FileObjectList: StoredObject,
}
with contextlib.suppress(ImportError):
from pydantic import AnyHttpUrl, AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, Json
type_annotation_map.update(
{
EmailStr: String,
AnyUrl: String,
AnyHttpUrl: String,
Json: JsonB,
IPvAnyAddress: String,
IPvAnyInterface: String,
IPvAnyNetwork: String,
}
)
with contextlib.suppress(ImportError):
from msgspec import Struct
type_annotation_map[Struct] = JsonB
if custom_annotation_map is not None:
type_annotation_map.update(custom_annotation_map)
return SQLAlchemyRegistry(metadata=meta, type_annotation_map=type_annotation_map)
orm_registry = create_registry()
class MetadataRegistry:
"""A registry for metadata.
Provides methods to get and set metadata for different bind keys.
Methods:
get: Retrieves the metadata for a given bind key.
set: Sets the metadata for a given bind key.
"""
_instance: Optional["MetadataRegistry"] = None
_registry: dict[Union[str, None], MetaData] = {None: orm_registry.metadata}
def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cast("Self", cls._instance)
def get(self, bind_key: Optional[str] = None) -> MetaData:
"""Get the metadata for the given bind key.
Args:
bind_key (Optional[str]): The bind key for the metadata.
Returns:
:class:`sqlalchemy.MetaData`: The metadata for the given bind key.
"""
return self._registry.setdefault(bind_key, MetaData(naming_convention=convention))
def set(self, bind_key: Optional[str], metadata: MetaData) -> None:
"""Set the metadata for the given bind key.
Args:
bind_key (Optional[str]): The bind key for the metadata.
metadata (:class:`sqlalchemy.MetaData`): The metadata to set.
"""
self._registry[bind_key] = metadata
def __iter__(self) -> Iterator[Union[str, None]]:
return iter(self._registry)
def __getitem__(self, bind_key: Union[str, None]) -> MetaData:
return self._registry[bind_key]
def __setitem__(self, bind_key: Union[str, None], metadata: MetaData) -> None:
self._registry[bind_key] = metadata
def __contains__(self, bind_key: Union[str, None]) -> bool:
return bind_key in self._registry
metadata_registry = MetadataRegistry()
class AdvancedDeclarativeBase(DeclarativeBase):
"""A subclass of declarative base that allows for overriding of the registry.
Inherits from :class:`sqlalchemy.orm.DeclarativeBase`.
Attributes:
registry (:class:`sqlalchemy.orm.registry`): The registry for the declarative base.
__metadata_registry__ (:class:`~advanced_alchemy.base.MetadataRegistry`): The metadata registry.
__bind_key__ (Optional[:class:`str`]): The bind key for the metadata.
"""
registry = orm_registry
__abstract__ = True
__metadata_registry__: MetadataRegistry = MetadataRegistry()
__bind_key__: Optional[str] = None
def __init_subclass__(cls, **kwargs: Any) -> None:
bind_key = getattr(cls, "__bind_key__", None)
if bind_key is not None:
cls.metadata = cls.__metadata_registry__.get(bind_key)
elif None not in cls.__metadata_registry__ and getattr(cls, "metadata", None) is not None:
cls.__metadata_registry__[None] = cls.metadata
super().__init_subclass__(**kwargs)
class UUIDBase(UUIDPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v4 primary keys.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDPrimaryKey`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDAuditBase(CommonTableAttributes, UUIDPrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v4 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDv6Base(UUIDv6PrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v6 primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.UUIDv6PrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDv6AuditBase(CommonTableAttributes, UUIDv6PrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v6 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDv6PrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDv7Base(UUIDv7PrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v7 primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.UUIDv7PrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDv7AuditBase(CommonTableAttributes, UUIDv7PrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v7 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDv7PrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class NanoIDBase(NanoIDPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with Nano ID primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.NanoIDPrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class NanoIDAuditBase(CommonTableAttributes, NanoIDPrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with Nano ID primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.NanoIDPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class BigIntBase(BigIntPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with BigInt primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.BigIntPrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class BigIntAuditBase(CommonTableAttributes, BigIntPrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with BigInt primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.BigIntPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class DefaultBase(CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models. No primary key is added.
.. seealso::
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class SQLQuery(BasicAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy custom mapped objects.
.. seealso::
:class:`BasicAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
__allow_unmapped__ = True
python-advanced-alchemy-1.4.1/advanced_alchemy/cli.py 0000664 0000000 0000000 00000042131 15003544734 0022617 0 ustar 00root root 0000000 0000000 import sys
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union, cast
if TYPE_CHECKING:
from click import Group
from advanced_alchemy.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from alembic.migration import MigrationContext
from alembic.operations.ops import MigrationScript, UpgradeOps
__all__ = ("add_migration_commands", "get_alchemy_group")
def get_alchemy_group() -> "Group":
"""Get the Advanced Alchemy CLI group.
Raises:
MissingDependencyError: If the `click` package is not installed.
Returns:
The Advanced Alchemy CLI group.
"""
from advanced_alchemy.exceptions import MissingDependencyError
try:
import rich_click as click
except ImportError:
try:
import click # type: ignore[no-redef]
except ImportError as e:
raise MissingDependencyError(package="click", install_package="cli") from e
@click.group(name="alchemy")
@click.option(
"--config",
help="Dotted path to SQLAlchemy config(s) (e.g. 'myapp.config.alchemy_configs')",
required=True,
type=str,
)
@click.pass_context
def alchemy_group(ctx: "click.Context", config: str) -> None:
"""Advanced Alchemy CLI commands."""
from rich import get_console
from advanced_alchemy.utils import module_loader
console = get_console()
ctx.ensure_object(dict)
try:
config_instance = module_loader.import_string(config)
if isinstance(config_instance, Sequence):
ctx.obj["configs"] = config_instance
else:
ctx.obj["configs"] = [config_instance]
except ImportError as e:
console.print(f"[red]Error loading config: {e}[/]")
ctx.exit(1)
return alchemy_group
def add_migration_commands(database_group: Optional["Group"] = None) -> "Group": # noqa: C901, PLR0915
"""Add migration commands to the database group.
Args:
database_group: The database group to add the commands to.
Raises:
MissingDependencyError: If the `click` package is not installed.
Returns:
The database group with the migration commands added.
"""
from advanced_alchemy.exceptions import MissingDependencyError
try:
import rich_click as click
except ImportError:
try:
import click # type: ignore[no-redef]
except ImportError as e:
raise MissingDependencyError(package="click", install_package="cli") from e
from rich import get_console
console = get_console()
if database_group is None:
database_group = get_alchemy_group()
bind_key_option = click.option(
"--bind-key",
help="Specify which SQLAlchemy config to use by bind key",
type=str,
default=None,
)
verbose_option = click.option(
"--verbose",
help="Enable verbose output.",
type=bool,
default=False,
is_flag=True,
)
no_prompt_option = click.option(
"--no-prompt",
help="Do not prompt for confirmation before executing the command.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)
def get_config_by_bind_key(
ctx: "click.Context", bind_key: Optional[str]
) -> "Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]":
"""Get the SQLAlchemy config for the specified bind key.
Args:
ctx: The click context.
bind_key: The bind key to get the config for.
Returns:
The SQLAlchemy config for the specified bind key.
"""
configs = ctx.obj["configs"]
if bind_key is None:
return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", configs[0])
for config in configs:
if config.bind_key == bind_key:
return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", config)
console.print(f"[red]No config found for bind key: {bind_key}[/]")
sys.exit(1)
@database_group.command(
name="show-current-revision",
help="Shows the current revision for the database.",
)
@bind_key_option
@verbose_option
def show_database_revision(bind_key: Optional[str], verbose: bool) -> None: # pyright: ignore[reportUnusedFunction]
"""Show current database revision."""
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
console.rule("[yellow]Listing current revision[/]", align="left")
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.current(verbose=verbose)
@database_group.command(
name="downgrade",
help="Downgrade database to a specific revision.",
)
@bind_key_option
@click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True)
@click.option(
"--tag",
help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.",
type=str,
default=None,
)
@no_prompt_option
@click.argument(
"revision",
type=str,
default="-1",
)
def downgrade_database( # pyright: ignore[reportUnusedFunction]
bind_key: Optional[str], revision: str, sql: bool, tag: Optional[str], no_prompt: bool
) -> None:
"""Downgrade the database to the latest revision."""
from rich.prompt import Confirm
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
console.rule("[yellow]Starting database downgrade process[/]", align="left")
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?")
)
if input_confirmed:
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.downgrade(revision=revision, sql=sql, tag=tag)
@database_group.command(
name="upgrade",
help="Upgrade database to a specific revision.",
)
@bind_key_option
@click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True)
@click.option(
"--tag",
help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.",
type=str,
default=None,
)
@no_prompt_option
@click.argument(
"revision",
type=str,
default="head",
)
def upgrade_database( # pyright: ignore[reportUnusedFunction]
bind_key: Optional[str], revision: str, sql: bool, tag: Optional[str], no_prompt: bool
) -> None:
"""Upgrade the database to the latest revision."""
from rich.prompt import Confirm
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
console.rule("[yellow]Starting database upgrade process[/]", align="left")
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"[bold]Are you sure you want migrate the database to the `{revision}` revision?[/]")
)
if input_confirmed:
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.upgrade(revision=revision, sql=sql, tag=tag)
@database_group.command(
help="Stamp the revision table with the given revision",
)
@click.argument("revision", type=str)
@bind_key_option
def stamp(bind_key: Optional[str], revision: str) -> None: # pyright: ignore[reportUnusedFunction]
"""Stamp the revision table with the given revision."""
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.stamp(revision=revision)
@database_group.command(
name="init",
help="Initialize migrations for the project.",
)
@bind_key_option
@click.argument(
"directory",
default=None,
required=False,
)
@click.option("--multidb", is_flag=True, default=False, help="Support multiple databases")
@click.option("--package", is_flag=True, default=True, help="Create `__init__.py` for created folder")
@no_prompt_option
def init_alembic( # pyright: ignore[reportUnusedFunction]
bind_key: Optional[str], directory: Optional[str], multidb: bool, package: bool, no_prompt: bool
) -> None:
"""Initialize the database migrations."""
from rich.prompt import Confirm
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
console.rule("[yellow]Initializing database migrations.", align="left")
input_confirmed = (
True if no_prompt else Confirm.ask("[bold]Are you sure you want initialize migrations for the project?[/]")
)
if input_confirmed:
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
for config in configs:
directory = config.alembic_config.script_location if directory is None else directory
alembic_commands = AlembicCommands(sqlalchemy_config=config)
alembic_commands.init(directory=cast("str", directory), multidb=multidb, package=package)
@database_group.command(
name="make-migrations",
help="Create a new migration revision.",
)
@bind_key_option
@click.option("-m", "--message", default=None, help="Revision message")
@click.option(
"--autogenerate/--no-autogenerate", default=True, help="Automatically populate revision with detected changes"
)
@click.option("--sql", is_flag=True, default=False, help="Export to `.sql` instead of writing to the database.")
@click.option("--head", default="head", help="Specify head revision to use as base for new revision.")
@click.option(
"--splice", is_flag=True, default=False, help='Allow a non-head revision as the "head" to splice onto'
)
@click.option("--branch-label", default=None, help="Specify a branch label to apply to the new revision")
@click.option("--version-path", default=None, help="Specify specific path from config for version file")
@click.option("--rev-id", default=None, help="Specify a ID to use for revision.")
@no_prompt_option
def create_revision( # pyright: ignore[reportUnusedFunction]
bind_key: Optional[str],
message: Optional[str],
autogenerate: bool,
sql: bool,
head: str,
splice: bool,
branch_label: Optional[str],
version_path: Optional[str],
rev_id: Optional[str],
no_prompt: bool,
) -> None:
"""Create a new database revision."""
from rich.prompt import Prompt
from advanced_alchemy.alembic.commands import AlembicCommands
def process_revision_directives(
context: "MigrationContext", # noqa: ARG001
revision: tuple[str], # noqa: ARG001
directives: list["MigrationScript"],
) -> None:
"""Handle revision directives."""
if autogenerate and cast("UpgradeOps", directives[0].upgrade_ops).is_empty():
console.rule(
"[magenta]The generation of a migration file is being skipped because it would result in an empty file.",
style="magenta",
align="left",
)
console.rule(
"[magenta]More information can be found here. https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect",
style="magenta",
align="left",
)
console.rule(
"[magenta]If you intend to create an empty migration file, use the --no-autogenerate option.",
style="magenta",
align="left",
)
directives.clear()
ctx = click.get_current_context()
console.rule("[yellow]Starting database upgrade process[/]", align="left")
if message is None:
message = "autogenerated" if no_prompt else Prompt.ask("Please enter a message describing this revision")
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.revision(
message=message,
autogenerate=autogenerate,
sql=sql,
head=head,
splice=splice,
branch_label=branch_label,
version_path=version_path,
rev_id=rev_id,
process_revision_directives=process_revision_directives, # type: ignore[arg-type]
)
@database_group.command(name="drop-all", help="Drop all tables from the database.")
@bind_key_option
@no_prompt_option
def drop_all(bind_key: Optional[str], no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
"""Drop all tables from the database."""
from anyio import run
from rich.prompt import Confirm
from advanced_alchemy.alembic.utils import drop_all
from advanced_alchemy.base import metadata_registry
ctx = click.get_current_context()
console.rule("[yellow]Dropping all tables from the database[/]", align="left")
input_confirmed = no_prompt or Confirm.ask(
"[bold red]Are you sure you want to drop all tables from the database?"
)
async def _drop_all(
configs: "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]",
) -> None:
for config in configs:
engine = config.get_engine()
await drop_all(engine, config.alembic_config.version_table_name, metadata_registry.get(config.bind_key))
if input_confirmed:
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
run(_drop_all, configs)
@database_group.command(name="dump-data", help="Dump specified tables from the database to JSON files.")
@bind_key_option
@click.option(
"--table",
"table_names",
help="Name of the table to dump. Multiple tables can be specified. Use '*' to dump all tables.",
type=str,
required=True,
multiple=True,
)
@click.option(
"--dir",
"dump_dir",
help="Directory to save the JSON files. Defaults to WORKDIR/fixtures",
type=click.Path(path_type=Path),
default=Path.cwd() / "fixtures",
required=False,
)
def dump_table_data(bind_key: Optional[str], table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction]
"""Dump table data to JSON files."""
from anyio import run
from rich.prompt import Confirm
from advanced_alchemy.alembic.utils import dump_tables
from advanced_alchemy.base import metadata_registry, orm_registry
ctx = click.get_current_context()
all_tables = "*" in table_names
if all_tables and not Confirm.ask(
"[yellow bold]You have specified '*'. Are you sure you want to dump all tables from the database?",
):
return console.rule("[red bold]No data was dumped.", style="red", align="left")
async def _dump_tables() -> None:
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
for config in configs:
target_tables = set(metadata_registry.get(config.bind_key).tables)
if not all_tables:
for table_name in set(table_names) - target_tables:
console.rule(
f"[red bold]Skipping table '{table_name}' because it is not available in the default registry",
style="red",
align="left",
)
target_tables.intersection_update(table_names)
else:
console.rule("[yellow bold]Dumping all tables", style="yellow", align="left")
models = [
mapper.class_ for mapper in orm_registry.mappers if mapper.class_.__table__.name in target_tables
]
await dump_tables(dump_dir, config.get_session(), models)
console.rule("[green bold]Data dump complete", align="left")
return run(_dump_tables)
return database_group
python-advanced-alchemy-1.4.1/advanced_alchemy/config/ 0000775 0000000 0000000 00000000000 15003544734 0022742 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/config/__init__.py 0000664 0000000 0000000 00000001757 15003544734 0025065 0 ustar 00root root 0000000 0000000 from advanced_alchemy.config.asyncio import AlembicAsyncConfig, AsyncSessionConfig, SQLAlchemyAsyncConfig
from advanced_alchemy.config.common import (
ConnectionT,
EngineT,
GenericAlembicConfig,
GenericSessionConfig,
GenericSQLAlchemyConfig,
SessionMakerT,
SessionT,
)
from advanced_alchemy.config.engine import EngineConfig
from advanced_alchemy.config.sync import AlembicSyncConfig, SQLAlchemySyncConfig, SyncSessionConfig
from advanced_alchemy.config.types import CommitStrategy, TypeDecodersSequence, TypeEncodersMap
__all__ = (
"AlembicAsyncConfig",
"AlembicSyncConfig",
"AsyncSessionConfig",
"CommitStrategy",
"ConnectionT",
"EngineConfig",
"EngineT",
"GenericAlembicConfig",
"GenericSQLAlchemyConfig",
"GenericSessionConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SessionMakerT",
"SessionT",
"SyncSessionConfig",
"TypeDecodersSequence",
"TypeEncodersMap",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/config/asyncio.py 0000664 0000000 0000000 00000006567 15003544734 0024777 0 ustar 00root root 0000000 0000000 from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from advanced_alchemy._listeners import set_async_context
from advanced_alchemy.config.common import (
GenericAlembicConfig,
GenericSessionConfig,
GenericSQLAlchemyConfig,
)
from advanced_alchemy.utils.dataclass import Empty
if TYPE_CHECKING:
from typing import Callable
from sqlalchemy.orm import Session
from advanced_alchemy.utils.dataclass import EmptyType
__all__ = (
"AlembicAsyncConfig",
"AsyncSessionConfig",
"SQLAlchemyAsyncConfig",
)
@dataclass
class AsyncSessionConfig(GenericSessionConfig[AsyncConnection, AsyncEngine, AsyncSession]):
"""SQLAlchemy async session config."""
sync_session_class: "Optional[Union[type[Session], EmptyType]]" = Empty
"""A :class:`Session ` subclass or other callable which will be used to construct the
:class:`Session ` which will be proxied. This parameter may be used to provide custom
:class:`Session ` subclasses. Defaults to the
:attr:`AsyncSession.sync_session_class ` class-level
attribute."""
@dataclass
class AlembicAsyncConfig(GenericAlembicConfig):
"""Configuration for an Async Alembic's Config class.
.. seealso::
https://alembic.sqlalchemy.org/en/latest/api/config.html
"""
@dataclass
class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, async_sessionmaker[AsyncSession]]):
"""Async SQLAlchemy Configuration.
Note:
The alembic configuration options are documented in the Alembic documentation.
"""
create_engine_callable: "Callable[[str], AsyncEngine]" = create_async_engine
"""Callable that creates an :class:`AsyncEngine ` instance or instance of its
subclass.
"""
session_config: AsyncSessionConfig = field(default_factory=AsyncSessionConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration options for the :class:`async_sessionmaker`."""
session_maker_class: "type[async_sessionmaker[AsyncSession]]" = async_sessionmaker # pyright: ignore[reportIncompatibleVariableOverride]
"""Sessionmaker class to use."""
alembic_config: "AlembicAsyncConfig" = field(default_factory=AlembicAsyncConfig)
"""Configuration for the SQLAlchemy Alembic migrations.
The configuration options are documented in the Alembic documentation.
"""
def __hash__(self) -> int:
return super().__hash__()
def __eq__(self, other: object) -> bool:
return super().__eq__(other)
@asynccontextmanager
async def get_session(
self,
) -> AsyncGenerator[AsyncSession, None]:
"""Get a session from the session maker.
Yields:
AsyncGenerator[AsyncSession, None]: An async context manager that yields an AsyncSession.
"""
session_maker = self.create_session_maker()
set_async_context(True) # Set context for standalone usage
async with session_maker() as session:
yield session
python-advanced-alchemy-1.4.1/advanced_alchemy/config/common.py 0000664 0000000 0000000 00000034046 15003544734 0024613 0 ustar 00root root 0000000 0000000 from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, Union, cast
from typing_extensions import TypeVar
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config.engine import EngineConfig
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.utils.dataclass import Empty, simple_asdict
if TYPE_CHECKING:
from sqlalchemy import Connection, Engine, MetaData
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Mapper, Query, Session, sessionmaker
from sqlalchemy.orm.session import JoinTransactionMode
from sqlalchemy.sql import TableClause
from advanced_alchemy.utils.dataclass import EmptyType
__all__ = (
"ALEMBIC_TEMPLATE_PATH",
"ConnectionT",
"EngineT",
"GenericAlembicConfig",
"GenericSQLAlchemyConfig",
"GenericSessionConfig",
"SessionMakerT",
"SessionT",
)
ALEMBIC_TEMPLATE_PATH = f"{Path(__file__).parent.parent}/alembic/templates"
"""Path to the Alembic templates."""
ConnectionT = TypeVar("ConnectionT", bound="Union[Connection, AsyncConnection]")
"""Type variable for SQLAlchemy connection types.
.. seealso::
:class:`sqlalchemy.Connection`
:class:`sqlalchemy.ext.asyncio.AsyncConnection`
"""
EngineT = TypeVar("EngineT", bound="Union[Engine, AsyncEngine]")
"""Type variable for a SQLAlchemy engine.
.. seealso::
:class:`sqlalchemy.Engine`
:class:`sqlalchemy.ext.asyncio.AsyncEngine`
"""
SessionT = TypeVar("SessionT", bound="Union[Session, AsyncSession]")
"""Type variable for a SQLAlchemy session.
.. seealso::
:class:`sqlalchemy.Session`
:class:`sqlalchemy.ext.asyncio.AsyncSession`
"""
SessionMakerT = TypeVar("SessionMakerT", bound="Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]")
"""Type variable for a SQLAlchemy sessionmaker.
.. seealso::
:class:`sqlalchemy.orm.sessionmaker`
:class:`sqlalchemy.ext.asyncio.async_sessionmaker`
"""
@dataclass
class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]):
"""SQLAlchemy async session config.
Types:
ConnectionT: :class:`sqlalchemy.Connection` | :class:`sqlalchemy.ext.asyncio.AsyncConnection`
EngineT: :class:`sqlalchemy.Engine` | :class:`sqlalchemy.ext.asyncio.AsyncEngine`
SessionT: :class:`sqlalchemy.Session` | :class:`sqlalchemy.ext.asyncio.AsyncSession`
"""
autobegin: "Union[bool, EmptyType]" = Empty
"""Automatically start transactions when database access is requested by an operation.
Bool or :class:`Empty `
"""
autoflush: "Union[bool, EmptyType]" = Empty
"""When ``True``, all query operations will issue a flush call to this :class:`Session `
before proceeding"""
bind: "Optional[Union[EngineT, ConnectionT, EmptyType]]" = Empty
"""The :class:`Engine ` or :class:`Connection ` that new
:class:`Session ` objects will be bound to."""
binds: "Optional[Union[dict[Union[type[Any], Mapper[Any], TableClause, str], Union[EngineT, ConnectionT]], EmptyType]]" = Empty
"""A dictionary which may specify any number of :class:`Engine ` or :class:`Connection
` objects as the source of connectivity for SQL operations on a per-entity basis. The
keys of the dictionary consist of any series of mapped classes, arbitrary Python classes that are bases for mapped
classes, :class:`Table ` objects and :class:`Mapper ` objects. The
values of the dictionary are then instances of :class:`Engine ` or less commonly
:class:`Connection ` objects."""
class_: "Union[type[SessionT], EmptyType]" = Empty
"""Class to use in order to create new :class:`Session ` objects."""
expire_on_commit: "Union[bool, EmptyType]" = Empty
"""If ``True``, all instances will be expired after each commit."""
info: "Optional[Union[dict[str, Any], EmptyType]]" = Empty
"""Optional dictionary of information that will be available via the
:attr:`Session.info `"""
join_transaction_mode: "Union[JoinTransactionMode, EmptyType]" = Empty
"""Describes the transactional behavior to take when a given bind is a Connection that has already begun a
transaction outside the scope of this Session; in other words the
:attr:`Connection.in_transaction() ` method returns True."""
query_cls: "Optional[Union[type[Query], EmptyType]]" = Empty # pyright: ignore[reportMissingTypeArgument]
"""Class which should be used to create new Query objects, as returned by the
:attr:`Session.query() ` method."""
twophase: "Union[bool, EmptyType]" = Empty
"""When ``True``, all transactions will be started as a "two phase" transaction, i.e. using the "two phase"
semantics of the database in use along with an XID. During a :attr:`commit() `, after
:attr:`flush() ` has been issued for all attached databases, the
:attr:`TwoPhaseTransaction.prepare() ` method on each database`s
:class:`TwoPhaseTransaction ` will be called. This allows each database to
roll back the entire transaction, before each transaction is committed."""
@dataclass
class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]):
"""Common SQLAlchemy Configuration.
Types:
EngineT: :class:`sqlalchemy.Engine` or :class:`sqlalchemy.ext.asyncio.AsyncEngine`
SessionT: :class:`sqlalchemy.Session` or :class:`sqlalchemy.ext.asyncio.AsyncSession`
SessionMakerT: :class:`sqlalchemy.orm.sessionmaker` or :class:`sqlalchemy.ext.asyncio.async_sessionmaker`
"""
create_engine_callable: "Callable[[str], EngineT]"
"""Callable that creates an :class:`AsyncEngine ` instance or instance of its
subclass.
"""
session_config: "GenericSessionConfig[Any, Any, Any]"
"""Configuration options for either the :class:`async_sessionmaker `
or :class:`sessionmaker `.
"""
session_maker_class: "type[Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]]"
"""Sessionmaker class to use.
.. seealso::
:class:`sqlalchemy.orm.sessionmaker`
:class:`sqlalchemy.ext.asyncio.async_sessionmaker`
"""
connection_string: "Optional[str]" = field(default=None)
"""Database connection string in one of the formats supported by SQLAlchemy.
Notes:
- For async connections, the connection string must include the correct async prefix.
e.g. ``'postgresql+asyncpg://...'`` instead of ``'postgresql://'``, and for sync connections its the opposite.
"""
engine_config: "EngineConfig" = field(default_factory=EngineConfig)
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
session_maker: "Optional[Callable[[], SessionT]]" = None
"""Callable that returns a session.
If provided, the plugin will use this rather than instantiate a sessionmaker.
"""
engine_instance: "Optional[EngineT]" = None
"""Optional engine to use.
If set, the plugin will use the provided instance rather than instantiate an engine.
"""
create_all: bool = False
"""If true, all models are automatically created on engine creation."""
metadata: "Optional[MetaData]" = None
"""Optional metadata to use.
If set, the plugin will use the provided instance rather than the default metadata."""
bind_key: "Optional[str]" = None
"""Bind key to register a metadata to a specific engine configuration."""
enable_touch_updated_timestamp_listener: bool = True
"""Enable Created/Updated Timestamp event listener.
This is a listener that will update ``created_at`` and ``updated_at`` columns on record modification.
Disable if you plan to bring your own update mechanism for these columns"""
enable_file_object_listener: bool = True
"""Enable FileObject listener.
This is a listener that will automatically save and delete :class:`FileObject ` instances when they are saved or deleted.
Disable if you plan to bring your own save/delete mechanism for these columns"""
_SESSION_SCOPE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of session scope keys in the class."""
_ENGINE_APP_STATE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of engine app state keys in the class."""
_SESSIONMAKER_APP_STATE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of sessionmaker state keys in the class."""
def __post_init__(self) -> None:
if self.connection_string is not None and self.engine_instance is not None:
msg = "Only one of 'connection_string' or 'engine_instance' can be provided."
raise ImproperConfigurationError(msg)
if self.metadata is None:
self.metadata = metadata_registry.get(self.bind_key)
else:
metadata_registry.set(self.bind_key, self.metadata)
if self.enable_touch_updated_timestamp_listener:
from sqlalchemy import event
from sqlalchemy.orm import Session
from advanced_alchemy._listeners import touch_updated_timestamp
event.listen(Session, "before_flush", touch_updated_timestamp)
if self.enable_file_object_listener:
from advanced_alchemy._listeners import setup_file_object_listeners
setup_file_object_listeners()
def __hash__(self) -> int: # pragma: no cover
return hash(
(
self.__class__.__qualname__,
self.connection_string,
self.engine_config.__class__.__qualname__,
self.bind_key,
)
)
def __eq__(self, other: object) -> bool:
return self.__hash__() == other.__hash__()
@property
def engine_config_dict(self) -> dict[str, Any]:
"""Return the engine configuration as a dict.
Returns:
A string keyed dict of config kwargs for the SQLAlchemy :func:`sqlalchemy.get_engine`
function.
"""
return simple_asdict(self.engine_config, exclude_empty=True)
@property
def session_config_dict(self) -> dict[str, Any]:
"""Return the session configuration as a dict.
Returns:
A string keyed dict of config kwargs for the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker`
class.
"""
return simple_asdict(self.session_config, exclude_empty=True)
def get_engine(self) -> EngineT:
"""Return an engine. If none exists yet, create one.
Raises:
ImproperConfigurationError: if neither `connection_string` nor `engine_instance` are provided.
Returns:
:class:`sqlalchemy.Engine` or :class:`sqlalchemy.ext.asyncio.AsyncEngine` instance used by the plugin.
"""
if self.engine_instance:
return self.engine_instance
if self.connection_string is None:
msg = "One of 'connection_string' or 'engine_instance' must be provided."
raise ImproperConfigurationError(msg)
engine_config = self.engine_config_dict
try:
self.engine_instance = self.create_engine_callable(self.connection_string, **engine_config)
except TypeError:
# likely due to a dialect that doesn't support json type
del engine_config["json_deserializer"]
del engine_config["json_serializer"]
self.engine_instance = self.create_engine_callable(self.connection_string, **engine_config)
return self.engine_instance
def create_session_maker(self) -> "Callable[[], SessionT]": # pragma: no cover
"""Get a session maker. If none exists yet, create one.
Returns:
:class:`sqlalchemy.orm.sessionmaker` or :class:`sqlalchemy.ext.asyncio.async_sessionmaker` factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if session_kws.get("bind") is None:
session_kws["bind"] = self.get_engine()
self.session_maker = cast("Callable[[], SessionT]", self.session_maker_class(**session_kws))
return self.session_maker
@dataclass
class GenericAlembicConfig:
"""Configuration for Alembic's :class:`Config `.
For details see: https://alembic.sqlalchemy.org/en/latest/api/config.html
"""
script_config: str = "alembic.ini"
"""A path to the Alembic configuration file such as ``alembic.ini``. If left unset, the default configuration
will be used.
"""
version_table_name: str = "alembic_versions"
"""Configure the name of the table used to hold the applied alembic revisions.
Defaults to ``alembic_versions``.
"""
version_table_schema: "Optional[str]" = None
"""Configure the schema to use for the alembic revisions revisions.
If unset, it defaults to connection's default schema."""
script_location: str = "migrations"
"""A path to save generated migrations.
"""
user_module_prefix: "Optional[str]" = "sa."
"""User module prefix."""
render_as_batch: bool = True
"""Render as batch."""
compare_type: bool = False
"""Compare type."""
template_path: str = ALEMBIC_TEMPLATE_PATH
"""Template path."""
python-advanced-alchemy-1.4.1/advanced_alchemy/config/engine.py 0000664 0000000 0000000 00000026173 15003544734 0024572 0 ustar 00root root 0000000 0000000 from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.utils.dataclass import Empty
if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any
from sqlalchemy.engine.interfaces import IsolationLevel
from sqlalchemy.pool import Pool
from typing_extensions import TypeAlias
from advanced_alchemy.utils.dataclass import EmptyType
_EchoFlagType: "TypeAlias" = 'Union[None, bool, Literal["debug"]]'
_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat", "numeric_dollar"]
__all__ = ("EngineConfig",)
@dataclass
class EngineConfig:
"""Configuration for SQLAlchemy's Engine.
This class provides configuration options for SQLAlchemy engine creation.
See: https://docs.sqlalchemy.org/en/20/core/engines.html
"""
connect_args: "Union[dict[Any, Any], EmptyType]" = Empty
"""A dictionary of arguments which will be passed directly to the DBAPI's ``connect()`` method as keyword arguments.
"""
echo: "Union[_EchoFlagType, EmptyType]" = Empty
"""If ``True``, the Engine will log all statements as well as a ``repr()`` of their parameter lists to the default
log handler, which defaults to ``sys.stdout`` for output. If set to the string "debug", result rows will be printed
to the standard output as well. The echo attribute of Engine can be modified at any time to turn logging on and off;
direct control of logging is also available using the standard Python logging module.
"""
echo_pool: "Union[_EchoFlagType, EmptyType]" = Empty
"""If ``True``, the connection pool will log informational output such as when connections are invalidated as well
as when connections are recycled to the default log handler, which defaults to sys.stdout for output. If set to the
string "debug", the logging will include pool checkouts and checkins. Direct control of logging is also available
using the standard Python logging module."""
enable_from_linting: "Union[bool, EmptyType]" = Empty
"""Defaults to True. Will emit a warning if a given SELECT statement is found to have un-linked FROM elements which
would cause a cartesian product."""
execution_options: "Union[Mapping[str, Any], EmptyType]" = Empty
"""Dictionary execution options which will be applied to all connections. See
:attr:`Connection.execution_options() ` for details."""
hide_parameters: "Union[bool, EmptyType]" = Empty
"""Boolean, when set to ``True``, SQL statement parameters will not be displayed in INFO logging nor will they be
formatted into the string representation of :class:`StatementError ` objects."""
insertmanyvalues_page_size: "Union[int, EmptyType]" = Empty
"""Number of rows to format into an INSERT statement when the statement uses โinsertmanyvaluesโ mode, which is a
paged form of bulk insert that is used for many backends when using executemany execution typically in conjunction
with RETURNING. Defaults to 1000, but may also be subject to dialect-specific limiting factors which may override
this value on a per-statement basis."""
isolation_level: "Union[IsolationLevel, EmptyType]" = Empty
"""Optional string name of an isolation level which will be set on all new connections unconditionally. Isolation
levels are typically some subset of the string names "SERIALIZABLE", "REPEATABLE READ", "READ COMMITTED",
"READ UNCOMMITTED" and "AUTOCOMMIT" based on backend."""
json_deserializer: "Callable[[str], Any]" = decode_json
"""For dialects that support the :class:`JSON ` datatype, this is a Python callable that will
convert a JSON string to a Python object. By default, this is set to Litestar's
:attr:`decode_json() <.serialization.decode_json>` function."""
json_serializer: "Callable[[Any], str]" = encode_json
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, Litestar's :attr:`encode_json() <.serialization.encode_json>` is used."""
label_length: "Optional[Union[int, EmptyType]]" = Empty
"""Optional integer value which limits the size of dynamically generated column labels to that many characters. If
less than 6, labels are generated as โ_(counter)โ. If ``None``, the value of ``dialect.max_identifier_length``,
which may be affected via the
:attr:`get_engine.max_identifier_length parameter `, is
used instead. The value of
:attr:`get_engine.label_length ` may not be larger than that of
:attr:`get_engine.max_identifier_length `."""
logging_name: "Union[str, EmptyType]" = Empty
"""String identifier which will be used within the โnameโ field of logging records generated within the
โsqlalchemy.engineโ logger. Defaults to a hexstring of the object`s id."""
max_identifier_length: "Optional[Union[int, EmptyType]]" = Empty
"""Override the max_identifier_length determined by the dialect. if ``None`` or ``0``, has no effect. This is the
database`s configured maximum number of characters that may be used in a SQL identifier such as a table name, column
name, or label name. All dialects determine this value automatically, however in the case of a new database version
for which this value has changed but SQLAlchemy`s dialect has not been adjusted, the value may be passed here."""
max_overflow: "Union[int, EmptyType]" = Empty
"""The number of connections to allow in connection pool โoverflowโ, that is connections that can be opened above
and beyond the pool_size setting, which defaults to five. This is only used with
:class:`QueuePool `."""
module: "Optional[Union[Any, EmptyType]]" = Empty
"""Reference to a Python module object (the module itself, not its string name). Specifies an alternate DBAPI module
to be used by the engine`s dialect. Each sub-dialect references a specific DBAPI which will be imported before first
connect. This parameter causes the import to be bypassed, and the given module to be used instead. Can be used for
testing of DBAPIs as well as to inject โmockโ DBAPI implementations into the
:class:`Engine `."""
paramstyle: "Optional[Union[_ParamStyle, EmptyType]]" = Empty
"""The paramstyle to use when rendering bound parameters. This style defaults to the one recommended by the DBAPI
itself, which is retrieved from the ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept more than
one paramstyle, and in particular it may be desirable to change a โnamedโ paramstyle into a โpositionalโ one, or
vice versa. When this attribute is passed, it should be one of the values "qmark", "numeric", "named", "format" or
"pyformat", and should correspond to a parameter style known to be supported by the DBAPI in use."""
pool: "Optional[Union[Pool, EmptyType]]" = Empty
"""An already-constructed instance of :class:`Pool `, such as a
:class:`QueuePool ` instance. If non-None, this pool will be used directly as the
underlying connection pool for the engine, bypassing whatever connection parameters are present in the URL argument.
For information on constructing connection pools manually, see
`Connection Pooling `_."""
poolclass: "Optional[Union[type[Pool], EmptyType]]" = Empty
"""A :class:`Pool ` subclass, which will be used to create a connection pool instance using
the connection parameters given in the URL. Note this differs from pool in that you don`t actually instantiate the
pool in this case, you just indicate what type of pool to be used."""
pool_logging_name: "Union[str, EmptyType]" = Empty
"""String identifier which will be used within the โnameโ field of logging records generated within the
โsqlalchemy.poolโ logger. Defaults to a hexstring of the object`s id."""
pool_pre_ping: "Union[bool, EmptyType]" = Empty
"""If True will enable the connection pool โpre-pingโ feature that tests connections for liveness upon each
checkout."""
pool_size: "Union[int, EmptyType]" = Empty
"""The number of connections to keep open inside the connection pool. This used with
:class:`QueuePool ` as well as
:class:`SingletonThreadPool `. With
:class:`QueuePool `, a pool_size setting of ``0`` indicates no limit; to disable pooling,
set ``poolclass`` to :class:`NullPool ` instead."""
pool_recycle: "Union[int, EmptyType]" = Empty
"""This setting causes the pool to recycle connections after the given number of seconds has passed. It defaults to
``-1``, or no timeout. For example, setting to ``3600`` means connections will be recycled after one hour. Note that
MySQL in particular will disconnect automatically if no activity is detected on a connection for eight hours
(although this is configurable with the MySQLDB connection itself and the server configuration as well)."""
pool_reset_on_return: 'Union[Literal["rollback", "commit"], EmptyType]' = Empty
"""Set the :attr:`Pool.reset_on_return ` object, which can be set to the values ``"rollback"``, ``"commit"``, or
``None``."""
pool_timeout: "Union[int, EmptyType]" = Empty
"""Number of seconds to wait before giving up on getting a connection from the pool. This is only used with
:class:`QueuePool `. This can be a float but is subject to the limitations of Python time
functions which may not be reliable in the tens of milliseconds."""
pool_use_lifo: "Union[bool, EmptyType]" = Empty
"""Use LIFO (last-in-first-out) when retrieving connections from :class:`QueuePool `
instead of FIFO (first-in-first-out). Using LIFO, a server-side timeout scheme can reduce the number of connections
used during non-peak periods of use. When planning for server-side timeouts, ensure that a recycle or pre-ping
strategy is in use to gracefully handle stale connections."""
plugins: "Union[list[str], EmptyType]" = Empty
"""String list of plugin names to load. See :class:`CreateEnginePlugin ` for
background."""
query_cache_size: "Union[int, EmptyType]" = Empty
"""Size of the cache used to cache the SQL string form of queries. Set to zero to disable caching.
See :attr:`query_cache_size ` for more info.
"""
use_insertmanyvalues: "Union[bool, EmptyType]" = Empty
"""``True`` by default, use the โinsertmanyvaluesโ execution style for INSERT..RETURNING statements by default."""
python-advanced-alchemy-1.4.1/advanced_alchemy/config/sync.py 0000664 0000000 0000000 00000005355 15003544734 0024300 0 ustar 00root root 0000000 0000000 """Sync SQLAlchemy configuration module."""
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from sqlalchemy import Connection, Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker
from advanced_alchemy._listeners import set_async_context
from advanced_alchemy.config.common import GenericAlembicConfig, GenericSessionConfig, GenericSQLAlchemyConfig
if TYPE_CHECKING:
from collections.abc import Generator
from typing import Callable
__all__ = (
"AlembicSyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
)
@dataclass
class SyncSessionConfig(GenericSessionConfig[Connection, Engine, Session]):
"""Configuration for synchronous SQLAlchemy sessions."""
@dataclass
class AlembicSyncConfig(GenericAlembicConfig):
"""Configuration for Alembic's synchronous migrations.
For details see: https://alembic.sqlalchemy.org/en/latest/api/config.html
"""
@dataclass
class SQLAlchemySyncConfig(GenericSQLAlchemyConfig[Engine, Session, sessionmaker[Session]]):
"""Synchronous SQLAlchemy Configuration.
Note:
The alembic configuration options are documented in the Alembic documentation.
"""
create_engine_callable: "Callable[[str], Engine]" = create_engine
"""Callable that creates an :class:`Engine ` instance or instance of its subclass."""
session_config: SyncSessionConfig = field(default_factory=SyncSessionConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration options for the :class:`sessionmaker`."""
session_maker_class: type[sessionmaker[Session]] = sessionmaker # pyright: ignore[reportIncompatibleVariableOverride]
"""Sessionmaker class to use."""
alembic_config: AlembicSyncConfig = field(default_factory=AlembicSyncConfig)
"""Configuration for the SQLAlchemy Alembic migrations.
The configuration options are documented in the Alembic documentation.
"""
def __hash__(self) -> int:
return super().__hash__()
def __eq__(self, other: object) -> bool:
return super().__eq__(other)
@contextmanager
def get_session(self) -> "Generator[Session, None, None]":
"""Get a session context manager.
Yields:
Generator[sqlalchemy.orm.Session, None, None]: A context manager yielding an active SQLAlchemy Session.
Examples:
Using the session context manager:
>>> with config.get_session() as session:
... session.execute(...)
"""
session_maker = self.create_session_maker()
set_async_context(False) # Set context for standalone usage
with session_maker() as session:
yield session
python-advanced-alchemy-1.4.1/advanced_alchemy/config/types.py 0000664 0000000 0000000 00000001475 15003544734 0024467 0 ustar 00root root 0000000 0000000 """Type aliases and constants used in the package config."""
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Literal
from typing_extensions import TypeAlias
TypeEncodersMap: TypeAlias = Mapping[Any, Callable[[Any], Any]]
"""Type alias for a mapping of type encoders.
Maps types to their encoder functions.
"""
TypeDecodersSequence: TypeAlias = Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]
"""Type alias for a sequence of type decoders.
Each tuple contains a type check predicate and its corresponding decoder function.
"""
CommitStrategy: TypeAlias = Literal["always", "match_status"]
"""Commit strategy for SQLAlchemy sessions.
Values:
always: Always commit the session after operations
match_status: Only commit if the HTTP status code indicates success
"""
python-advanced-alchemy-1.4.1/advanced_alchemy/exceptions.py 0000664 0000000 0000000 00000030521 15003544734 0024231 0 ustar 00root root 0000000 0000000 import re
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Callable, Optional, TypedDict, Union, cast
from sqlalchemy.exc import IntegrityError as SQLAlchemyIntegrityError
from sqlalchemy.exc import InvalidRequestError as SQLAlchemyInvalidRequestError
from sqlalchemy.exc import MultipleResultsFound, SQLAlchemyError, StatementError
__all__ = (
"AdvancedAlchemyError",
"DuplicateKeyError",
"ErrorMessages",
"ForeignKeyError",
"ImproperConfigurationError",
"IntegrityError",
"MissingDependencyError",
"MultipleResultsFoundError",
"NotFoundError",
"RepositoryError",
"SerializationError",
"wrap_sqlalchemy_exception",
)
DUPLICATE_KEY_REGEXES = {
"postgresql": [
re.compile(
r"^.*duplicate\s+key.*\"(?P[^\"]+)\"\s*\n.*Key\s+\((?P.*)\)=\((?P.*)\)\s+already\s+exists.*$",
),
re.compile(r"^.*duplicate\s+key.*\"(?P[^\"]+)\"\s*\n.*$"),
],
"sqlite": [
re.compile(r"^.*columns?(?P[^)]+)(is|are)\s+not\s+unique$"),
re.compile(r"^.*UNIQUE\s+constraint\s+failed:\s+(?P.+)$"),
re.compile(r"^.*PRIMARY\s+KEY\s+must\s+be\s+unique.*$"),
],
"mysql": [
re.compile(r"^.*\b1062\b.*Duplicate entry '(?P.*)' for key '(?P[^']+)'.*$"),
re.compile(r"^.*\b1062\b.*Duplicate entry \\'(?P.*)\\' for key \\'(?P.+)\\'.*$"),
],
"oracle": [],
"spanner+spanner": [],
"duckdb": [],
"mssql": [],
"bigquery": [],
"cockroach": [],
}
FOREIGN_KEY_REGEXES = {
"postgresql": [
re.compile(
r".*on table \"(?P
.+)\" violates check constraint (?P.+)"),
],
"sqlite": [],
"mysql": [],
"oracle": [],
"spanner+spanner": [],
"duckdb": [],
"mssql": [],
"bigquery": [],
"cockroach": [],
}
class AdvancedAlchemyError(Exception):
"""Base exception class from which all Advanced Alchemy exceptions inherit."""
detail: str
def __init__(self, *args: Any, detail: str = "") -> None:
"""Initialize ``AdvancedAlchemyException``.
Args:
*args: args are converted to :class:`str` before passing to :class:`Exception`
detail: detail of the exception.
"""
str_args = [str(arg) for arg in args if arg]
if not detail:
if str_args:
detail, *str_args = str_args
elif hasattr(self, "detail"):
detail = self.detail
self.detail = detail
super().__init__(*str_args)
def __repr__(self) -> str:
if self.detail:
return f"{self.__class__.__name__} - {self.detail}"
return self.__class__.__name__
def __str__(self) -> str:
return " ".join((*self.args, self.detail)).strip()
class MissingDependencyError(AdvancedAlchemyError, ImportError):
"""Missing optional dependency.
This exception is raised when a module depends on a dependency that has not been installed.
Args:
package: Name of the missing package.
install_package: Optional alternative package name to install.
"""
def __init__(self, package: str, install_package: Optional[str] = None) -> None:
super().__init__(
f"Package {package!r} is not installed but required. You can install it by running "
f"'pip install advanced_alchemy[{install_package or package}]' to install advanced_alchemy with the required extra "
f"or 'pip install {install_package or package}' to install the package separately",
)
class ImproperConfigurationError(AdvancedAlchemyError):
"""Improper Configuration error.
This exception is raised when there is an issue with the configuration of a module.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class SerializationError(AdvancedAlchemyError):
"""Encoding or decoding error.
This exception is raised when serialization or deserialization of an object fails.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class RepositoryError(AdvancedAlchemyError):
"""Base repository exception type.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class IntegrityError(RepositoryError):
"""Data integrity error.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class DuplicateKeyError(IntegrityError):
"""Duplicate key error.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class ForeignKeyError(IntegrityError):
"""Foreign key error.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class NotFoundError(RepositoryError):
"""Not found error.
This exception is raised when a requested resource is not found.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class MultipleResultsFoundError(RepositoryError):
"""Multiple results found error.
This exception is raised when a single result was expected but multiple were found.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class InvalidRequestError(RepositoryError):
"""Invalid request error.
This exception is raised when SQLAlchemy is unable to complete the request due to a runtime error
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class ErrorMessages(TypedDict, total=False):
duplicate_key: Union[str, Callable[[Exception], str]]
integrity: Union[str, Callable[[Exception], str]]
foreign_key: Union[str, Callable[[Exception], str]]
multiple_rows: Union[str, Callable[[Exception], str]]
check_constraint: Union[str, Callable[[Exception], str]]
other: Union[str, Callable[[Exception], str]]
not_found: Union[str, Callable[[Exception], str]]
def _get_error_message(error_messages: ErrorMessages, key: str, exc: Exception) -> str:
template: Union[str, Callable[[Exception], str]] = error_messages.get(key, f"{key} error: {exc}") # type: ignore[assignment]
if callable(template): # pyright: ignore[reportUnknownArgumentType]
template = template(exc) # pyright: ignore[reportUnknownVariableType]
return template # pyright: ignore[reportUnknownVariableType]
@contextmanager
def wrap_sqlalchemy_exception( # noqa: C901, PLR0915
error_messages: Optional[ErrorMessages] = None,
dialect_name: Optional[str] = None,
wrap_exceptions: bool = True,
) -> Generator[None, None, None]:
"""Do something within context to raise a ``RepositoryError`` chained
from an original ``SQLAlchemyError``.
>>> try:
... with wrap_sqlalchemy_exception():
... raise SQLAlchemyError("Original Exception")
... except RepositoryError as exc:
... print(
... f"caught repository exception from {type(exc.__context__)}"
... )
caught repository exception from
Args:
error_messages: Error messages to use for the exception.
dialect_name: The name of the dialect to use for the exception.
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
Raises:
NotFoundError: Raised when no rows matched the specified data.
MultipleResultsFound: Raised when multiple rows matched the specified data.
IntegrityError: Raised when an integrity error occurs.
InvalidRequestError: Raised when an invalid request was made to SQLAlchemy.
RepositoryError: Raised for other SQLAlchemy errors.
AttributeError: Raised when an attribute error occurs during processing.
SQLAlchemyError: Raised for general SQLAlchemy errors.
StatementError: Raised when there is an issue processing the statement.
MultipleResultsFoundError: Raised when multiple rows matched the specified data.
"""
try:
yield
except NotFoundError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="not_found", exc=exc)
else:
msg = "No rows matched the specified data"
raise NotFoundError(detail=msg) from exc
except MultipleResultsFound as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="multiple_rows", exc=exc)
else:
msg = "Multiple rows matched the specified data"
raise MultipleResultsFoundError(detail=msg) from exc
except SQLAlchemyIntegrityError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None and dialect_name is not None:
keys_to_regex = {
"duplicate_key": (DUPLICATE_KEY_REGEXES.get(dialect_name, []), DuplicateKeyError),
"check_constraint": (CHECK_CONSTRAINT_REGEXES.get(dialect_name, []), IntegrityError),
"foreign_key": (FOREIGN_KEY_REGEXES.get(dialect_name, []), ForeignKeyError),
}
detail = " - ".join(str(exc_arg) for exc_arg in exc.orig.args) if exc.orig.args else "" # type: ignore[union-attr] # pyright: ignore[reportArgumentType,reportOptionalMemberAccess]
for key, (regexes, exception) in keys_to_regex.items():
for regex in regexes:
if (match := regex.findall(detail)) and match[0]:
raise exception(
detail=_get_error_message(error_messages=error_messages, key=key, exc=exc),
) from exc
raise IntegrityError(
detail=_get_error_message(error_messages=error_messages, key="integrity", exc=exc),
) from exc
raise IntegrityError(detail=f"An integrity error occurred: {exc}") from exc
except SQLAlchemyInvalidRequestError as exc:
if wrap_exceptions is False:
raise
raise InvalidRequestError(detail="An invalid request was made.") from exc
except StatementError as exc:
if wrap_exceptions is False:
raise
raise IntegrityError(
detail=cast("str", getattr(exc.orig, "detail", "There was an issue processing the statement."))
) from exc
except SQLAlchemyError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="other", exc=exc)
else:
msg = f"An exception occurred: {exc}"
raise RepositoryError(detail=msg) from exc
except AttributeError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="other", exc=exc)
else:
msg = f"An attribute error occurred during processing: {exc}"
raise RepositoryError(detail=msg) from exc
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/ 0000775 0000000 0000000 00000000000 15003544734 0023674 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/__init__.py 0000664 0000000 0000000 00000000000 15003544734 0025773 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/fastapi/ 0000775 0000000 0000000 00000000000 15003544734 0025323 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/fastapi/__init__.py 0000664 0000000 0000000 00000002367 15003544734 0027444 0 ustar 00root root 0000000 0000000 """FastAPI extension for Advanced Alchemy.
This module provides FastAPI integration for Advanced Alchemy, including session management,
database migrations, and service utilities.
"""
from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig
from advanced_alchemy.extensions.fastapi import providers
from advanced_alchemy.extensions.fastapi.cli import get_database_migration_plugin
from advanced_alchemy.extensions.fastapi.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.fastapi.extension import AdvancedAlchemy, assign_cli_group
__all__ = (
"AdvancedAlchemy",
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"assign_cli_group",
"base",
"exceptions",
"filters",
"get_database_migration_plugin",
"mixins",
"operations",
"providers",
"repository",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/fastapi/cli.py 0000664 0000000 0000000 00000003012 15003544734 0026440 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Optional, cast
try:
import rich_click as click
except ImportError:
import click # type: ignore[no-redef]
from advanced_alchemy.cli import add_migration_commands
if TYPE_CHECKING:
from fastapi import FastAPI
from advanced_alchemy.extensions.fastapi.extension import AdvancedAlchemy
def get_database_migration_plugin(app: "FastAPI") -> "AdvancedAlchemy": # pragma: no cover
"""Retrieve the Advanced Alchemy extension from a FastAPI application instance.
Args:
app: The FastAPI application instance.
Raises:
ImproperConfigurationError: If the Advanced Alchemy extension is not properly configured.
Returns:
The Advanced Alchemy extension instance.
"""
from advanced_alchemy.exceptions import ImproperConfigurationError
extension = cast("Optional[AdvancedAlchemy]", getattr(app.state, "advanced_alchemy", None))
if extension is None:
msg = "Failed to initialize database CLI. The Advanced Alchemy extension is not properly configured."
raise ImproperConfigurationError(msg)
return extension
def register_database_commands(app: "FastAPI") -> click.Group: # pragma: no cover
@click.group(name="database")
@click.pass_context
def database_group(ctx: click.Context) -> None:
"""Manage SQLAlchemy database components."""
ctx.ensure_object(dict)
ctx.obj["configs"] = get_database_migration_plugin(app).config
add_migration_commands(database_group)
return database_group
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/fastapi/config.py 0000664 0000000 0000000 00000000310 15003544734 0027134 0 ustar 00root root 0000000 0000000 from advanced_alchemy.extensions.starlette import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
__all__ = (
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/fastapi/extension.py 0000664 0000000 0000000 00000013414 15003544734 0027714 0 ustar 00root root 0000000 0000000 from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
overload,
)
from advanced_alchemy.extensions.fastapi.cli import register_database_commands
from advanced_alchemy.extensions.fastapi.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.starlette import AdvancedAlchemy as StarletteAdvancedAlchemy
from advanced_alchemy.service import (
Empty,
EmptyType,
ErrorMessages,
LoadSpec,
ModelT,
)
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable, Generator, Sequence
from fastapi import FastAPI
from sqlalchemy import Select
from advanced_alchemy import filters
from advanced_alchemy.extensions.fastapi.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.fastapi.providers import (
AsyncServiceT_co,
DependencyDefaults,
FilterConfig,
SyncServiceT_co,
)
__all__ = ("AdvancedAlchemy",)
def assign_cli_group(app: "FastAPI") -> None: # pragma: no cover
try:
from fastapi_cli.cli import app as fastapi_cli_app # pyright: ignore[reportUnknownVariableType]
from typer.main import get_group
except ImportError:
print("FastAPI CLI is not installed. Skipping CLI registration.") # noqa: T201
return
click_app = get_group(fastapi_cli_app) # pyright: ignore[reportUnknownArgumentType]
click_app.add_command(register_database_commands(app))
class AdvancedAlchemy(StarletteAdvancedAlchemy):
"""AdvancedAlchemy integration for FastAPI applications.
This class manages SQLAlchemy sessions and engine lifecycle within a FastAPI application.
It provides middleware for handling transactions based on commit strategies.
"""
def __init__(
self,
config: "Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig, Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]]",
app: "Optional[FastAPI]" = None,
) -> None:
super().__init__(config, app)
@overload
def provide_service(
self,
service_class: type["AsyncServiceT_co"], # pyright: ignore
/,
key: "Optional[str]" = None,
statement: "Optional[Select[tuple[ModelT]]]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> "Callable[..., AsyncGenerator[AsyncServiceT_co, None]]": ...
@overload
def provide_service(
self,
service_class: type["SyncServiceT_co"], # pyright: ignore
/,
key: "Optional[str]" = None,
statement: "Optional[Select[tuple[ModelT]]]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> "Callable[..., Generator[SyncServiceT_co, None, None]]": ...
def provide_service( # pragma: no cover
self,
service_class: type[Union["AsyncServiceT_co", "SyncServiceT_co"]],
/,
key: "Optional[str]" = None,
statement: "Optional[Select[tuple[ModelT]]]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> "Callable[..., Union[AsyncGenerator[AsyncServiceT_co, None], Generator[SyncServiceT_co, None, None]]]":
"""Provides a service instance for dependency injection.
Args:
service_class: The service class to provide.
key: Optional key for the service.
statement: Optional SQLAlchemy statement.
error_messages: Optional error messages.
load: Optional load specification.
execution_options: Optional execution options.
uniquify: Optional flag to uniquify the service.
count_with_window_function: Optional flag to use window function for counting.
Returns:
A callable that returns an async generator for async services or a generator for sync services.
"""
from advanced_alchemy.extensions.fastapi.providers import provide_service as _provide_service
return _provide_service(
service_class,
extension=self,
key=key,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
@staticmethod
def provide_filters( # pragma: no cover
config: "FilterConfig",
/,
dep_defaults: "Optional[DependencyDefaults]" = None,
) -> "Callable[..., list[filters.FilterTypes]]":
"""Provides filters for dependency injection.
Args:
config: The filters to provide.
dep_defaults: Optional key for the filters.
Returns:
A callable that returns an async generator for async filters or a generator for sync filters.
"""
from advanced_alchemy.extensions.fastapi.providers import DEPENDENCY_DEFAULTS
from advanced_alchemy.extensions.fastapi.providers import provide_filters as _provide_filters
if dep_defaults is None:
dep_defaults = DEPENDENCY_DEFAULTS
return _provide_filters(config, dep_defaults=dep_defaults)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/fastapi/providers.py 0000664 0000000 0000000 00000063525 15003544734 0027725 0 ustar 00root root 0000000 0000000 # pyright: ignore
"""Application dependency providers generators for FastAPI.
This module contains functions to create dependency providers for filters,
similar to the Litestar extension, but tailored for FastAPI.
"""
import datetime
import inspect
from collections.abc import AsyncGenerator, Generator
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Literal,
NamedTuple,
Optional,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID
from fastapi import Depends, Query
from fastapi.exceptions import RequestValidationError
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from typing_extensions import NotRequired, TypedDict
from advanced_alchemy.extensions.fastapi.extension import AdvancedAlchemy
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
FilterTypes,
LimitOffset,
NotInCollectionFilter,
OrderBy,
SearchFilter,
)
from advanced_alchemy.service import (
Empty,
EmptyType,
ErrorMessages,
LoadSpec,
ModelT,
SQLAlchemyAsyncRepositoryService,
SQLAlchemySyncRepositoryService,
)
from advanced_alchemy.utils.singleton import SingletonMeta
from advanced_alchemy.utils.text import camelize
if TYPE_CHECKING:
from advanced_alchemy.extensions.fastapi import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
T = TypeVar("T")
DTorNone = Optional[datetime.datetime]
StringOrNone = Optional[str]
UuidOrNone = Optional[str] # FastAPI doesn't automatically parse UUIDs from query params like Litestar
IntOrNone = Optional[int]
BooleanOrNone = Optional[bool]
SortOrder = Literal["asc", "desc"]
SortOrderOrNone = Optional[SortOrder]
FilterConfigValues = Union[
bool, str, list[str], type[Union[str, int]]
] # Simplified compared to Litestar's UUID/int flexibility for now
AsyncServiceT_co = TypeVar("AsyncServiceT_co", bound=SQLAlchemyAsyncRepositoryService[Any], covariant=True)
SyncServiceT_co = TypeVar("SyncServiceT_co", bound=SQLAlchemySyncRepositoryService[Any], covariant=True)
HashableValue = Union[str, int, float, bool, None]
HashableType = Union[HashableValue, tuple[Any, ...], tuple[tuple[str, Any], ...], tuple[HashableValue, ...]]
class FieldNameType(NamedTuple):
"""Type for field name and associated type information.
This allows for specifying both the field name and the expected type for filter values.
"""
name: str
"""Name of the field to filter on."""
type_hint: type[Any] = str
"""Type of the filter value. Defaults to str."""
class DependencyDefaults:
"""Default values for dependency generation."""
CREATED_FILTER_DEPENDENCY_KEY: str = "created_filter"
"""Key for the created filter dependency."""
ID_FILTER_DEPENDENCY_KEY: str = "id_filter"
"""Key for the id filter dependency."""
LIMIT_OFFSET_FILTER_DEPENDENCY_KEY: str = "limit_offset_filter"
"""Key for the limit offset dependency."""
UPDATED_FILTER_DEPENDENCY_KEY: str = "updated_filter"
"""Key for the updated filter dependency."""
ORDER_BY_FILTER_DEPENDENCY_KEY: str = "order_by_filter"
"""Key for the order by dependency."""
SEARCH_FILTER_DEPENDENCY_KEY: str = "search_filter"
"""Key for the search filter dependency."""
DEFAULT_PAGINATION_SIZE: int = 20
"""Default pagination size."""
DEPENDENCY_DEFAULTS = DependencyDefaults()
class DependencyCache(metaclass=SingletonMeta):
"""Simple dependency cache for the application. This is used to help memoize dependencies that are generated dynamically."""
def __init__(self) -> None:
self.dependencies: dict[int, Callable[[Any], list[FilterTypes]]] = {}
def add_dependencies(self, key: int, dependencies: Callable[[Any], list[FilterTypes]]) -> None:
self.dependencies[key] = dependencies
def get_dependencies(self, key: int) -> Optional[Callable[[Any], list[FilterTypes]]]:
return self.dependencies.get(key)
dep_cache = DependencyCache()
class FilterConfig(TypedDict):
"""Configuration for generating dynamic filters for FastAPI."""
id_filter: NotRequired[type[Union[UUID, int, str]]]
"""Indicates that the id filter should be enabled."""
id_field: NotRequired[str]
"""The field on the model that stored the primary key or identifier. Defaults to 'id'."""
sort_field: NotRequired[Union[str, set[str]]]
"""The default field(s) to use for the sort filter."""
sort_order: NotRequired[SortOrder]
"""The default order to use for the sort filter. Defaults to 'desc'."""
pagination_type: NotRequired[Literal["limit_offset"]]
"""When set, pagination is enabled based on the type specified."""
pagination_size: NotRequired[int]
"""The size of the pagination. Defaults to `DEFAULT_PAGINATION_SIZE`."""
search: NotRequired[Union[str, set[str]]]
"""Fields to enable search on. Can be a comma-separated string or a set of field names."""
search_ignore_case: NotRequired[bool]
"""When set, search is case insensitive by default. Defaults to False."""
created_at: NotRequired[bool]
"""When set, created_at filter is enabled. Defaults to 'created_at' field."""
updated_at: NotRequired[bool]
"""When set, updated_at filter is enabled. Defaults to 'updated_at' field."""
not_in_fields: NotRequired[Union[FieldNameType, set[FieldNameType]]]
"""Fields that support not-in collection filters. Can be a single field or a set of fields with type information."""
in_fields: NotRequired[Union[FieldNameType, set[FieldNameType]]]
"""Fields that support in-collection filters. Can be a single field or a set of fields with type information."""
@overload
def provide_service(
service_class: type["AsyncServiceT_co"],
/,
extension: AdvancedAlchemy,
key: Optional[str] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., AsyncGenerator[AsyncServiceT_co, None]]: ...
@overload
def provide_service(
service_class: type["SyncServiceT_co"],
/,
extension: AdvancedAlchemy,
key: Optional[str] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., Generator[SyncServiceT_co, None, None]]: ...
def provide_service(
service_class: type[Union["AsyncServiceT_co", "SyncServiceT_co"]],
/,
extension: AdvancedAlchemy,
key: Optional[str] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., Union[AsyncGenerator[AsyncServiceT_co, None], Generator[SyncServiceT_co, None, None]]]:
"""Create a dependency provider for a service.
Returns:
A dependency provider for the service.
"""
if issubclass(service_class, SQLAlchemyAsyncRepositoryService) or service_class is SQLAlchemyAsyncRepositoryService: # type: ignore[comparison-overlap]
async def provide_async_service(
db_session: AsyncSession = Depends(extension.provide_session(key)), # noqa: B008
) -> AsyncGenerator[AsyncServiceT_co, None]: # type: ignore[union-attr,unused-ignore]
async with service_class.new( # type: ignore[union-attr,unused-ignore]
session=db_session, # type: ignore[arg-type, unused-ignore]
statement=statement,
config=cast("Optional[SQLAlchemyAsyncConfig]", extension.get_config(key)), # type: ignore[arg-type]
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
) as service:
yield service
return provide_async_service
def provide_sync_service(
db_session: Session = Depends(extension.provide_session(key)), # noqa: B008
) -> Generator[SyncServiceT_co, None, None]:
with service_class.new(
session=db_session, # type: ignore[arg-type, unused-ignore]
statement=statement,
config=cast("Optional[SQLAlchemySyncConfig]", extension.get_config(key)),
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
) as service:
yield service
return provide_sync_service
def provide_filters(
config: FilterConfig,
dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS,
) -> Callable[..., list[FilterTypes]]:
"""Create FastAPI dependency providers for filters based on the provided configuration.
Returns:
A FastAPI dependency provider function that aggregates multiple filter dependencies.
"""
# Check if any filters are actually requested in the config
filter_keys = {
"id_filter",
"created_at",
"updated_at",
"pagination_type",
"search",
"sort_field",
"not_in_fields",
"in_fields",
}
has_filters = False
for key in filter_keys:
value = config.get(key)
if value is not None and value is not False and value != []:
has_filters = True
break
if not has_filters:
return list
# Calculate cache key using hashable version of config
cache_key = hash(_make_hashable(config))
# Check cache first
cached_dep = dep_cache.get_dependencies(cache_key)
if cached_dep is not None:
return cached_dep
dep = _create_filter_aggregate_function_fastapi(config, dep_defaults)
dep_cache.add_dependencies(cache_key, dep)
return dep
def _make_hashable(value: Any) -> HashableType:
"""Convert a value into a hashable type.
This function converts any value into a hashable type by:
- Converting dictionaries to sorted tuples of (key, value) pairs
- Converting lists and sets to sorted tuples
- Preserving primitive types (str, int, float, bool, None)
- Converting any other type to its string representation
Args:
value: Any value that needs to be made hashable.
Returns:
A hashable version of the value.
"""
if isinstance(value, dict):
# Convert dict to tuple of tuples with sorted keys
items = []
for k in sorted(value.keys()): # pyright: ignore
v = value[k] # pyright: ignore
items.append((str(k), _make_hashable(v))) # pyright: ignore
return tuple(items) # pyright: ignore
if isinstance(value, (list, set)):
hashable_items = [_make_hashable(item) for item in value] # pyright: ignore
filtered_items = [item for item in hashable_items if item is not None] # pyright: ignore
return tuple(sorted(filtered_items, key=str))
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)
def _create_filter_aggregate_function_fastapi( # noqa: C901, PLR0915
config: FilterConfig,
dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS,
) -> Callable[..., list[FilterTypes]]:
"""Create a FastAPI dependency provider function that aggregates multiple filter dependencies.
Returns:
A FastAPI dependency provider function that aggregates multiple filter dependencies.
"""
params: list[inspect.Parameter] = []
annotations: dict[str, Any] = {}
# Add id filter providers
if (id_filter := config.get("id_filter", False)) is not False:
def provide_id_filter( # pyright: ignore[reportUnknownParameterType]
ids: Annotated[ # type: ignore
Optional[list[id_filter]], # pyright: ignore
Query(
alias="ids",
required=False,
description="IDs to filter by.",
),
] = None,
) -> Optional[CollectionFilter[id_filter]]: # type: ignore
return CollectionFilter[id_filter](field_name=config.get("id_field", "id"), values=ids) if ids else None # type: ignore
params.append(
inspect.Parameter(
name=dep_defaults.ID_FILTER_DEPENDENCY_KEY,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[Optional[CollectionFilter[id_filter]], Depends(provide_id_filter)], # type: ignore
)
)
annotations[dep_defaults.ID_FILTER_DEPENDENCY_KEY] = Annotated[
Optional[CollectionFilter[id_filter]], Depends(provide_id_filter) # type: ignore
]
# Add created_at filter providers
if config.get("created_at", False):
def provide_created_at_filter(
before: Annotated[
Optional[str],
Query(
alias="createdBefore",
description="Filter by created date before this timestamp.",
json_schema_extra={"format": "date-time"},
),
] = None,
after: Annotated[
Optional[str],
Query(
alias="createdAfter",
description="Filter by created date after this timestamp.",
json_schema_extra={"format": "date-time"},
),
] = None,
) -> Optional[BeforeAfter]:
before_dt = None
after_dt = None
# Validate both parameters regardless of endpoint path
if before is not None:
try:
before_dt = datetime.datetime.fromisoformat(before.replace("Z", "+00:00"))
except (ValueError, TypeError, AttributeError) as e:
raise RequestValidationError(
errors=[{"loc": ["query", "createdBefore"], "msg": "Invalid date format"}]
) from e
if after is not None:
try:
after_dt = datetime.datetime.fromisoformat(after.replace("Z", "+00:00"))
except (ValueError, TypeError, AttributeError) as e:
raise RequestValidationError(
errors=[{"loc": ["query", "createdAfter"], "msg": "Invalid date format"}]
) from e
return (
BeforeAfter(field_name="created_at", before=before_dt, after=after_dt)
if before_dt or after_dt
else None # pyright: ignore
)
param_name = dep_defaults.CREATED_FILTER_DEPENDENCY_KEY
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[Optional[BeforeAfter], Depends(provide_created_at_filter)],
)
)
annotations[param_name] = Annotated[Optional[BeforeAfter], Depends(provide_created_at_filter)]
# Add updated_at filter providers
if config.get("updated_at", False):
def provide_updated_at_filter(
before: Annotated[
Optional[str],
Query(
alias="updatedBefore",
description="Filter by updated date before this timestamp.",
json_schema_extra={"format": "date-time"},
),
] = None,
after: Annotated[
Optional[str],
Query(
alias="updatedAfter",
description="Filter by updated date after this timestamp.",
json_schema_extra={"format": "date-time"},
),
] = None,
) -> Optional[BeforeAfter]:
before_dt = None
after_dt = None
# Validate both parameters regardless of endpoint path
if before is not None:
try:
before_dt = datetime.datetime.fromisoformat(before.replace("Z", "+00:00"))
except (ValueError, TypeError, AttributeError) as e:
raise RequestValidationError(
errors=[{"loc": ["query", "updatedBefore"], "msg": "Invalid date format"}]
) from e
if after is not None:
try:
after_dt = datetime.datetime.fromisoformat(after.replace("Z", "+00:00"))
except (ValueError, TypeError, AttributeError) as e:
raise RequestValidationError(
errors=[{"loc": ["query", "updatedAfter"], "msg": "Invalid date format"}]
) from e
return (
BeforeAfter(field_name="updated_at", before=before_dt, after=after_dt)
if before_dt or after_dt
else None # pyright: ignore
)
param_name = dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[Optional[BeforeAfter], Depends(provide_updated_at_filter)],
)
)
annotations[param_name] = Annotated[Optional[BeforeAfter], Depends(provide_updated_at_filter)]
# Add pagination filter providers
if config.get("pagination_type") == "limit_offset":
def provide_limit_offset_pagination(
current_page: Annotated[
int,
Query(
ge=1,
alias="currentPage",
description="Page number for pagination.",
),
] = 1,
page_size: Annotated[
int,
Query(
ge=1,
alias="pageSize",
description="Number of items per page.",
),
] = config.get("pagination_size", dep_defaults.DEFAULT_PAGINATION_SIZE),
) -> LimitOffset:
return LimitOffset(limit=page_size, offset=page_size * (current_page - 1))
param_name = dep_defaults.LIMIT_OFFSET_FILTER_DEPENDENCY_KEY
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[LimitOffset, Depends(provide_limit_offset_pagination)],
)
)
annotations[param_name] = Annotated[LimitOffset, Depends(provide_limit_offset_pagination)]
# Add search filter providers
if search_fields := config.get("search"):
def provide_search_filter(
search_string: Annotated[
Optional[str],
Query(
required=False,
alias="searchString",
description="Search term.",
),
] = None,
ignore_case: Annotated[
Optional[bool],
Query(
required=False,
alias="searchIgnoreCase",
description="Whether search should be case-insensitive.",
),
] = config.get("search_ignore_case", False),
) -> SearchFilter:
field_names = set(search_fields.split(",")) if isinstance(search_fields, str) else search_fields
return SearchFilter(
field_name=field_names,
value=search_string, # type: ignore[arg-type]
ignore_case=ignore_case or False,
)
param_name = dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[Optional[SearchFilter], Depends(provide_search_filter)],
)
)
annotations[param_name] = Annotated[Optional[SearchFilter], Depends(provide_search_filter)]
# Add sort filter providers
if sort_field := config.get("sort_field"):
sort_order_default = config.get("sort_order", "desc")
def provide_order_by(
field_name: Annotated[
str,
Query(
alias="orderBy",
description="Field to order by.",
required=False,
),
] = sort_field, # type: ignore[assignment]
sort_order: Annotated[
Optional[SortOrder],
Query(
alias="sortOrder",
description="Sort order ('asc' or 'desc').",
required=False,
),
] = sort_order_default,
) -> OrderBy:
return OrderBy(field_name=field_name, sort_order=sort_order or sort_order_default)
param_name = dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[OrderBy, Depends(provide_order_by)],
)
)
annotations[param_name] = Annotated[OrderBy, Depends(provide_order_by)]
# Add not_in filter providers
if not_in_fields := config.get("not_in_fields"):
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
for field_def in not_in_fields:
def create_not_in_filter_provider( # pyright: ignore
local_field_name: str,
local_field_type: type[Any],
) -> Callable[..., Optional[NotInCollectionFilter[field_def.type_hint]]]: # type: ignore
def provide_not_in_filter( # pyright: ignore
values: Annotated[ # type: ignore
Optional[set[local_field_type]], # pyright: ignore
Query(
alias=camelize(f"{local_field_name}_not_in"),
description=f"Filter {local_field_name} not in values",
),
] = None,
) -> Optional[NotInCollectionFilter[local_field_type]]: # type: ignore
return NotInCollectionFilter(field_name=local_field_name, values=values) if values else None # pyright: ignore
return provide_not_in_filter # pyright: ignore
provider = create_not_in_filter_provider(field_def.name, field_def.type_hint) # pyright: ignore
param_name = f"{field_def.name}_not_in_filter"
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[Optional[NotInCollectionFilter[field_def.type_hint]], Depends(provider)], # type: ignore
)
)
annotations[param_name] = Annotated[Optional[NotInCollectionFilter[field_def.type_hint]], Depends(provider)] # type: ignore
# Add in filter providers
if in_fields := config.get("in_fields"):
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
for field_def in in_fields:
def create_in_filter_provider( # pyright: ignore
local_field_name: str,
local_field_type: type[Any],
) -> Callable[..., Optional[CollectionFilter[field_def.type_hint]]]: # type: ignore
def provide_in_filter( # pyright: ignore
values: Annotated[ # type: ignore
Optional[set[local_field_type]], # pyright: ignore
Query(
alias=camelize(f"{local_field_name}_in"),
description=f"Filter {local_field_name} in values",
),
] = None,
) -> Optional[CollectionFilter[local_field_type]]: # type: ignore
return CollectionFilter(field_name=local_field_name, values=values) if values else None # pyright: ignore
return provide_in_filter # pyright: ignore
provider = create_in_filter_provider(field_def.name, field_def.type_hint) # type: ignore
param_name = f"{field_def.name}_in_filter"
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[Optional[CollectionFilter[field_def.type_hint]], Depends(provider)], # type: ignore
)
)
annotations[param_name] = Annotated[Optional[CollectionFilter[field_def.type_hint]], Depends(provider)] # type: ignore
_aggregate_filter_function.__signature__ = inspect.Signature( # type: ignore
parameters=params,
return_annotation=Annotated[list[FilterTypes], Depends(_aggregate_filter_function)],
)
return _aggregate_filter_function
def _aggregate_filter_function(**kwargs: Any) -> list[FilterTypes]:
filters: list[FilterTypes] = []
for filter_value in kwargs.values():
if filter_value is None:
continue
if isinstance(filter_value, list):
filters.extend(cast("list[FilterTypes]", filter_value))
elif isinstance(filter_value, SearchFilter) and filter_value.value is None: # pyright: ignore # noqa: SIM114
continue # type: ignore
elif isinstance(filter_value, OrderBy) and filter_value.field_name is None: # pyright: ignore
continue # type: ignore
else:
filters.append(cast("FilterTypes", filter_value))
return filters
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/flask/ 0000775 0000000 0000000 00000000000 15003544734 0024774 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/flask/__init__.py 0000664 0000000 0000000 00000002326 15003544734 0027110 0 ustar 00root root 0000000 0000000 """Flask extension for Advanced Alchemy.
This module provides Flask integration for Advanced Alchemy, including session management,
database migrations, and service utilities.
"""
from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig
from advanced_alchemy.extensions.flask.cli import get_database_migration_plugin
from advanced_alchemy.extensions.flask.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.flask.extension import AdvancedAlchemy
from advanced_alchemy.extensions.flask.utils import FlaskServiceMixin
__all__ = (
"AdvancedAlchemy",
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"FlaskServiceMixin",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"base",
"exceptions",
"filters",
"get_database_migration_plugin",
"mixins",
"operations",
"repository",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/flask/cli.py 0000664 0000000 0000000 00000003245 15003544734 0026121 0 ustar 00root root 0000000 0000000 """Command-line interface utilities for Flask integration.
This module provides CLI commands for database management in Flask applications.
"""
from contextlib import suppress
from typing import TYPE_CHECKING, cast
from flask.cli import with_appcontext
from advanced_alchemy.cli import add_migration_commands
try:
import rich_click as click
except ImportError:
import click # type: ignore[no-redef]
if TYPE_CHECKING:
from flask import Flask
from advanced_alchemy.extensions.flask.extension import AdvancedAlchemy
def get_database_migration_plugin(app: "Flask") -> "AdvancedAlchemy":
"""Retrieve the Advanced Alchemy extension from the Flask application.
Args:
app: The :class:`flask.Flask` application instance.
Returns:
:class:`AdvancedAlchemy`: The Advanced Alchemy extension instance.
Raises:
:exc:`advanced_alchemy.exceptions.ImproperConfigurationError`: If the extension is not found.
"""
from advanced_alchemy.exceptions import ImproperConfigurationError
with suppress(KeyError):
return cast("AdvancedAlchemy", app.extensions["advanced_alchemy"])
msg = "Failed to initialize database migrations. The Advanced Alchemy extension is not properly configured."
raise ImproperConfigurationError(msg)
@click.group(name="database")
@with_appcontext
def database_group() -> None:
"""Manage SQLAlchemy database components.
This command group provides database management commands like migrations.
"""
ctx = click.get_current_context()
app = ctx.obj.load_app()
ctx.obj = {"app": app, "configs": get_database_migration_plugin(app).config}
add_migration_commands(database_group)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/flask/config.py 0000664 0000000 0000000 00000024601 15003544734 0026616 0 ustar 00root root 0000000 0000000 """Configuration classes for Flask integration.
This module provides configuration classes for integrating SQLAlchemy with Flask applications,
including both synchronous and asynchronous database configurations.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
from click import echo
from flask import g, has_request_context
from sqlalchemy.exc import OperationalError
from typing_extensions import Literal
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.service import schema_dump
if TYPE_CHECKING:
from flask import Flask, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from advanced_alchemy.utils.portals import Portal
__all__ = ("EngineConfig", "SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig")
ConfigT = TypeVar("ConfigT", bound="Union[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig]")
def serializer(value: "Any") -> str:
"""Serialize JSON field values.
Calls the `:func:schema_dump` function to convert the value to a built-in before encoding.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(schema_dump(value))
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's Engine.
This class extends the base EngineConfig with Flask-specific JSON serialization options.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
Attributes:
json_deserializer: Callable for converting JSON strings to Python objects.
json_serializer: Callable for converting Python objects to JSON strings.
"""
json_deserializer: "Callable[[str], Any]" = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object."""
json_serializer: "Callable[[Any], str]" = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON."""
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""Flask-specific synchronous SQLAlchemy configuration.
Attributes:
app: The Flask application instance.
commit_mode: The commit mode to use for database sessions.
"""
app: "Optional[Flask]" = None
"""The Flask application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
def create_session_maker(self) -> "Callable[[], Session]":
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
def init_app(self, app: "Flask", portal: "Optional[Portal]" = None) -> None:
"""Initialize the Flask application with this configuration.
Args:
app: The Flask application instance.
portal: The portal to use for thread-safe communication. Unused in synchronous configurations.
"""
self.app = app
self.bind_key = self.bind_key or "default"
if self.create_all:
self.create_all_metadata()
if self.commit_mode != "manual":
self._setup_session_handling(app)
def _setup_session_handling(self, app: "Flask") -> None:
"""Set up the session handling for the Flask application.
Args:
app: The Flask application instance.
"""
@app.after_request
def handle_db_session(response: "Response") -> "Response": # pyright: ignore[reportUnusedFunction]
"""Commit the session if the response meets the commit criteria."""
if not has_request_context():
return response
db_session = cast("Optional[Session]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None))
if db_session is not None:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
db_session.commit()
db_session.close()
return response
def close_engines(self, portal: "Portal") -> None:
"""Close the engines.
Args:
portal: The portal to use for thread-safe communication.
"""
if self.engine_instance is not None:
self.engine_instance.dispose()
def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
with self.engine_instance.begin() as conn:
try:
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all(conn)
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
else:
echo(" * Created target metadata.")
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""Flask-specific asynchronous SQLAlchemy configuration.
Attributes:
app: The Flask application instance.
commit_mode: The commit mode to use for database sessions.
"""
app: "Optional[Flask]" = None
"""The Flask application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
def create_session_maker(self) -> "Callable[[], AsyncSession]":
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], AsyncSession]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
def init_app(self, app: "Flask", portal: "Optional[Portal]" = None) -> None:
"""Initialize the Flask application with this configuration.
Args:
app: The Flask application instance.
portal: The portal to use for thread-safe communication.
Raises:
ImproperConfigurationError: If portal is not provided for async configuration.
"""
self.app = app
self.bind_key = self.bind_key or "default"
if portal is None:
msg = "Portal is required for asynchronous configurations"
raise ImproperConfigurationError(msg)
if self.create_all:
_ = portal.call(self.create_all_metadata)
self._setup_session_handling(app, portal)
def _setup_session_handling(self, app: "Flask", portal: "Portal") -> None:
"""Set up the session handling for the Flask application.
Args:
app: The Flask application instance.
portal: The portal to use for thread-safe communication.
"""
@app.after_request
def handle_db_session(response: "Response") -> "Response": # pyright: ignore[reportUnusedFunction]
"""Commit the session if the response meets the commit criteria."""
if not has_request_context():
return response
db_session = cast("Optional[AsyncSession]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None))
if db_session is not None:
p = getattr(db_session, "_session_portal", None) or portal
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
_ = p.call(db_session.commit)
_ = p.call(db_session.close)
return response
@app.teardown_appcontext
def close_db_session(_: "Optional[BaseException]" = None) -> None: # pyright: ignore[reportUnusedFunction]
"""Close the session at the end of the request."""
db_session = cast("Optional[AsyncSession]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None))
if db_session is not None:
p = getattr(db_session, "_session_portal", None) or portal
_ = p.call(db_session.close)
def close_engines(self, portal: "Portal") -> None:
"""Close the engines.
Args:
portal: The portal to use for thread-safe communication.
"""
if self.engine_instance is not None:
_ = portal.call(self.engine_instance.dispose)
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
async with self.engine_instance.begin() as conn:
try:
await conn.run_sync(
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all
)
await conn.commit()
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
else:
echo(" * Created target metadata.")
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/flask/extension.py 0000664 0000000 0000000 00000020334 15003544734 0027364 0 ustar 00root root 0000000 0000000 # ruff: noqa: SLF001, ARG001
"""Flask extension for Advanced Alchemy."""
from collections.abc import Generator, Sequence
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
from flask import g
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from advanced_alchemy._listeners import set_async_context
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.flask.cli import database_group
from advanced_alchemy.extensions.flask.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.utils.portals import Portal, PortalProvider
if TYPE_CHECKING:
from flask import Flask
class AdvancedAlchemy:
"""Flask extension for Advanced Alchemy."""
__slots__ = (
"_config",
"_has_async_config",
"_session_makers",
"portal_provider",
)
def __init__(
self,
config: "Union[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig, Sequence[Union[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig]]]",
app: "Optional[Flask]" = None,
*,
portal_provider: "Optional[PortalProvider]" = None,
) -> None:
"""Initialize the extension."""
self.portal_provider = portal_provider if portal_provider is not None else PortalProvider()
self._config = config if isinstance(config, Sequence) else [config]
self._has_async_config = any(isinstance(c, SQLAlchemyAsyncConfig) for c in self.config)
self._session_makers: dict[str, Callable[..., Union[AsyncSession, Session]]] = {}
if app is not None:
self.init_app(app)
@property
def portal(self) -> "Portal":
"""Get the portal."""
return self.portal_provider.portal
@property
def config(self) -> "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]":
"""Get the SQLAlchemy configuration(s)."""
return self._config
@property
def is_async_enabled(self) -> bool:
"""Return True if any of the database configs are async."""
return self._has_async_config
def _is_async(self, bind_key: str) -> bool:
"""Check if the config for the given bind key is async.
Args:
bind_key: The bind key to check.
Returns:
True if the config is async, False otherwise.
"""
config = next((c for c in self.config if (c.bind_key or "default") == bind_key), None)
return isinstance(config, SQLAlchemyAsyncConfig)
def init_app(self, app: "Flask") -> None:
"""Initialize the Flask application.
Args:
app: The Flask application to initialize.
Raises:
ImproperConfigurationError: If the extension is already registered on the Flask application.
"""
if "advanced_alchemy" in app.extensions:
msg = "Advanced Alchemy extension is already registered on this Flask application."
raise ImproperConfigurationError(msg)
if self._has_async_config:
self.portal_provider.start()
# Create tables for async configs
for cfg in self._config:
if isinstance(cfg, SQLAlchemyAsyncConfig):
self.portal_provider.portal.call(cfg.create_all_metadata)
# Register shutdown handler for the portal
@app.teardown_appcontext
def shutdown_portal(exception: "Optional[BaseException]" = None) -> None: # pyright: ignore[reportUnusedFunction]
"""Stop the portal when the application shuts down."""
if not app.debug: # Don't stop portal in debug mode
with suppress(Exception):
self.portal_provider.stop()
# Initialize each config with the app
for config in self.config:
config.init_app(app, self.portal_provider.portal)
bind_key = config.bind_key if config.bind_key is not None else "default"
session_maker = config.create_session_maker()
self._session_makers[bind_key] = session_maker
# Register session cleanup only
app.teardown_appcontext(self._teardown_appcontext)
app.extensions["advanced_alchemy"] = self
app.cli.add_command(database_group)
def _teardown_appcontext(self, exception: "Optional[BaseException]" = None) -> None:
"""Clean up resources when the application context ends."""
for key in list(g):
if key.startswith("advanced_alchemy_session_"):
session = getattr(g, key)
if isinstance(session, AsyncSession):
# Close async sessions through the portal
with suppress(ImproperConfigurationError):
self.portal_provider.portal.call(session.close)
else:
session.close()
delattr(g, key)
def get_session(self, bind_key: str = "default") -> "Union[AsyncSession, Session]":
"""Get a new session from the configured session factory.
Args:
bind_key: The bind key to use for the session.
Returns:
A new session from the configured session factory.
Raises:
ImproperConfigurationError: If no session maker is found for the bind key.
"""
if bind_key == "default" and len(self.config) == 1:
bind_key = self.config[0].bind_key if self.config[0].bind_key is not None else "default"
session_key = f"advanced_alchemy_session_{bind_key}"
if hasattr(g, session_key):
return cast("Union[AsyncSession, Session]", getattr(g, session_key))
session_maker = self._session_makers.get(bind_key)
if session_maker is None:
msg = f'No session maker found for bind key "{bind_key}"'
raise ImproperConfigurationError(msg)
session = session_maker()
set_async_context(self._is_async(bind_key))
if self._has_async_config:
# Ensure portal is started
if not self.portal_provider.is_running:
self.portal_provider.start()
setattr(session, "_session_portal", self.portal_provider.portal)
setattr(g, session_key, session)
return session
def get_async_session(self, bind_key: str = "default") -> AsyncSession:
"""Get an async session from the configured session factory.
Args:
bind_key: The bind key to use for the session.
Raises:
ImproperConfigurationError: If the session is not an async session.
Returns:
An async session from the configured session factory.
"""
session = self.get_session(bind_key)
if not isinstance(session, AsyncSession):
msg = f"Expected async session for bind key {bind_key}, but got {type(session)}"
raise ImproperConfigurationError(msg)
return session
def get_sync_session(self, bind_key: str = "default") -> Session:
"""Get a sync session from the configured session factory.
Args:
bind_key: The bind key to use for the session.
Raises:
ImproperConfigurationError: If the session is not a sync session.
Returns:
A sync session from the configured session factory.
"""
session = self.get_session(bind_key)
if not isinstance(session, Session):
msg = f"Expected sync session for bind key {bind_key}, but got {type(session)}"
raise ImproperConfigurationError(msg)
return session
@contextmanager
def with_session( # pragma: no cover (more on this later)
self, bind_key: str = "default"
) -> "Generator[Union[AsyncSession, Session], None, None]":
"""Provide a transactional scope around a series of operations.
Args:
bind_key: The bind key to use for the session.
Yields:
A session.
"""
session = self.get_session(bind_key)
try:
yield session
finally:
if isinstance(session, AsyncSession):
with suppress(ImproperConfigurationError):
self.portal_provider.portal.call(session.close)
else:
session.close()
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/flask/utils.py 0000664 0000000 0000000 00000002245 15003544734 0026511 0 ustar 00root root 0000000 0000000 """Flask-specific service classes.
This module provides Flask-specific service mixins and utilities for integrating
with the Advanced Alchemy service layer.
"""
from typing import Any
from flask import Response, current_app
from advanced_alchemy.extensions.flask.config import serializer
class FlaskServiceMixin:
"""Flask service mixin.
This mixin provides Flask-specific functionality for services.
"""
def jsonify(
self,
data: Any,
*args: Any,
status_code: int = 200,
**kwargs: Any,
) -> Response:
"""Convert data to a Flask JSON response.
Args:
data: Data to serialize to JSON.
*args: Additional positional arguments passed to Flask's response class.
status_code: HTTP status code for the response. Defaults to 200.
**kwargs: Additional keyword arguments passed to Flask's response class.
Returns:
:class:`flask.Response`: A Flask response with JSON content type.
"""
return current_app.response_class(
serializer(data),
status=status_code,
mimetype="application/json",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/ 0000775 0000000 0000000 00000000000 15003544734 0025523 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/__init__.py 0000664 0000000 0000000 00000005302 15003544734 0027634 0 ustar 00root root 0000000 0000000 from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig
from advanced_alchemy.extensions.litestar import providers
from advanced_alchemy.extensions.litestar.cli import get_database_migration_plugin
from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig
from advanced_alchemy.extensions.litestar.plugins import (
EngineConfig,
SQLAlchemyAsyncConfig,
SQLAlchemyInitPlugin,
SQLAlchemyPlugin,
SQLAlchemySerializationPlugin,
SQLAlchemySyncConfig,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
autocommit_before_send_handler as async_autocommit_before_send_handler,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
autocommit_handler_maker as async_autocommit_handler_maker,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
default_before_send_handler as async_default_before_send_handler,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
default_handler_maker as async_default_handler_maker,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
autocommit_before_send_handler as sync_autocommit_before_send_handler,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
autocommit_handler_maker as sync_autocommit_handler_maker,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
default_before_send_handler as sync_default_before_send_handler,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
default_handler_maker as sync_default_handler_maker,
)
__all__ = (
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemyDTO",
"SQLAlchemyDTOConfig",
"SQLAlchemyInitPlugin",
"SQLAlchemyPlugin",
"SQLAlchemySerializationPlugin",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"async_autocommit_before_send_handler",
"async_autocommit_handler_maker",
"async_default_before_send_handler",
"async_default_handler_maker",
"base",
"exceptions",
"filters",
"get_database_migration_plugin",
"mixins",
"operations",
"providers",
"repository",
"service",
"sync_autocommit_before_send_handler",
"sync_autocommit_handler_maker",
"sync_default_before_send_handler",
"sync_default_handler_maker",
"types",
"utils",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/_utils.py 0000664 0000000 0000000 00000003637 15003544734 0027405 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from litestar.types import Scope
__all__ = (
"delete_aa_scope_state",
"get_aa_scope_state",
"set_aa_scope_state",
)
_SCOPE_NAMESPACE = "_aa_connection_state"
def get_aa_scope_state(scope: "Scope", key: str, default: Any = None, pop: bool = False) -> Any:
"""Get an internal value from connection scope state.
Note:
If called with a default value, this method behaves like to `dict.set_default()`, both setting the key in the
namespace to the default value, and returning it.
If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
exist.
Args:
scope: The connection scope.
key: Key to get from internal namespace in scope state.
default: Default value to return.
pop: Boolean flag dictating whether the value should be deleted from the state.
Returns:
Value mapped to ``key`` in internal connection scope namespace.
"""
namespace = scope.setdefault(_SCOPE_NAMESPACE, {}) # type: ignore[misc]
return namespace.pop(key, default) if pop else namespace.get(key, default) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
def set_aa_scope_state(scope: "Scope", key: str, value: Any) -> None:
"""Set an internal value in connection scope state.
Args:
scope: The connection scope.
key: Key to set under internal namespace in scope state.
value: Value for key.
"""
scope.setdefault(_SCOPE_NAMESPACE, {})[key] = value # type: ignore[misc]
def delete_aa_scope_state(scope: "Scope", key: str) -> None:
"""Delete an internal value from connection scope state.
Args:
scope: The connection scope.
key: Key to set under internal namespace in scope state.
"""
del scope.setdefault(_SCOPE_NAMESPACE, {})[key] # type: ignore[misc]
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/cli.py 0000664 0000000 0000000 00000002714 15003544734 0026650 0 ustar 00root root 0000000 0000000 from contextlib import suppress
from typing import TYPE_CHECKING
from litestar.cli._utils import LitestarGroup # pyright: ignore
from advanced_alchemy.cli import add_migration_commands
try:
import rich_click as click
except ImportError:
import click # type: ignore[no-redef]
if TYPE_CHECKING:
from litestar import Litestar
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin
def get_database_migration_plugin(app: "Litestar") -> "SQLAlchemyInitPlugin":
"""Retrieve a database migration plugin from the Litestar application's plugins.
Args:
app: The Litestar application
Returns:
The database migration plugin
Raises:
ImproperConfigurationError: If the database migration plugin is not found
"""
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin
with suppress(KeyError):
return app.plugins.get(SQLAlchemyInitPlugin)
msg = "Failed to initialize database migrations. The required plugin (SQLAlchemyPlugin or SQLAlchemyInitPlugin) is missing."
raise ImproperConfigurationError(msg)
@click.group(cls=LitestarGroup, name="database")
def database_group(ctx: "click.Context") -> None:
"""Manage SQLAlchemy database components."""
ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
add_migration_commands(database_group)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/dto.py 0000664 0000000 0000000 00000045374 15003544734 0026700 0 ustar 00root root 0000000 0000000 from collections.abc import Collection, Generator
from collections.abc import Set as AbstractSet
from dataclasses import asdict, dataclass, field, replace
from functools import singledispatchmethod
from typing import (
Any,
ClassVar,
Generic,
Literal,
Optional,
Union,
)
from litestar.dto.base_dto import AbstractDTO
from litestar.dto.config import DTOConfig
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField, Mark
from litestar.types.empty import Empty
from litestar.typing import FieldDefinition
from litestar.utils.signature import ParsedSignature
from sqlalchemy import Column, inspect, orm, sql
from sqlalchemy.ext.associationproxy import AssociationProxy, AssociationProxyExtensionType
from sqlalchemy.ext.hybrid import HybridExtensionType, hybrid_property
from sqlalchemy.orm import (
ColumnProperty,
CompositeProperty,
DeclarativeBase,
DynamicMapped,
InspectionAttr,
InstrumentedAttribute,
Mapped,
MappedColumn,
NotExtension,
QueryableAttribute,
Relationship,
RelationshipDirection,
RelationshipProperty,
WriteOnlyMapped,
)
from sqlalchemy.sql.expression import ColumnClause, Label
from typing_extensions import TypeAlias, TypeVar
from advanced_alchemy.exceptions import ImproperConfigurationError
__all__ = ("SQLAlchemyDTO",)
T = TypeVar("T", bound="Union[DeclarativeBase, Collection[DeclarativeBase]]")
ElementType: TypeAlias = Union[
"Column[Any]", "RelationshipProperty[Any]", "CompositeProperty[Any]", "ColumnClause[Any]", "Label[Any]"
]
SQLA_NS = {**vars(orm), **vars(sql)}
@dataclass(frozen=True)
class SQLAlchemyDTOConfig(DTOConfig):
"""Additional controls for the generated SQLAlchemy DTO."""
exclude: AbstractSet[Union[str, InstrumentedAttribute[Any]]] = field(default_factory=set) # type: ignore[assignment] # pyright: ignore[reportIncompatibleVariableOverride]
"""Explicitly exclude fields from the generated DTO.
If exclude is specified, all fields not specified in exclude will be included by default.
Notes:
- The field names are dot-separated paths to nested fields, e.g. ``"address.street"`` will
exclude the ``"street"`` field from a nested ``"address"`` model.
- 'exclude' mutually exclusive with 'include' - specifying both values will raise an
``ImproperlyConfiguredException``.
"""
include: AbstractSet[Union[str, InstrumentedAttribute[Any]]] = field(default_factory=set) # type: ignore[assignment] # pyright: ignore[reportIncompatibleVariableOverride]
"""Explicitly include fields in the generated DTO.
If include is specified, all fields not specified in include will be excluded by default.
Notes:
- The field names are dot-separated paths to nested fields, e.g. ``"address.street"`` will
include the ``"street"`` field from a nested ``"address"`` model.
- 'include' mutually exclusive with 'exclude' - specifying both values will raise an
``ImproperlyConfiguredException``.
"""
rename_fields: dict[Union[str, InstrumentedAttribute[Any]], str] = field(default_factory=dict) # type: ignore[assignment] # pyright: ignore[reportIncompatibleVariableOverride]
"""Mapping of field names, to new name."""
include_implicit_fields: Union[bool, Literal["hybrid-only"]] = True
"""Fields that are implicitly mapped are included.
Turning this off will lead to exclude all fields not using ``Mapped`` annotation,
When setting this to ``hybrid-only``, all implicitly mapped fields are excluded
with the exception for hybrid properties.
"""
def __post_init__(self) -> None:
super().__post_init__()
object.__setattr__(
self, "exclude", {f.key if isinstance(f, InstrumentedAttribute) else f for f in self.exclude}
)
object.__setattr__(
self, "include", {f.key if isinstance(f, InstrumentedAttribute) else f for f in self.include}
)
object.__setattr__(
self,
"rename_fields",
{f.key if isinstance(f, InstrumentedAttribute) else f: v for f, v in self.rename_fields.items()},
)
class SQLAlchemyDTO(AbstractDTO[T], Generic[T]):
"""Support for domain modelling with SQLAlchemy."""
config: ClassVar[SQLAlchemyDTOConfig]
@staticmethod
def _ensure_sqla_dto_config(config: Union[DTOConfig, SQLAlchemyDTOConfig]) -> SQLAlchemyDTOConfig:
if not isinstance(config, SQLAlchemyDTOConfig):
return SQLAlchemyDTOConfig(**asdict(config))
return config
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if hasattr(cls, "config"):
cls.config = cls._ensure_sqla_dto_config(cls.config) # pyright: ignore[reportIncompatibleVariableOverride]
@singledispatchmethod
@classmethod
def handle_orm_descriptor(
cls,
extension_type: Union[NotExtension, AssociationProxyExtensionType, HybridExtensionType],
orm_descriptor: InspectionAttr,
key: str,
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
msg = f"Unsupported extension type: {extension_type}"
raise NotImplementedError(msg)
@handle_orm_descriptor.register(NotExtension)
@classmethod
def _(
cls,
extension_type: NotExtension,
key: str,
orm_descriptor: InspectionAttr,
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
if not isinstance(orm_descriptor, QueryableAttribute): # pragma: no cover
msg = f"Unexpected descriptor type for '{extension_type}': '{orm_descriptor}'"
raise NotImplementedError(msg)
elem: ElementType
if isinstance(
orm_descriptor.property, # pyright: ignore[reportUnknownMemberType]
ColumnProperty, # pragma: no cover
):
if not isinstance(
orm_descriptor.property.expression, # pyright: ignore[reportUnknownMemberType]
(Column, ColumnClause, Label),
):
msg = f"Expected 'Column', got: '{orm_descriptor.property.expression}, {type(orm_descriptor.property.expression)}'" # pyright: ignore[reportUnknownMemberType]
raise NotImplementedError(msg)
elem = orm_descriptor.property.expression # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
elif isinstance(orm_descriptor.property, (RelationshipProperty, CompositeProperty)): # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
elem = orm_descriptor.property # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
else: # pragma: no cover
msg = f"Unhandled property type: '{orm_descriptor.property}'" # pyright: ignore[reportUnknownMemberType]
raise NotImplementedError(msg)
default, default_factory = _detect_defaults(elem) # pyright: ignore
try:
if (field_definition := model_type_hints[key]).origin in {
Mapped,
WriteOnlyMapped,
DynamicMapped,
Relationship,
}:
(field_definition,) = field_definition.inner_types
else: # pragma: no cover
msg = f"Expected 'Mapped' origin, got: '{field_definition.origin}'"
raise NotImplementedError(msg)
except KeyError:
field_definition = parse_type_from_element(elem, orm_descriptor) # pyright: ignore
dto_field = elem.info.get(DTO_FIELD_META_KEY) if hasattr(elem, "info") else None # pyright: ignore
if dto_field is None and isinstance(orm_descriptor, InstrumentedAttribute) and hasattr(orm_descriptor, "info"): # pyright: ignore
dto_field = orm_descriptor.info.get(DTO_FIELD_META_KEY) # pyright: ignore
if dto_field is None:
dto_field = DTOField()
return [
DTOFieldDefinition.from_field_definition(
field_definition=replace(
field_definition,
name=key,
default=default,
),
default_factory=default_factory,
dto_field=dto_field,
model_name=model_name,
),
]
@handle_orm_descriptor.register(AssociationProxyExtensionType)
@classmethod
def _(
cls,
extension_type: AssociationProxyExtensionType,
key: str,
orm_descriptor: InspectionAttr,
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
if not isinstance(orm_descriptor, AssociationProxy): # pragma: no cover
msg = f"Unexpected descriptor type '{orm_descriptor}' for '{extension_type}'"
raise NotImplementedError(msg)
if (field_definition := model_type_hints[key]).origin is AssociationProxy:
(field_definition,) = field_definition.inner_types
else: # pragma: no cover
msg = f"Expected 'AssociationProxy' origin, got: '{field_definition.origin}'"
raise NotImplementedError(msg)
return [
DTOFieldDefinition.from_field_definition(
field_definition=replace(
field_definition,
name=key,
default=Empty,
),
default_factory=None,
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.READ_ONLY)),
model_name=model_name,
),
]
@handle_orm_descriptor.register(HybridExtensionType)
@classmethod
def _(
cls,
extension_type: HybridExtensionType,
key: str,
orm_descriptor: InspectionAttr,
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
if not isinstance(orm_descriptor, hybrid_property):
msg = f"Unexpected descriptor type '{orm_descriptor}' for '{extension_type}'"
raise NotImplementedError(msg)
getter_sig = ParsedSignature.from_fn(orm_descriptor.fget, {}) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType,reportAttributeAccessIssue]
field_defs = [
DTOFieldDefinition.from_field_definition(
field_definition=replace(
getter_sig.return_type,
name=orm_descriptor.__name__,
default=Empty,
),
default_factory=None,
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.READ_ONLY)),
model_name=model_name,
),
]
if orm_descriptor.fset is not None: # pyright: ignore[reportUnknownMemberType]
setter_sig = ParsedSignature.from_fn(orm_descriptor.fset, {}) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType]
field_defs.append(
DTOFieldDefinition.from_field_definition(
field_definition=replace(
next(iter(setter_sig.parameters.values())),
name=orm_descriptor.__name__,
default=Empty,
),
default_factory=None,
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.WRITE_ONLY)),
model_name=model_name,
),
)
return field_defs
@classmethod
def generate_field_definitions(cls, model_type: type[DeclarativeBase]) -> Generator[DTOFieldDefinition, None, None]:
"""Generate DTO field definitions from a SQLAlchemy model.
Args:
model_type (typing.Type[sqlalchemy.orm.DeclarativeBase]): The SQLAlchemy model type to generate field definitions from.
Yields:
collections.abc.Generator[litestar.dto.data_structures.DTOFieldDefinition, None, None]: A generator yielding DTO field definitions.
Raises:
RuntimeError: If the mapper cannot be found for the model type.
"""
if (mapper := inspect(model_type)) is None: # pragma: no cover # pyright: ignore[reportUnnecessaryComparison]
msg = "Unexpected `None` value for mapper." # type: ignore[unreachable]
raise RuntimeError(msg)
# includes SQLAlchemy names and other mapped class names in the forward reference resolution namespace
namespace = {**SQLA_NS, **{m.class_.__name__: m.class_ for m in mapper.registry.mappers if m is not mapper}}
model_type_hints = cls.get_model_type_hints(model_type, namespace=namespace)
model_name = model_type.__name__
include_implicit_fields = cls.config.include_implicit_fields
# the same hybrid property descriptor can be included in `all_orm_descriptors` multiple times, once
# for each method name it is bound to. We only need to see it once, so track views of it here.
seen_hybrid_descriptors: set[hybrid_property] = set() # pyright: ignore[reportUnknownVariableType,reportMissingTypeArgument]
skipped_descriptors: set[str] = set()
for composite_property in mapper.composites: # pragma: no cover
for attr in composite_property.attrs:
if isinstance(attr, (MappedColumn, Column)):
skipped_descriptors.add(attr.name)
elif isinstance(attr, str):
skipped_descriptors.add(attr)
for key, orm_descriptor in mapper.all_orm_descriptors.items():
if is_hybrid_property := isinstance(orm_descriptor, hybrid_property):
if orm_descriptor in seen_hybrid_descriptors:
continue
seen_hybrid_descriptors.add(orm_descriptor) # pyright: ignore[reportUnknownMemberType]
if key in skipped_descriptors:
continue
should_skip_descriptor = False
dto_field: Optional[DTOField] = None
if hasattr(orm_descriptor, "property"): # pyright: ignore[reportUnknownArgumentType]
dto_field = orm_descriptor.property.info.get(DTO_FIELD_META_KEY) # pyright: ignore
# Case 1
is_field_marked_not_private = dto_field and dto_field.mark is not Mark.PRIVATE # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
# Case 2
should_exclude_anything_implicit = not include_implicit_fields and key not in model_type_hints
# Case 3
should_exclude_non_hybrid_only = (
not is_hybrid_property and include_implicit_fields == "hybrid-only" and key not in model_type_hints
)
# Descriptor is marked with with either Mark.READ_ONLY or Mark.WRITE_ONLY (see Case 1):
# - always include it regardless of anything else.
# Descriptor is not marked:
# - It's implicit BUT config excludes anything implicit (see Case 2): exclude
# - It's implicit AND not hybrid BUT config includes hybrid-only implicit descriptors (Case 3): exclude
should_skip_descriptor = not is_field_marked_not_private and (
should_exclude_anything_implicit or should_exclude_non_hybrid_only
)
if should_skip_descriptor:
continue
yield from cls.handle_orm_descriptor(
orm_descriptor.extension_type,
key,
orm_descriptor,
model_type_hints,
model_name,
)
@classmethod
def detect_nested_field(cls, field_definition: FieldDefinition) -> bool:
return field_definition.is_subclass_of(DeclarativeBase)
def _detect_defaults(elem: ElementType) -> tuple[Any, Any]:
default: Any = Empty
default_factory: Any = None # pyright:ignore
if sqla_default := getattr(elem, "default", None):
if sqla_default.is_scalar:
default = sqla_default.arg
elif sqla_default.is_callable:
def default_factory(d: Any = sqla_default) -> Any:
return d.arg({})
elif sqla_default.is_sequence or sqla_default.is_sentinel:
# SQLAlchemy sequences represent server side defaults
# so we cannot infer a reasonable default value for
# them on the client side
pass
else:
msg = "Unexpected default type"
raise ValueError(msg)
elif (isinstance(elem, RelationshipProperty) and detect_nullable_relationship(elem)) or getattr(
elem, "nullable", False
):
default = None
return default, default_factory
def parse_type_from_element(elem: ElementType, orm_descriptor: InspectionAttr) -> FieldDefinition: # noqa: PLR0911
"""Parses a type from a SQLAlchemy element.
Args:
elem: The SQLAlchemy element to parse.
orm_descriptor: The attribute `elem` was extracted from.
Raises:
ImproperConfigurationError: If the type cannot be parsed.
Returns:
FieldDefinition: The parsed type.
"""
if isinstance(elem, Column):
if elem.nullable:
return FieldDefinition.from_annotation(Optional[elem.type.python_type])
return FieldDefinition.from_annotation(elem.type.python_type)
if isinstance(elem, RelationshipProperty):
if elem.direction in {RelationshipDirection.ONETOMANY, RelationshipDirection.MANYTOMANY}:
collection_type = FieldDefinition.from_annotation(elem.collection_class or list) # pyright: ignore[reportUnknownMemberType]
return FieldDefinition.from_annotation(collection_type.safe_generic_origin[elem.mapper.class_])
if detect_nullable_relationship(elem):
return FieldDefinition.from_annotation(Optional[elem.mapper.class_])
return FieldDefinition.from_annotation(elem.mapper.class_)
if isinstance(elem, CompositeProperty):
return FieldDefinition.from_annotation(elem.composite_class)
if isinstance(orm_descriptor, InstrumentedAttribute):
return FieldDefinition.from_annotation(orm_descriptor.type.python_type)
msg = f"Unable to parse type from element '{elem}'. Consider adding a type hint."
raise ImproperConfigurationError(msg)
def detect_nullable_relationship(elem: RelationshipProperty[Any]) -> bool:
"""Detects if a relationship is nullable.
This attempts to decide if we should allow a ``None`` default value for a relationship by looking at the
foreign key fields. If all foreign key fields are nullable, then we allow a ``None`` default value.
Args:
elem: The relationship to check.
Returns:
bool: ``True`` if the relationship is nullable, ``False`` otherwise.
"""
return elem.direction == RelationshipDirection.MANYTOONE and all(c.nullable for c in elem.local_columns)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/exception_handler.py 0000664 0000000 0000000 00000003305 15003544734 0031571 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any
from litestar.connection import Request
from litestar.connection.base import AuthT, StateT, UserT
from litestar.exceptions import (
ClientException,
HTTPException,
InternalServerException,
NotFoundException,
)
from litestar.exceptions.responses import (
create_debug_response, # pyright: ignore[reportUnknownVariableType]
create_exception_response, # pyright: ignore[reportUnknownVariableType]
)
from litestar.response import Response
from litestar.status_codes import (
HTTP_409_CONFLICT,
)
from advanced_alchemy.exceptions import (
DuplicateKeyError,
ForeignKeyError,
IntegrityError,
NotFoundError,
RepositoryError,
)
if TYPE_CHECKING:
from litestar.connection import Request
from litestar.connection.base import AuthT, StateT, UserT
from litestar.response import Response
class ConflictError(ClientException):
"""Request conflict with the current state of the target resource."""
status_code: int = HTTP_409_CONFLICT
def exception_to_http_response(request: "Request[UserT, AuthT, StateT]", exc: "RepositoryError") -> "Response[Any]":
"""Handler for all exceptions subclassed from HTTPException."""
if isinstance(exc, NotFoundError):
http_exc: type[HTTPException] = NotFoundException
elif isinstance(exc, (DuplicateKeyError, IntegrityError, ForeignKeyError)):
http_exc = ConflictError
else:
http_exc = InternalServerException
if request.app.debug:
return create_debug_response(request, exc) # pyright: ignore[reportUnknownVariableType]
return create_exception_response(request, http_exc(detail=str(exc.detail))) # pyright: ignore[reportUnknownVariableType]
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/ 0000775 0000000 0000000 00000000000 15003544734 0027204 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/__init__.py 0000664 0000000 0000000 00000003306 15003544734 0031317 0 ustar 00root root 0000000 0000000 from collections.abc import Sequence
from typing import Union
from litestar.config.app import AppConfig
from litestar.plugins import InitPluginProtocol
from advanced_alchemy.extensions.litestar.plugins import _slots_base
from advanced_alchemy.extensions.litestar.plugins.init import (
EngineConfig,
SQLAlchemyAsyncConfig,
SQLAlchemyInitPlugin,
SQLAlchemySyncConfig,
)
from advanced_alchemy.extensions.litestar.plugins.serialization import SQLAlchemySerializationPlugin
class SQLAlchemyPlugin(InitPluginProtocol, _slots_base.SlotsBase):
"""A plugin that provides SQLAlchemy integration."""
def __init__(
self,
config: Union[
SQLAlchemyAsyncConfig, SQLAlchemySyncConfig, Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]
],
) -> None:
"""Initialize ``SQLAlchemyPlugin``.
Args:
config: configure DB connection and hook handlers and dependencies.
"""
self._config = config if isinstance(config, Sequence) else [config]
@property
def config(
self,
) -> Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]:
return self._config
def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""Configure application for use with SQLAlchemy.
Args:
app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
"""
app_config.plugins.extend([SQLAlchemyInitPlugin(config=self._config), SQLAlchemySerializationPlugin()])
return app_config
__all__ = (
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemyInitPlugin",
"SQLAlchemyPlugin",
"SQLAlchemySerializationPlugin",
"SQLAlchemySyncConfig",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/_slots_base.py 0000664 0000000 0000000 00000000434 15003544734 0032054 0 ustar 00root root 0000000 0000000 """Base class that aggregates slots for all SQLAlchemy plugins.
See: https://stackoverflow.com/questions/53060607/python-3-6-5-multiple-bases-have-instance-lay-out-conflict-when-multi-inherit
"""
class SlotsBase:
__slots__ = (
"_config",
"_type_dto_map",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/ 0000775 0000000 0000000 00000000000 15003544734 0030147 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/__init__.py 0000664 0000000 0000000 00000000542 15003544734 0032261 0 ustar 00root root 0000000 0000000 from advanced_alchemy.extensions.litestar.plugins.init.config import (
EngineConfig,
SQLAlchemyAsyncConfig,
SQLAlchemySyncConfig,
)
from advanced_alchemy.extensions.litestar.plugins.init.plugin import SQLAlchemyInitPlugin
__all__ = (
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemyInitPlugin",
"SQLAlchemySyncConfig",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/config/ 0000775 0000000 0000000 00000000000 15003544734 0031414 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/config/__init__.py 0000664 0000000 0000000 00000000567 15003544734 0033535 0 ustar 00root root 0000000 0000000 from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.engine import EngineConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import SQLAlchemySyncConfig
__all__ = (
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/config/asyncio.py 0000664 0000000 0000000 00000027031 15003544734 0033436 0 ustar 00root root 0000000 0000000 from collections.abc import AsyncGenerator, Coroutine
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
from litestar.cli._utils import console # pyright: ignore
from litestar.constants import HTTP_RESPONSE_START
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.extensions.litestar._utils import (
delete_aa_scope_state,
get_aa_scope_state,
set_aa_scope_state,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.common import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.engine import EngineConfig
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Coroutine
from litestar import Litestar
from litestar.datastructures.state import State
from litestar.types import BeforeMessageSendHookHandler, Message, Scope
# noinspection PyUnresolvedReferences
__all__ = (
"SQLAlchemyAsyncConfig",
"autocommit_before_send_handler",
"autocommit_handler_maker",
"default_before_send_handler",
"default_handler_maker",
)
def default_handler_maker(
session_scope_key: str = SESSION_SCOPE_KEY,
) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
session_scope_key: The key to use within the application state
Returns:
The handler callable
"""
async def handler(message: "Message", scope: "Scope") -> None:
"""Handle commit/rollback, closing and cleaning up sessions before sending.
Args:
message: ASGI-``Message``
scope: An ASGI-``Scope``
"""
session = cast("Optional[AsyncSession]", get_aa_scope_state(scope, session_scope_key))
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
await session.close()
delete_aa_scope_state(scope, session_scope_key)
return handler
default_before_send_handler = default_handler_maker()
def autocommit_handler_maker(
commit_on_redirect: bool = False,
extra_commit_statuses: Optional[set[int]] = None,
extra_rollback_statuses: Optional[set[int]] = None,
session_scope_key: str = SESSION_SCOPE_KEY,
) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
extra_commit_statuses: A set of additional status codes that trigger a commit
extra_rollback_statuses: A set of additional status codes that trigger a rollback
session_scope_key: The key to use within the application state
Raises:
ValueError: If the extra commit statuses and extra rollback statuses share any status codes
Returns:
The handler callable
"""
if extra_commit_statuses is None:
extra_commit_statuses = set()
if extra_rollback_statuses is None:
extra_rollback_statuses = set()
if len(extra_commit_statuses & extra_rollback_statuses) > 0:
msg = "Extra rollback statuses and commit statuses must not share any status codes"
raise ValueError(msg)
commit_range = range(200, 400 if commit_on_redirect else 300)
async def handler(message: "Message", scope: "Scope") -> None:
"""Handle commit/rollback, closing and cleaning up sessions before sending.
Args:
message: ASGI-``litestar.types.Message``
scope: An ASGI-``litestar.types.Scope``
"""
session = cast("Optional[AsyncSession]", get_aa_scope_state(scope, session_scope_key))
try:
if session is not None and message["type"] == HTTP_RESPONSE_START:
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
"status"
] not in extra_rollback_statuses:
await session.commit()
else:
await session.rollback()
finally:
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
await session.close()
delete_aa_scope_state(scope, session_scope_key)
return handler
autocommit_before_send_handler = autocommit_handler_maker()
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""Litestar Async SQLAlchemy Configuration."""
before_send_handler: Optional[
Union["BeforeMessageSendHookHandler", Literal["autocommit", "autocommit_include_redirects"]]
] = None
"""Handler to call before the ASGI message is sent.
The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and
uncommitted data.
"""
engine_dependency_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_dependency_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
engine_app_state_key: str = "db_engine"
"""Key under which to store the SQLAlchemy engine in the application :class:`State `
instance.
"""
session_maker_app_state_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application
:class:`State ` instance.
"""
session_scope_key: str = SESSION_SCOPE_KEY
"""Key under which to store the SQLAlchemy scope in the application."""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
set_default_exception_handler: bool = True
"""Sets the default exception handler on application start."""
def _ensure_unique(self, registry_name: str, key: str, new_key: Optional[str] = None, _iter: int = 0) -> str:
new_key = new_key if new_key is not None else key
if new_key in getattr(self.__class__, registry_name, {}):
_iter += 1
new_key = self._ensure_unique(registry_name, key, f"{key}_{_iter}", _iter)
return new_key
def __post_init__(self) -> None:
self.session_scope_key = self._ensure_unique("_SESSION_SCOPE_KEY_REGISTRY", self.session_scope_key)
self.engine_app_state_key = self._ensure_unique("_ENGINE_APP_STATE_KEY_REGISTRY", self.engine_app_state_key)
self.session_maker_app_state_key = self._ensure_unique(
"_SESSIONMAKER_APP_STATE_KEY_REGISTRY",
self.session_maker_app_state_key,
)
self.__class__._SESSION_SCOPE_KEY_REGISTRY.add(self.session_scope_key) # noqa: SLF001
self.__class__._ENGINE_APP_STATE_KEY_REGISTRY.add(self.engine_app_state_key) # noqa: SLF001
self.__class__._SESSIONMAKER_APP_STATE_KEY_REGISTRY.add(self.session_maker_app_state_key) # noqa: SLF001
if self.before_send_handler is None:
self.before_send_handler = default_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit":
self.before_send_handler = autocommit_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit_include_redirects":
self.before_send_handler = autocommit_handler_maker(
session_scope_key=self.session_scope_key,
commit_on_redirect=True,
)
super().__post_init__()
def create_session_maker(self) -> "Callable[[], AsyncSession]":
"""Get a session maker. If none exists yet, create one.
Returns:
Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if session_kws.get("bind") is None:
session_kws["bind"] = self.get_engine()
return self.session_maker_class(**session_kws) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
@asynccontextmanager
async def lifespan(
self,
app: "Litestar",
) -> "AsyncGenerator[None, None]":
deps = self.create_app_state_items()
app.state.update(deps)
try:
if self.create_all:
await self.create_all_metadata(app)
yield
finally:
if self.engine_dependency_key in deps:
engine = deps[self.engine_dependency_key]
if hasattr(engine, "dispose"):
await cast("AsyncEngine", engine).dispose()
def provide_engine(self, state: "State") -> "AsyncEngine":
"""Create an engine instance.
Args:
state: The ``Litestar.state`` instance.
Returns:
An engine instance.
"""
return cast("AsyncEngine", state.get(self.engine_app_state_key))
def provide_session(self, state: "State", scope: "Scope") -> "AsyncSession":
"""Create a session instance.
Args:
state: The ``Litestar.state`` instance.
scope: The current connection's scope.
Returns:
A session instance.
"""
# Import locally to avoid potential circular dependency issues at module level
from advanced_alchemy._listeners import set_async_context
session = cast("Optional[AsyncSession]", get_aa_scope_state(scope, self.session_scope_key))
if session is None:
session_maker = cast("Callable[[], AsyncSession]", state[self.session_maker_app_state_key])
session = session_maker()
set_aa_scope_state(scope, self.session_scope_key, session)
set_async_context(True) # Set context before yielding
return session
@property
def signature_namespace(self) -> dict[str, Any]:
"""Return the plugin's signature namespace.
Returns:
A string keyed dict of names to be added to the namespace for signature forward reference resolution.
"""
return {"AsyncEngine": AsyncEngine, "AsyncSession": AsyncSession}
async def create_all_metadata(self, app: "Litestar") -> None:
"""Create all metadata
Args:
app (Litestar): The ``Litestar`` instance
"""
async with self.get_engine().begin() as conn:
try:
await conn.run_sync(metadata_registry.get(self.bind_key).create_all)
except OperationalError as exc:
console.print(f"[bold red] * Could not create target metadata. Reason: {exc}")
def create_app_state_items(self) -> dict[str, Any]:
"""Key/value pairs to be stored in application state.
Returns:
A dictionary of key/value pairs to be stored in application state.
"""
return {
self.engine_app_state_key: self.get_engine(),
self.session_maker_app_state_key: self.create_session_maker(),
}
def update_app_state(self, app: "Litestar") -> None:
"""Set the app state with engine and session.
Args:
app: The ``Litestar`` instance.
"""
app.state.update(self.create_app_state_items())
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/config/common.py 0000664 0000000 0000000 00000000521 15003544734 0033254 0 ustar 00root root 0000000 0000000 from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
SESSION_SCOPE_KEY = "_sqlalchemy_db_session"
"""Session scope key."""
SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
"""ASGI events that terminate a session scope."""
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/config/engine.py 0000664 0000000 0000000 00000002230 15003544734 0033230 0 ustar 00root root 0000000 0000000 from dataclasses import dataclass
from typing import Any, Callable
from litestar.serialization import decode_json, encode_json
from advanced_alchemy.config import EngineConfig as _EngineConfig
__all__ = ("EngineConfig",)
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Args:
value: Any json serializable value.
Returns:
JSON string.
"""
return encode_json(value).decode("utf-8")
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's :class:`Engine `.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
"""
json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`JSON ` datatype, this is a Python callable that will
convert a JSON string to a Python object. By default, this is set to Litestar's decode_json function."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, Litestar's encode_json function is used."""
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/config/sync.py 0000664 0000000 0000000 00000026245 15003544734 0032753 0 ustar 00root root 0000000 0000000 from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
from litestar.cli._utils import console # pyright: ignore
from litestar.constants import HTTP_RESPONSE_START
from sqlalchemy import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.extensions.litestar._utils import (
delete_aa_scope_state,
get_aa_scope_state,
set_aa_scope_state,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.common import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.engine import EngineConfig
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from litestar import Litestar
from litestar.datastructures.state import State
from litestar.types import BeforeMessageSendHookHandler, Message, Scope
__all__ = (
"SQLAlchemySyncConfig",
"autocommit_before_send_handler",
"autocommit_handler_maker",
"default_before_send_handler",
"default_handler_maker",
)
def default_handler_maker(
session_scope_key: str = SESSION_SCOPE_KEY,
) -> "Callable[[Message, Scope], None]":
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
session_scope_key: The key to use within the application state
Returns:
The handler callable
"""
def handler(message: "Message", scope: "Scope") -> None:
"""Handle commit/rollback, closing and cleaning up sessions before sending.
Args:
message: ASGI-``Message``
scope: An ASGI-``Scope``
Returns:
None
"""
session = cast("Optional[Session]", get_aa_scope_state(scope, session_scope_key))
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
session.close()
delete_aa_scope_state(scope, session_scope_key)
return handler
default_before_send_handler = default_handler_maker()
def autocommit_handler_maker(
commit_on_redirect: bool = False,
extra_commit_statuses: "Optional[set[int]]" = None,
extra_rollback_statuses: "Optional[set[int]]" = None,
session_scope_key: str = SESSION_SCOPE_KEY,
) -> "Callable[[Message, Scope], None]":
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
extra_commit_statuses: A set of additional status codes that trigger a commit
extra_rollback_statuses: A set of additional status codes that trigger a rollback
session_scope_key: The key to use within the application state
Raises:
ValueError: If extra rollback statuses and commit statuses share any status codes
Returns:
The handler callable
"""
if extra_commit_statuses is None:
extra_commit_statuses = set()
if extra_rollback_statuses is None:
extra_rollback_statuses = set()
if len(extra_commit_statuses & extra_rollback_statuses) > 0:
msg = "Extra rollback statuses and commit statuses must not share any status codes"
raise ValueError(msg)
commit_range = range(200, 400 if commit_on_redirect else 300)
def handler(message: "Message", scope: "Scope") -> None:
"""Handle commit/rollback, closing and cleaning up sessions before sending.
Args:
message: ASGI-``Message``
scope: An ASGI-``Scope``
"""
session = cast("Optional[Session]", get_aa_scope_state(scope, session_scope_key))
try:
if session is not None and message["type"] == HTTP_RESPONSE_START:
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
"status"
] not in extra_rollback_statuses:
session.commit()
else:
session.rollback()
finally:
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
session.close()
delete_aa_scope_state(scope, session_scope_key)
return handler
autocommit_before_send_handler = autocommit_handler_maker()
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""Litestar Sync SQLAlchemy Configuration."""
before_send_handler: Optional[
Union["BeforeMessageSendHookHandler", Literal["autocommit", "autocommit_include_redirects"]]
] = None
"""Handler to call before the ASGI message is sent.
The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and
uncommitted data.
"""
engine_dependency_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_dependency_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
engine_app_state_key: str = "db_engine"
"""Key under which to store the SQLAlchemy engine in the application :class:`State <.datastructures.State>`
instance.
"""
session_maker_app_state_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application
:class:`State <.datastructures.State>` instance.
"""
session_scope_key: str = SESSION_SCOPE_KEY
"""Key under which to store the SQLAlchemy scope in the application."""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
set_default_exception_handler: bool = True
"""Sets the default exception handler on application start."""
def _ensure_unique(self, registry_name: str, key: str, new_key: Optional[str] = None, _iter: int = 0) -> str:
new_key = new_key if new_key is not None else key
if new_key in getattr(self.__class__, registry_name, {}):
_iter += 1
new_key = self._ensure_unique(registry_name, key, f"{key}_{_iter}", _iter)
return new_key
def __post_init__(self) -> None:
self.session_scope_key = self._ensure_unique("_SESSION_SCOPE_KEY_REGISTRY", self.session_scope_key)
self.engine_app_state_key = self._ensure_unique("_ENGINE_APP_STATE_KEY_REGISTRY", self.engine_app_state_key)
self.session_maker_app_state_key = self._ensure_unique(
"_SESSIONMAKER_APP_STATE_KEY_REGISTRY",
self.session_maker_app_state_key,
)
self.__class__._SESSION_SCOPE_KEY_REGISTRY.add(self.session_scope_key) # noqa: SLF001
self.__class__._ENGINE_APP_STATE_KEY_REGISTRY.add(self.engine_app_state_key) # noqa: SLF001
self.__class__._SESSIONMAKER_APP_STATE_KEY_REGISTRY.add(self.session_maker_app_state_key) # noqa: SLF001
if self.before_send_handler is None:
self.before_send_handler = default_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit":
self.before_send_handler = autocommit_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit_include_redirects":
self.before_send_handler = autocommit_handler_maker(
session_scope_key=self.session_scope_key,
commit_on_redirect=True,
)
super().__post_init__()
def create_session_maker(self) -> "Callable[[], Session]":
"""Get a session maker. If none exists yet, create one.
Returns:
Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if session_kws.get("bind") is None:
session_kws["bind"] = self.get_engine()
return self.session_maker_class(**session_kws)
@asynccontextmanager
async def lifespan(
self,
app: "Litestar",
) -> "AsyncGenerator[None, None]":
deps = self.create_app_state_items()
app.state.update(deps)
try:
if self.create_all:
self.create_all_metadata(app)
yield
finally:
if self.engine_dependency_key in deps:
engine = deps[self.engine_dependency_key]
if hasattr(engine, "dispose"):
cast("Engine", engine).dispose()
def provide_engine(self, state: "State") -> "Engine":
"""Create an engine instance.
Args:
state: The ``Litestar.state`` instance.
Returns:
An engine instance.
"""
return cast("Engine", state.get(self.engine_app_state_key))
def provide_session(self, state: "State", scope: "Scope") -> "Session":
"""Create a session instance.
Args:
state: The ``Litestar.state`` instance.
scope: The current connection's scope.
Returns:
A session instance.
"""
# Import locally to avoid potential circular dependency issues at module level
from advanced_alchemy._listeners import set_async_context
session = cast("Optional[Session]", get_aa_scope_state(scope, self.session_scope_key))
if session is None:
session_maker = cast("Callable[[], Session]", state[self.session_maker_app_state_key])
session = session_maker()
set_aa_scope_state(scope, self.session_scope_key, session)
set_async_context(False) # Set context before yielding
return session
@property
def signature_namespace(self) -> "dict[str, Any]":
"""Return the plugin's signature namespace.
Returns:
A string keyed dict of names to be added to the namespace for signature forward reference resolution.
"""
return {"Engine": Engine, "Session": Session}
def create_all_metadata(self, app: "Litestar") -> None:
"""Create all metadata
Args:
app (Litestar): The ``Litestar`` instance
"""
with self.get_engine().begin() as conn:
try:
metadata_registry.get(self.bind_key).create_all(bind=conn)
except OperationalError as exc:
console.print(f"[bold red] * Could not create target metadata. Reason: {exc}")
def create_app_state_items(self) -> "dict[str, Any]":
"""Key/value pairs to be stored in application state.
Returns:
A dictionary of key/value pairs to be stored in application state.
"""
return {
self.engine_app_state_key: self.get_engine(),
self.session_maker_app_state_key: self.create_session_maker(),
}
def update_app_state(self, app: "Litestar") -> None:
"""Set the app state with engine and session.
Args:
app: The ``Litestar`` instance.
"""
app.state.update(self.create_app_state_items())
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/init/plugin.py 0000664 0000000 0000000 00000015250 15003544734 0032022 0 ustar 00root root 0000000 0000000 import contextlib
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Union, cast
from litestar.di import Provide
from litestar.dto import DTOData
from litestar.params import Dependency, Parameter
from litestar.plugins import CLIPlugin, InitPluginProtocol
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import Session, scoped_session
from advanced_alchemy.exceptions import ImproperConfigurationError, RepositoryError
from advanced_alchemy.extensions.litestar.exception_handler import exception_to_http_response
from advanced_alchemy.extensions.litestar.plugins import _slots_base
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
ComparisonFilter,
ExistsFilter,
FilterGroup,
FilterMap,
FilterTypes,
InAnyFilter,
LimitOffset,
LogicalOperatorMap,
MultiFilter,
NotExistsFilter,
NotInCollectionFilter,
NotInSearchFilter,
OnBeforeAfter,
OrderBy,
SearchFilter,
StatementFilter,
StatementTypeT,
)
from advanced_alchemy.service import ModelDictListT, ModelDictT, ModelDTOT, ModelOrRowMappingT, ModelT, OffsetPagination
if TYPE_CHECKING:
from click import Group
from litestar.config.app import AppConfig
from litestar.types import BeforeMessageSendHookHandler
from advanced_alchemy.extensions.litestar.plugins.init.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
__all__ = ("SQLAlchemyInitPlugin",)
signature_namespace_values: dict[str, Any] = {
"BeforeAfter": BeforeAfter,
"OnBeforeAfter": OnBeforeAfter,
"CollectionFilter": CollectionFilter,
"LimitOffset": LimitOffset,
"OrderBy": OrderBy,
"SearchFilter": SearchFilter,
"NotInCollectionFilter": NotInCollectionFilter,
"NotInSearchFilter": NotInSearchFilter,
"FilterTypes": FilterTypes,
"OffsetPagination": OffsetPagination,
"ExistsFilter": ExistsFilter,
"Parameter": Parameter,
"Dependency": Dependency,
"DTOData": DTOData,
"Sequence": Sequence,
"ModelT": ModelT,
"ModelDictT": ModelDictT,
"ModelDTOT": ModelDTOT,
"ModelDictListT": ModelDictListT,
"ModelOrRowMappingT": ModelOrRowMappingT,
"Session": Session,
"scoped_session": scoped_session,
"AsyncSession": AsyncSession,
"async_scoped_session": async_scoped_session,
"FilterGroup": FilterGroup,
"NotExistsFilter": NotExistsFilter,
"MultiFilter": MultiFilter,
"ComparisonFilter": ComparisonFilter,
"StatementTypeT": StatementTypeT,
"StatementFilter": StatementFilter,
"LogicalOperatorMap": LogicalOperatorMap,
"InAnyFilter": InAnyFilter,
"FilterMap": FilterMap,
}
class SQLAlchemyInitPlugin(InitPluginProtocol, CLIPlugin, _slots_base.SlotsBase):
"""SQLAlchemy application lifecycle configuration."""
def __init__(
self,
config: Union[
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]",
],
) -> None:
"""Initialize ``SQLAlchemyPlugin``.
Args:
config: configure DB connection and hook handlers and dependencies.
"""
self._config = config
@property
def config(self) -> "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]":
return self._config if isinstance(self._config, Sequence) else [self._config]
def on_cli_init(self, cli: "Group") -> None:
from advanced_alchemy.extensions.litestar.cli import database_group
cli.add_command(database_group)
def _validate_config(self) -> None:
configs = self._config if isinstance(self._config, Sequence) else [self._config]
scope_keys = {config.session_scope_key for config in configs}
engine_keys = {config.engine_dependency_key for config in configs}
session_keys = {config.session_dependency_key for config in configs}
if len(configs) > 1 and any(len(i) != len(configs) for i in (scope_keys, engine_keys, session_keys)):
raise ImproperConfigurationError(
detail="When using multiple configurations, please ensure the `session_dependency_key` and `engine_dependency_key` settings are unique across all configs. Additionally, iF you are using a custom `before_send` handler, ensure `session_scope_key` is unique.",
)
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
"""Configure application for use with SQLAlchemy.
Args:
app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
"""
self._validate_config()
with contextlib.suppress(ImportError):
from asyncpg.pgproto import pgproto # pyright: ignore[reportMissingImports]
signature_namespace_values.update({"pgproto.UUID": pgproto.UUID})
app_config.type_encoders = {pgproto.UUID: str, **(app_config.type_encoders or {})}
with contextlib.suppress(ImportError):
import uuid_utils # pyright: ignore[reportMissingImports]
signature_namespace_values.update({"uuid_utils.UUID": uuid_utils.UUID}) # pyright: ignore[reportUnknownMemberType]
app_config.type_encoders = {uuid_utils.UUID: str, **(app_config.type_encoders or {})} # pyright: ignore[reportUnknownMemberType]
app_config.type_decoders = [
(lambda x: x is uuid_utils.UUID, lambda t, v: t(str(v))), # pyright: ignore[reportUnknownMemberType]
*(app_config.type_decoders or []),
]
configure_exception_handler = False
for config in self.config:
if config.set_default_exception_handler:
configure_exception_handler = True
signature_namespace_values.update(config.signature_namespace)
app_config.lifespan.append(config.lifespan) # pyright: ignore[reportUnknownMemberType]
app_config.dependencies.update(
{
config.engine_dependency_key: Provide(config.provide_engine, sync_to_thread=False),
config.session_dependency_key: Provide(config.provide_session, sync_to_thread=False),
},
)
app_config.before_send.append(cast("BeforeMessageSendHookHandler", config.before_send_handler))
app_config.signature_namespace.update(signature_namespace_values)
if configure_exception_handler and not any(
isinstance(exc, int) or issubclass(exc, RepositoryError)
for exc in app_config.exception_handlers # pyright: ignore[reportUnknownMemberType]
):
app_config.exception_handlers.update({RepositoryError: exception_to_http_response}) # pyright: ignore[reportUnknownMemberType]
return app_config
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/plugins/serialization.py 0000664 0000000 0000000 00000002646 15003544734 0032443 0 ustar 00root root 0000000 0000000 from typing import Any
from litestar.plugins import SerializationPlugin
from litestar.typing import FieldDefinition
from sqlalchemy.orm import DeclarativeBase
from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO
from advanced_alchemy.extensions.litestar.plugins import _slots_base
class SQLAlchemySerializationPlugin(SerializationPlugin, _slots_base.SlotsBase):
def __init__(self) -> None:
self._type_dto_map: dict[type[DeclarativeBase], type[SQLAlchemyDTO[Any]]] = {}
def supports_type(self, field_definition: FieldDefinition) -> bool:
return (
field_definition.is_collection and field_definition.has_inner_subclass_of(DeclarativeBase)
) or field_definition.is_subclass_of(DeclarativeBase)
def create_dto_for_type(self, field_definition: FieldDefinition) -> type[SQLAlchemyDTO[Any]]:
# assumes that the type is a container of SQLAlchemy models or a single SQLAlchemy model
annotation = next(
(
inner_type.annotation
for inner_type in field_definition.inner_types
if inner_type.is_subclass_of(DeclarativeBase)
),
field_definition.annotation,
)
if annotation in self._type_dto_map:
return self._type_dto_map[annotation]
self._type_dto_map[annotation] = dto_type = SQLAlchemyDTO[annotation] # type:ignore[valid-type]
return dto_type
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/litestar/providers.py 0000664 0000000 0000000 00000074005 15003544734 0030120 0 ustar 00root root 0000000 0000000 # ruff: noqa: B008
"""Application dependency providers generators.
This module contains functions to create dependency providers for services and filters.
You should not have modify this module very often and should only be invoked under normal usage.
"""
import datetime
import inspect
from collections.abc import AsyncGenerator, Callable, Generator
from typing import (
TYPE_CHECKING,
Any,
Literal,
NamedTuple,
Optional,
TypedDict,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID
from litestar.di import Provide
from litestar.params import Dependency, Parameter
from typing_extensions import NotRequired
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
FilterTypes,
LimitOffset,
NotInCollectionFilter,
OrderBy,
SearchFilter,
)
from advanced_alchemy.service import (
Empty,
EmptyType,
ErrorMessages,
LoadSpec,
ModelT,
SQLAlchemyAsyncRepositoryService,
SQLAlchemySyncRepositoryService,
)
from advanced_alchemy.utils.singleton import SingletonMeta
from advanced_alchemy.utils.text import camelize
if TYPE_CHECKING:
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import SQLAlchemySyncConfig
DTorNone = Optional[datetime.datetime]
StringOrNone = Optional[str]
UuidOrNone = Optional[UUID]
IntOrNone = Optional[int]
BooleanOrNone = Optional[bool]
SortOrder = Literal["asc", "desc"]
SortOrderOrNone = Optional[SortOrder]
AsyncServiceT_co = TypeVar("AsyncServiceT_co", bound=SQLAlchemyAsyncRepositoryService[Any], covariant=True)
SyncServiceT_co = TypeVar("SyncServiceT_co", bound=SQLAlchemySyncRepositoryService[Any], covariant=True)
HashableValue = Union[str, int, float, bool, None]
HashableType = Union[HashableValue, tuple[Any, ...], tuple[tuple[str, Any], ...], tuple[HashableValue, ...]]
class DependencyDefaults:
FILTERS_DEPENDENCY_KEY: str = "filters"
"""Key for the filters dependency."""
CREATED_FILTER_DEPENDENCY_KEY: str = "created_filter"
"""Key for the created filter dependency."""
ID_FILTER_DEPENDENCY_KEY: str = "id_filter"
"""Key for the id filter dependency."""
LIMIT_OFFSET_FILTER_DEPENDENCY_KEY: str = "limit_offset_filter"
"""Key for the limit offset dependency."""
UPDATED_FILTER_DEPENDENCY_KEY: str = "updated_filter"
"""Key for the updated filter dependency."""
ORDER_BY_FILTER_DEPENDENCY_KEY: str = "order_by_filter"
"""Key for the order by dependency."""
SEARCH_FILTER_DEPENDENCY_KEY: str = "search_filter"
"""Key for the search filter dependency."""
DEFAULT_PAGINATION_SIZE: int = 20
"""Default pagination size."""
DEPENDENCY_DEFAULTS = DependencyDefaults()
class FieldNameType(NamedTuple):
"""Type for field name and associated type information.
This allows for specifying both the field name and the expected type for filter values.
"""
name: str
"""Name of the field to filter on."""
type_hint: type[Any] = str
"""Type of the filter value. Defaults to str."""
class FilterConfig(TypedDict):
"""Configuration for generating dynamic filters."""
id_filter: NotRequired[type[Union[UUID, int, str]]]
"""Indicates that the id filter should be enabled. When set, the type specified will be used for the :class:`CollectionFilter`."""
id_field: NotRequired[str]
"""The field on the model that stored the primary key or identifier."""
sort_field: NotRequired[str]
"""The default field to use for the sort filter."""
sort_order: NotRequired[SortOrder]
"""The default order to use for the sort filter."""
pagination_type: NotRequired[Literal["limit_offset"]]
"""When set, pagination is enabled based on the type specified."""
pagination_size: NotRequired[int]
"""The size of the pagination. Defaults to `DEFAULT_PAGINATION_SIZE`."""
search: NotRequired[Union[str, set[str], list[str]]]
"""Fields to enable search on. Can be a comma-separated string or a set of field names."""
search_ignore_case: NotRequired[bool]
"""When set, search is case insensitive by default."""
created_at: NotRequired[bool]
"""When set, created_at filter is enabled."""
updated_at: NotRequired[bool]
"""When set, updated_at filter is enabled."""
not_in_fields: NotRequired[Union[FieldNameType, set[FieldNameType], list[Union[str, FieldNameType]]]]
"""Fields that support not-in collection filters. Can be a single field or a set of fields with type information."""
in_fields: NotRequired[Union[FieldNameType, set[FieldNameType], list[Union[str, FieldNameType]]]]
"""Fields that support in-collection filters. Can be a single field or a set of fields with type information."""
class DependencyCache(metaclass=SingletonMeta):
"""Simple dependency cache for the application. This is used to help memoize dependencies that are generated dynamically."""
def __init__(self) -> None:
self.dependencies: dict[Union[int, str], dict[str, Provide]] = {}
def add_dependencies(self, key: Union[int, str], dependencies: dict[str, Provide]) -> None:
self.dependencies[key] = dependencies
def get_dependencies(self, key: Union[int, str]) -> Optional[dict[str, Provide]]:
return self.dependencies.get(key)
dep_cache = DependencyCache()
@overload
def create_service_provider(
service_class: type["AsyncServiceT_co"],
/,
statement: "Optional[Select[tuple[ModelT]]]" = None,
config: "Optional[SQLAlchemyAsyncConfig]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., AsyncGenerator[AsyncServiceT_co, None]]: ...
@overload
def create_service_provider(
service_class: type["SyncServiceT_co"],
/,
statement: "Optional[Select[tuple[ModelT]]]" = None,
config: "Optional[SQLAlchemySyncConfig]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., Generator[SyncServiceT_co, None, None]]: ...
def create_service_provider(
service_class: type[Union["AsyncServiceT_co", "SyncServiceT_co"]],
/,
statement: "Optional[Select[tuple[ModelT]]]" = None,
config: "Optional[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., Union["AsyncGenerator[AsyncServiceT_co, None]", "Generator[SyncServiceT_co,None, None]"]]:
"""Create a dependency provider for a service with a configurable session key.
Args:
service_class: The service class inheriting from SQLAlchemyAsyncRepositoryService or SQLAlchemySyncRepositoryService.
statement: An optional SQLAlchemy Select statement to scope the service.
config: An optional SQLAlchemy configuration object.
error_messages: Optional custom error messages for the service.
load: Optional LoadSpec for eager loading relationships.
execution_options: Optional dictionary of execution options for SQLAlchemy.
uniquify: Optional flag to uniquify results.
count_with_window_function: Optional flag to use window function for counting.
Returns:
A dependency provider function suitable for Litestar's DI system.
"""
session_dependency_key = config.session_dependency_key if config else "db_session"
if issubclass(service_class, SQLAlchemyAsyncRepositoryService) or service_class is SQLAlchemyAsyncRepositoryService: # type: ignore[comparison-overlap]
session_type_annotation = "Optional[AsyncSession]"
return_type_annotation = AsyncGenerator[service_class, None] # type: ignore[valid-type]
async def provide_service_async(*args: Any, **kwargs: Any) -> "AsyncGenerator[AsyncServiceT_co, None]":
db_session = cast("Optional[AsyncSession]", args[0] if args else kwargs.get(session_dependency_key))
async with service_class.new( # type: ignore[union-attr]
session=db_session, # type: ignore[arg-type]
statement=statement,
config=cast("Optional[SQLAlchemyAsyncConfig]", config), # type: ignore[arg-type]
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
) as service:
yield service
session_param = inspect.Parameter(
name=session_dependency_key,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=session_type_annotation,
)
provider_signature = inspect.Signature(
parameters=[session_param],
return_annotation=return_type_annotation,
)
provide_service_async.__signature__ = provider_signature # type: ignore[attr-defined]
provide_service_async.__annotations__ = {
session_dependency_key: session_type_annotation,
"return": return_type_annotation,
}
return provide_service_async
session_type_annotation = "Optional[Session]"
return_type_annotation = Generator[service_class, None, None] # type: ignore[misc,assignment,valid-type]
def provide_service_sync(*args: Any, **kwargs: Any) -> "Generator[SyncServiceT_co, None, None]":
db_session = cast("Optional[Session]", args[0] if args else kwargs.get(session_dependency_key))
with service_class.new(
session=db_session,
statement=statement,
config=cast("Optional[SQLAlchemySyncConfig]", config),
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
) as service:
yield service
session_param = inspect.Parameter(
name=session_dependency_key,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=session_type_annotation,
)
provider_signature = inspect.Signature(
parameters=[session_param],
return_annotation=return_type_annotation,
)
provide_service_sync.__signature__ = provider_signature # type: ignore[attr-defined]
provide_service_sync.__annotations__ = {
session_dependency_key: session_type_annotation,
"return": return_type_annotation,
}
return provide_service_sync
def create_service_dependencies(
service_class: type[Union["AsyncServiceT_co", "SyncServiceT_co"]],
/,
key: str,
statement: "Optional[Select[tuple[ModelT]]]" = None,
config: "Optional[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
filters: "Optional[FilterConfig]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
dep_defaults: "DependencyDefaults" = DEPENDENCY_DEFAULTS,
) -> dict[str, Provide]:
"""Create a dependency provider for the combined filter function.
Args:
key: The key to use for the dependency provider.
service_class: The service class to create a dependency provider for.
statement: The statement to use for the service.
config: The configuration to use for the service.
error_messages: The error messages to use for the service.
load: The load spec to use for the service.
execution_options: The execution options to use for the service.
filters: The filter configuration to use for the service.
uniquify: Whether to uniquify the service.
count_with_window_function: Whether to count with a window function.
dep_defaults: The dependency defaults to use for the service.
Returns:
A dictionary of dependency providers for the service.
"""
if issubclass(service_class, SQLAlchemyAsyncRepositoryService) or service_class is SQLAlchemyAsyncRepositoryService: # type: ignore[comparison-overlap]
svc = create_service_provider( # type: ignore[type-var,misc,unused-ignore]
service_class,
statement,
cast("Optional[SQLAlchemyAsyncConfig]", config),
error_messages,
load,
execution_options,
uniquify,
count_with_window_function,
)
deps = {key: Provide(svc)}
else:
svc = create_service_provider( # type: ignore[assignment]
service_class,
statement,
cast("Optional[SQLAlchemySyncConfig]", config),
error_messages,
load,
execution_options,
uniquify,
count_with_window_function,
)
deps = {key: Provide(svc, sync_to_thread=False)}
if filters:
deps.update(create_filter_dependencies(filters, dep_defaults))
return deps
def create_filter_dependencies(
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
) -> dict[str, Provide]:
"""Create a dependency provider for the combined filter function.
Args:
config: FilterConfig instance with desired settings.
dep_defaults: Dependency defaults to use for the filter dependencies
Returns:
A dependency provider function for the combined filter function.
"""
cache_key = hash(_make_hashable(config))
deps = dep_cache.get_dependencies(cache_key)
if deps is not None:
return deps
deps = _create_statement_filters(config, dep_defaults)
dep_cache.add_dependencies(cache_key, deps)
return deps
def _make_hashable(value: Any) -> HashableType:
"""Convert a value into a hashable type.
This function converts any value into a hashable type by:
- Converting dictionaries to sorted tuples of (key, value) pairs
- Converting lists and sets to sorted tuples
- Preserving primitive types (str, int, float, bool, None)
- Converting any other type to its string representation
Args:
value: Any value that needs to be made hashable.
Returns:
A hashable version of the value.
"""
if isinstance(value, dict):
# Convert dict to tuple of tuples with sorted keys
items = []
for k in sorted(value.keys()): # pyright: ignore
v = value[k] # pyright: ignore
items.append((str(k), _make_hashable(v))) # pyright: ignore
return tuple(items) # pyright: ignore
if isinstance(value, (list, set)):
hashable_items = [_make_hashable(item) for item in value] # pyright: ignore
filtered_items = [item for item in hashable_items if item is not None] # pyright: ignore
return tuple(sorted(filtered_items, key=str))
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)
def _create_statement_filters( # noqa: C901
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
) -> dict[str, Provide]:
"""Create filter dependencies based on configuration.
Args:
config (FilterConfig): Configuration dictionary specifying which filters to enable
dep_defaults (DependencyDefaults): Dependency defaults to use for the filter dependencies
Returns:
dict[str, Provide]: Dictionary of filter provider functions
"""
filters: dict[str, Provide] = {}
if config.get("id_filter", False):
def provide_id_filter( # pyright: ignore[reportUnknownParameterType]
ids: Optional[list[str]] = Parameter(query="ids", default=None, required=False),
) -> CollectionFilter: # pyright: ignore[reportMissingTypeArgument]
return CollectionFilter(field_name=config.get("id_field", "id"), values=ids)
filters[dep_defaults.ID_FILTER_DEPENDENCY_KEY] = Provide(provide_id_filter, sync_to_thread=False) # pyright: ignore[reportUnknownArgumentType]
if config.get("created_at", False):
def provide_created_filter(
before: DTorNone = Parameter(query="createdBefore", default=None, required=False),
after: DTorNone = Parameter(query="createdAfter", default=None, required=False),
) -> BeforeAfter:
return BeforeAfter("created_at", before, after)
filters[dep_defaults.CREATED_FILTER_DEPENDENCY_KEY] = Provide(provide_created_filter, sync_to_thread=False)
if config.get("updated_at", False):
def provide_updated_filter(
before: DTorNone = Parameter(query="updatedBefore", default=None, required=False),
after: DTorNone = Parameter(query="updatedAfter", default=None, required=False),
) -> BeforeAfter:
return BeforeAfter("updated_at", before, after)
filters[dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY] = Provide(provide_updated_filter, sync_to_thread=False)
if config.get("pagination_type") == "limit_offset":
def provide_limit_offset_pagination(
current_page: int = Parameter(ge=1, query="currentPage", default=1, required=False),
page_size: int = Parameter(
query="pageSize",
ge=1,
default=config.get("pagination_size", dep_defaults.DEFAULT_PAGINATION_SIZE),
required=False,
),
) -> LimitOffset:
return LimitOffset(page_size, page_size * (current_page - 1))
filters[dep_defaults.LIMIT_OFFSET_FILTER_DEPENDENCY_KEY] = Provide(
provide_limit_offset_pagination, sync_to_thread=False
)
if search_fields := config.get("search"):
def provide_search_filter(
search_string: StringOrNone = Parameter(
title="Field to search",
query="searchString",
default=None,
required=False,
),
ignore_case: BooleanOrNone = Parameter(
title="Search should be case sensitive",
query="searchIgnoreCase",
default=config.get("search_ignore_case", False),
required=False,
),
) -> SearchFilter:
# Handle both string and set input types for search fields
field_names = set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields)
return SearchFilter(
field_name=field_names,
value=search_string, # type: ignore[arg-type]
ignore_case=ignore_case or False,
)
filters[dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY] = Provide(provide_search_filter, sync_to_thread=False)
if sort_field := config.get("sort_field"):
def provide_order_by(
field_name: StringOrNone = Parameter(
title="Order by field",
query="orderBy",
default=sort_field,
required=False,
),
sort_order: SortOrderOrNone = Parameter(
title="Field to search",
query="sortOrder",
default=config.get("sort_order", "desc"),
required=False,
),
) -> OrderBy:
return OrderBy(field_name=field_name, sort_order=sort_order) # type: ignore[arg-type]
filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False)
# Add not_in filter providers
if not_in_fields := config.get("not_in_fields"):
# Get all field names, handling both strings and FieldNameType objects
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
for field_def in not_in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
def create_not_in_filter_provider( # pyright: ignore
field_name: FieldNameType,
) -> Callable[..., Optional[NotInCollectionFilter[field_def.type_hint]]]: # type: ignore
def provide_not_in_filter( # pyright: ignore
values: Optional[list[field_name.type_hint]] = Parameter( # type: ignore
query=camelize(f"{field_name.name}_not_in"), default=None, required=False
),
) -> Optional[NotInCollectionFilter[field_name.type_hint]]: # type: ignore
return (
NotInCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore
if values
else None
)
return provide_not_in_filter # pyright: ignore
provider = create_not_in_filter_provider(field_def) # pyright: ignore
filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
# Add in filter providers
if in_fields := config.get("in_fields"):
# Get all field names, handling both strings and FieldNameType objects
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
for field_def in in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
def create_in_filter_provider( # pyright: ignore
field_name: FieldNameType,
) -> Callable[..., Optional[CollectionFilter[field_def.type_hint]]]: # type: ignore # pyright: ignore
def provide_in_filter( # pyright: ignore
values: Optional[list[field_name.type_hint]] = Parameter( # type: ignore # pyright: ignore
query=camelize(f"{field_name.name}_in"), default=None, required=False
),
) -> Optional[CollectionFilter[field_name.type_hint]]: # type: ignore # pyright: ignore
return (
CollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore # pyright: ignore
if values
else None
)
return provide_in_filter # pyright: ignore
provider = create_in_filter_provider(field_def) # type: ignore
filters[f"{field_def.name}_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
if filters:
filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide(
_create_filter_aggregate_function(config), sync_to_thread=False
)
return filters
def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]: # noqa: C901, PLR0915
"""Create a filter function based on the provided configuration.
Args:
config: The filter configuration.
Returns:
A function that returns a list of filters based on the configuration.
"""
parameters: dict[str, inspect.Parameter] = {}
annotations: dict[str, Any] = {}
# Build parameters based on config
if cls := config.get("id_filter"):
parameters["id_filter"] = inspect.Parameter(
name="id_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=CollectionFilter[cls], # type: ignore[valid-type]
)
annotations["id_filter"] = CollectionFilter[cls] # type: ignore[valid-type]
if config.get("created_at"):
parameters["created_filter"] = inspect.Parameter(
name="created_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=BeforeAfter,
)
annotations["created_filter"] = BeforeAfter
if config.get("updated_at"):
parameters["updated_filter"] = inspect.Parameter(
name="updated_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=BeforeAfter,
)
annotations["updated_filter"] = BeforeAfter
if config.get("search"):
parameters["search_filter"] = inspect.Parameter(
name="search_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=SearchFilter,
)
annotations["search_filter"] = SearchFilter
if config.get("pagination_type") == "limit_offset":
parameters["limit_offset_filter"] = inspect.Parameter(
name="limit_offset_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=LimitOffset,
)
annotations["limit_offset_filter"] = LimitOffset
if config.get("sort_field"):
parameters["order_by_filter"] = inspect.Parameter(
name="order_by_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=OrderBy,
)
annotations["order_by_filter"] = OrderBy
# Add parameters for not_in filters
if not_in_fields := config.get("not_in_fields"):
for field_def in not_in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
parameters[f"{field_def.name}_not_in_filter"] = inspect.Parameter(
name=f"{field_def.name}_not_in_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=NotInCollectionFilter[field_def.type_hint], # type: ignore
)
annotations[f"{field_def.name}_not_in_filter"] = NotInCollectionFilter[field_def.type_hint] # type: ignore
# Add parameters for in filters
if in_fields := config.get("in_fields"):
for field_def in in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
parameters[f"{field_def.name}_in_filter"] = inspect.Parameter(
name=f"{field_def.name}_in_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=CollectionFilter[field_def.type_hint], # type: ignore
)
annotations[f"{field_def.name}_in_filter"] = CollectionFilter[field_def.type_hint] # type: ignore
def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
"""Provide filter dependencies based on configuration.
Args:
**kwargs: Filter parameters dynamically provided based on configuration.
Returns:
list[FilterTypes]: List of configured filters.
"""
filters: list[FilterTypes] = []
if id_filter := kwargs.get("id_filter"):
filters.append(id_filter)
if created_filter := kwargs.get("created_filter"):
filters.append(created_filter)
if limit_offset := kwargs.get("limit_offset_filter"):
filters.append(limit_offset)
if updated_filter := kwargs.get("updated_filter"):
filters.append(updated_filter)
if (
(search_filter := cast("Optional[SearchFilter]", kwargs.get("search_filter")))
and search_filter is not None # pyright: ignore[reportUnnecessaryComparison]
and search_filter.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
and search_filter.value is not None # pyright: ignore[reportUnnecessaryComparison]
):
filters.append(search_filter)
if (
(order_by := cast("Optional[OrderBy]", kwargs.get("order_by_filter")))
and order_by is not None # pyright: ignore[reportUnnecessaryComparison]
and order_by.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
):
filters.append(order_by)
# Add not_in filters
if not_in_fields := config.get("not_in_fields"):
# Get all field names, handling both strings and FieldNameType objects
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
for field_def in not_in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
filter_ = kwargs.get(f"{field_def.name}_not_in_filter")
if filter_ is not None:
filters.append(filter_)
# Add in filters
if in_fields := config.get("in_fields"):
# Get all field names, handling both strings and FieldNameType objects
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
for field_def in in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
filter_ = kwargs.get(f"{field_def.name}_in_filter")
if filter_ is not None:
filters.append(filter_)
return filters
# Set both signature and annotations
provide_filters.__signature__ = inspect.Signature( # type: ignore
parameters=list(parameters.values()),
return_annotation=list[FilterTypes],
)
provide_filters.__annotations__ = annotations
provide_filters.__annotations__["return"] = list[FilterTypes]
return provide_filters
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/sanic/ 0000775 0000000 0000000 00000000000 15003544734 0024771 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/sanic/__init__.py 0000664 0000000 0000000 00000001541 15003544734 0027103 0 ustar 00root root 0000000 0000000 from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import (
AlembicAsyncConfig,
AlembicSyncConfig,
AsyncSessionConfig,
SyncSessionConfig,
)
from advanced_alchemy.extensions.sanic.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.sanic.extension import AdvancedAlchemy
__all__ = (
"AdvancedAlchemy",
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"base",
"exceptions",
"filters",
"mixins",
"operations",
"repository",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/sanic/config.py 0000664 0000000 0000000 00000055344 15003544734 0026623 0 ustar 00root root 0000000 0000000 """Configuration classes for Sanic integration.
This module provides configuration classes for integrating SQLAlchemy with Sanic applications,
including both synchronous and asynchronous database configurations.
"""
import asyncio
import contextlib
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, cast
from click import echo
from sanic import HTTPResponse, Request, Sanic
from sqlalchemy.exc import OperationalError
from advanced_alchemy.exceptions import ImproperConfigurationError
try:
from sanic_ext import Extend
SANIC_INSTALLED = True
except ModuleNotFoundError: # pragma: no cover
SANIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
Extend = type("Extend", (), {}) # type: ignore
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import Literal
from advanced_alchemy._listeners import set_async_context
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.service import schema_dump
def _make_unique_context_key(app: "Sanic[Any, Any]", key: str) -> str: # pragma: no cover
"""Generates a unique context key for the Sanic application.
Ensures that the key does not already exist in the application's state.
Args:
app (sanic.Sanic): The Sanic application instance.
key (str): The base key name.
Returns:
str: A unique key name.
"""
i = 0
while True:
if not hasattr(app.ctx, key):
return key
key = f"{key}_{i}"
i += i
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(schema_dump(value))
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's Engine.
This class extends the base EngineConfig with Sanic-specific JSON serialization options.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
Attributes:
json_deserializer: Callable for converting JSON strings to Python objects.
json_serializer: Callable for converting Python objects to JSON strings.
"""
json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object. But default, this uses the built-in serializers."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, By default, the built-in serializer is used."""
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""SQLAlchemy Async config for Sanic."""
_app: "Optional[Sanic[Any, Any]]" = None
"""The Sanic application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
async with self.engine_instance.begin() as conn:
try:
await conn.run_sync(
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all
)
await conn.commit()
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
else:
echo(" * Created target metadata.")
@property
def app(self) -> "Sanic[Any, Any]":
"""The Sanic application instance.
Raises:
ImproperConfigurationError: If the application is not initialized.
"""
if self._app is None:
msg = "The Sanic application instance is not set."
raise ImproperConfigurationError(msg)
return self._app
def init_app(self, app: "Sanic[Any, Any]", bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Initialize the Sanic application with this configuration.
Args:
app: The Sanic application instance.
bootstrap: The Sanic extension bootstrap.
"""
self._app = app
self.bind_key = self.bind_key or "default"
_ = self.create_session_maker()
self.session_key = _make_unique_context_key(app, f"advanced_alchemy_async_session_{self.session_key}")
self.engine_key = _make_unique_context_key(app, f"advanced_alchemy_async_engine_{self.engine_key}")
self.session_maker_key = _make_unique_context_key(
app, f"advanced_alchemy_async_session_maker_{self.session_maker_key}"
)
self.startup(bootstrap) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
def startup(self, bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Initialize the Sanic application with this configuration.
Args:
bootstrap: The Sanic extension bootstrap.
"""
@self.app.before_server_start # pyright: ignore[reportUnknownMemberType]
async def on_startup(_: Any) -> None: # pyright: ignore[reportUnusedFunction]
setattr(self.app.ctx, self.engine_key, self.get_engine()) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
setattr(self.app.ctx, self.session_maker_key, self.create_session_maker()) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
AsyncEngine,
self.get_engine_from_request,
)
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
async_sessionmaker[AsyncSession],
self.get_sessionmaker_from_request,
)
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
AsyncSession,
self.get_session_from_request,
)
await self.on_startup()
@self.app.after_server_stop # pyright: ignore[reportUnknownMemberType]
async def on_shutdown(_: Any) -> None: # pyright: ignore[reportUnusedFunction]
if self.engine_instance is not None:
await self.engine_instance.dispose()
if hasattr(self.app.ctx, self.engine_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.engine_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
if hasattr(self.app.ctx, self.session_maker_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.session_maker_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
@self.app.middleware("request") # pyright: ignore[reportUnknownMemberType]
async def on_request(request: Request) -> None: # pyright: ignore[reportUnusedFunction]
session = cast("Optional[AsyncSession]", getattr(request.ctx, self.session_key, None))
if session is None:
setattr(request.ctx, self.session_key, self.get_session())
set_async_context(True)
@self.app.middleware("response") # type: ignore[arg-type]
async def on_response(request: Request, response: HTTPResponse) -> None: # pyright: ignore[reportUnusedFunction]
session = cast("Optional[AsyncSession]", getattr(request.ctx, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
async def on_startup(self) -> None:
"""Initialize the Sanic application with this configuration."""
if self.create_all:
await self.create_all_metadata()
def create_session_maker(self) -> Callable[[], "AsyncSession"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
async def session_handler(
self, session: "AsyncSession", request: "Request", response: "HTTPResponse"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (sanic.Request):
The incoming HTTP request.
response (sanic.HTTPResponse):
The outgoing HTTP response.
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status < 400 # noqa: PLR2004
):
await session.commit()
else:
await session.rollback()
finally:
await session.close()
with contextlib.suppress(AttributeError, KeyError):
delattr(request.ctx, self.session_key)
def get_engine_from_request(self, request: "Request") -> AsyncEngine:
"""Retrieve the engine from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
AsyncEngine: The SQLAlchemy engine.
"""
return cast("AsyncEngine", getattr(request.app.ctx, self.engine_key, self.get_engine())) # pragma: no cover
def get_sessionmaker_from_request(self, request: "Request") -> async_sessionmaker[AsyncSession]:
"""Retrieve the session maker from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
SessionMakerT: The session maker.
"""
return cast(
"async_sessionmaker[AsyncSession]", getattr(request.app.ctx, self.session_maker_key, None)
) # pragma: no cover
def get_session_from_request(self, request: Request) -> AsyncSession:
"""Retrieve the session from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
SessionT: The session associated with the request.
"""
return cast("AsyncSession", getattr(request.ctx, self.session_key, None)) # pragma: no cover
async def close_engine(self) -> None: # pragma: no cover
"""Close the engine."""
if self.engine_instance is not None:
await self.engine_instance.dispose()
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
"""
await self.close_engine()
if hasattr(self.app.ctx, self.engine_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.engine_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
if hasattr(self.app.ctx, self.session_maker_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.session_maker_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""SQLAlchemy Sync config for Starlette."""
_app: "Optional[Sanic[Any, Any]]" = None
"""The Sanic application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
@property
def app(self) -> "Sanic[Any, Any]":
"""The Sanic application instance.
Raises:
ImproperConfigurationError: If the application is not initialized.
"""
if self._app is None:
msg = "The Sanic application instance is not set."
raise ImproperConfigurationError(msg)
return self._app
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
with self.engine_instance.begin() as conn:
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None, metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all, conn
)
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
def init_app(self, app: "Sanic[Any, Any]", bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Initialize the Sanic application with this configuration.
Args:
app: The Sanic application instance.
bootstrap: The Sanic extension bootstrap.
"""
self._app = app
self.bind_key = self.bind_key or "default"
_ = self.create_session_maker()
self.session_key = _make_unique_context_key(app, f"advanced_alchemy_sync_session_{self.session_key}")
self.engine_key = _make_unique_context_key(app, f"advanced_alchemy_sync_engine_{self.engine_key}")
self.session_maker_key = _make_unique_context_key(
app, f"advanced_alchemy_sync_session_maker_{self.session_maker_key}"
)
self.startup(bootstrap) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
def startup(self, bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Initialize the Sanic application with this configuration.
Args:
bootstrap: The Sanic extension bootstrap.
"""
@self.app.before_server_start # pyright: ignore[reportUnknownMemberType]
async def on_startup(_: Any) -> None: # pyright: ignore[reportUnusedFunction]
setattr(self.app.ctx, self.engine_key, self.get_engine()) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
setattr(self.app.ctx, self.session_maker_key, self.create_session_maker()) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
AsyncEngine,
self.get_engine_from_request,
)
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
sessionmaker[Session],
self.get_sessionmaker_from_request,
)
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
AsyncSession,
self.get_session_from_request,
)
await self.on_startup()
@self.app.after_server_stop # pyright: ignore[reportUnknownMemberType]
async def on_shutdown(_: Any) -> None: # pyright: ignore[reportUnusedFunction]
await self.on_shutdown()
@self.app.middleware("request") # pyright: ignore[reportUnknownMemberType]
async def on_request(request: Request) -> None: # pyright: ignore[reportUnusedFunction]
session = cast("Optional[Session]", getattr(request.ctx, self.session_key, None))
if session is None:
setattr(request.ctx, self.session_key, self.get_session())
set_async_context(False)
@self.app.middleware("response") # type: ignore[arg-type]
async def on_response(request: Request, response: HTTPResponse) -> None: # pyright: ignore[reportUnusedFunction]
session = cast("Optional[Session]", getattr(request.ctx, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
async def on_startup(self) -> None:
"""Initialize the Sanic application with this configuration."""
if self.create_all:
await self.create_all_metadata()
def create_session_maker(self) -> Callable[[], "Session"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
async def session_handler(
self, session: "Session", request: "Request", response: "HTTPResponse"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.orm.Session):
The database session.
request (sanic.Request):
The incoming HTTP request.
response (sanic.HTTPResponse):
The outgoing HTTP response.
"""
loop = asyncio.get_event_loop()
try:
if (self.commit_mode == "autocommit" and 200 <= response.status < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status < 400 # noqa: PLR2004
):
await loop.run_in_executor(None, session.commit)
else:
await loop.run_in_executor(None, session.rollback)
finally:
await loop.run_in_executor(None, session.close)
with contextlib.suppress(AttributeError, KeyError):
delattr(request.ctx, self.session_key)
def get_engine_from_request(self, request: Request) -> "AsyncEngine":
"""Retrieve the engine from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
AsyncEngine: The SQLAlchemy engine.
"""
return cast("AsyncEngine", getattr(request.app.ctx, self.engine_key, self.get_engine())) # pragma: no cover
def get_sessionmaker_from_request(self, request: Request) -> sessionmaker[Session]:
"""Retrieve the session maker from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
SessionMakerT: The session maker.
"""
return cast("sessionmaker[Session]", getattr(request.app.ctx, self.session_maker_key, None)) # pragma: no cover
def get_session_from_request(self, request: Request) -> "Session":
"""Retrieve the session from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
SessionT: The session associated with the request.
"""
return cast("Session", getattr(request.ctx, self.session_key, None)) # pragma: no cover
async def close_engine(self) -> None: # pragma: no cover
"""Close the engine."""
if self.engine_instance is not None:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.engine_instance.dispose)
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
"""
await self.close_engine()
if hasattr(self.app.ctx, self.engine_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.engine_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
if hasattr(self.app.ctx, self.session_maker_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.session_maker_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/sanic/extension.py 0000664 0000000 0000000 00000031261 15003544734 0027362 0 ustar 00root root 0000000 0000000 # ruff: noqa: PLR0904
from collections.abc import AsyncGenerator, Generator, Sequence
from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, overload
from sanic import Request, Sanic
from advanced_alchemy.exceptions import ImproperConfigurationError, MissingDependencyError
from advanced_alchemy.extensions.sanic.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
try:
from sanic_ext import Extend
from sanic_ext.extensions.base import Extension
SANIC_INSTALLED = True
except ModuleNotFoundError: # pragma: no cover
SANIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
Extension = type("Extension", (), {}) # type: ignore
Extend = type("Extend", (), {}) # type: ignore
if TYPE_CHECKING:
from sanic import Sanic
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import Session
__all__ = ("AdvancedAlchemy",)
class AdvancedAlchemy(Extension): # type: ignore[no-untyped-call] # pyright: ignore[reportGeneralTypeIssues,reportUntypedBaseClass]
"""Sanic extension for integrating Advanced Alchemy with SQLAlchemy.
Args:
config: One or more configurations for SQLAlchemy.
app: The Sanic application instance.
"""
name = "AdvancedAlchemy"
def __init__(
self,
*,
sqlalchemy_config: Union[
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
Sequence[Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]],
],
sanic_app: Optional["Sanic[Any, Any]"] = None,
) -> None:
if not SANIC_INSTALLED: # pragma: no cover
msg = "Could not locate either Sanic or Sanic Extensions. Both libraries must be installed to use Advanced Alchemy. Try: pip install sanic[ext]"
raise MissingDependencyError(msg)
self._config = sqlalchemy_config if isinstance(sqlalchemy_config, Sequence) else [sqlalchemy_config]
self._mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = self.map_configs()
self._app = sanic_app
self._initialized = False
if self._app is not None:
self.register(self._app)
def register(self, sanic_app: "Sanic[Any, Any]") -> None:
"""Initialize the extension with the given Sanic app."""
self._app = sanic_app
Extend.register(self) # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
self._initialized = True
@property
def sanic_app(self) -> "Sanic[Any, Any]":
"""The Sanic app.
Raises:
ImproperConfigurationError: If the app is not initialized.
"""
if self._app is None: # pragma: no cover
msg = "AdvancedAlchemy has not been initialized with a Sanic app."
raise ImproperConfigurationError(msg)
return self._app
@property
def sqlalchemy_config(self) -> Sequence[Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]]:
"""Current Advanced Alchemy configuration."""
return self._config
def startup(self, bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Advanced Alchemy Sanic extension startup hook.
Args:
bootstrap (sanic_ext.Extend): The Sanic extension bootstrap.
"""
for config in self.sqlalchemy_config:
config.init_app(self.sanic_app, bootstrap) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
def map_configs(self) -> dict[str, Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]]:
"""Maps the configs to the session bind keys.
Returns:
A dictionary mapping bind keys to SQLAlchemy configurations.
"""
mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = {}
for config in self.sqlalchemy_config:
if config.bind_key is None:
config.bind_key = "default"
mapped_configs[config.bind_key] = config
return mapped_configs
def get_config(self, key: Optional[str] = None) -> Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]:
"""Get the config for the given key.
Returns:
The config for the given key.
Raises:
ImproperConfigurationError: If the config is not found.
"""
if key is None:
key = "default"
if key == "default" and len(self.sqlalchemy_config) == 1:
key = self.sqlalchemy_config[0].bind_key or "default"
config = self._mapped_configs.get(key)
if config is None: # pragma: no cover
msg = f"Config with key {key} not found"
raise ImproperConfigurationError(msg)
return config
def get_async_config(self, key: Optional[str] = None) -> "SQLAlchemyAsyncConfig":
"""Get the async config for the given key.
Returns:
The async config for the given key.
Raises:
ImproperConfigurationError: If the config is not an async config.
"""
config = self.get_config(key)
if not isinstance(config, SQLAlchemyAsyncConfig): # pragma: no cover
msg = "Expected an async config, but got a sync config"
raise ImproperConfigurationError(msg)
return config
def get_sync_config(self, key: Optional[str] = None) -> "SQLAlchemySyncConfig":
"""Get the sync config for the given key.
Returns:
The sync config for the given key.
Raises:
ImproperConfigurationError: If the config is not an sync config.
"""
config = self.get_config(key)
if not isinstance(config, SQLAlchemySyncConfig): # pragma: no cover
msg = "Expected a sync config, but got an async config"
raise ImproperConfigurationError(msg)
return config
@asynccontextmanager
async def with_async_session(
self, key: Optional[str] = None
) -> AsyncGenerator["AsyncSession", None]: # pragma: no cover
"""Context manager for getting an async session.
Yields:
An AsyncSession instance.
"""
config = self.get_async_config(key)
async with config.get_session() as session:
yield session
@contextmanager
def with_sync_session(self, key: Optional[str] = None) -> Generator["Session", None]: # pragma: no cover
"""Context manager for getting a sync session.
Yields:
A Session instance.
"""
config = self.get_sync_config(key)
with config.get_session() as session:
yield session
@overload
@staticmethod
def _get_session_from_request(request: "Request", config: "SQLAlchemyAsyncConfig") -> "AsyncSession": ...
@overload
@staticmethod
def _get_session_from_request(request: "Request", config: "SQLAlchemySyncConfig") -> "Session": ...
@staticmethod
def _get_session_from_request(
request: "Request",
config: Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"], # pragma: no cover
) -> Union["Session", "AsyncSession"]: # pragma: no cover
"""Get the session for the request and config.
Returns:
The session for the request and config.
"""
session = getattr(request.ctx, config.session_key, None)
if session is None:
setattr(request.ctx, config.session_key, config.get_session())
return cast("Union[Session, AsyncSession]", session)
def get_session(
self, request: "Request", key: Optional[str] = None
) -> Union["Session", "AsyncSession"]: # pragma: no cover
"""Get the session for the given key.
Returns:
The session for the given key.
"""
config = self.get_config(key)
return self._get_session_from_request(request, config)
def get_async_session(self, request: "Request", key: Optional[str] = None) -> "AsyncSession": # pragma: no cover
"""Get the async session for the given key.
Returns:
The async session for the given key.
"""
config = self.get_async_config(key)
return self._get_session_from_request(request, config)
def get_sync_session(self, request: "Request", key: Optional[str] = None) -> "Session": # pragma: no cover
"""Get the sync session for the given key.
Returns:
The sync session for the given key.
"""
config = self.get_sync_config(key)
return self._get_session_from_request(request, config)
def provide_session(
self, key: Optional[str] = None
) -> Callable[["Request"], Union["Session", "AsyncSession"]]: # pragma: no cover
"""Get session provider for the given key.
Returns:
The session provider for the given key.
"""
config = self.get_config(key)
def _get_session(request: "Request") -> Union["Session", "AsyncSession"]:
return self._get_session_from_request(request, config)
return _get_session
def provide_async_session(
self, key: Optional[str] = None
) -> Callable[["Request"], "AsyncSession"]: # pragma: no cover
"""Get async session provider for the given key.
Returns:
The async session provider for the given key.
"""
config = self.get_async_config(key)
def _get_session(request: Request) -> "AsyncSession":
return self._get_session_from_request(request, config)
return _get_session
def provide_sync_session(self, key: Optional[str] = None) -> Callable[[Request], "Session"]: # pragma: no cover
"""Get sync session provider for the given key.
Returns:
The sync session provider for the given key.
"""
config = self.get_sync_config(key)
def _get_session(request: Request) -> "Session":
return self._get_session_from_request(request, config)
return _get_session
def get_engine(self, key: Optional[str] = None) -> Union["Engine", "AsyncEngine"]: # pragma: no cover
"""Get the engine for the given key.
Returns:
The engine for the given key.
"""
config = self.get_config(key)
return config.get_engine()
def get_async_engine(self, key: Optional[str] = None) -> "AsyncEngine": # pragma: no cover
"""Get the async engine for the given key.
Returns:
The async engine for the given key.
"""
config = self.get_async_config(key)
return config.get_engine()
def get_sync_engine(self, key: Optional[str] = None) -> "Engine": # pragma: no cover
"""Get the sync engine for the given key.
Returns:
The sync engine for the given key.
"""
config = self.get_sync_config(key)
return config.get_engine()
def provide_engine(
self, key: Optional[str] = None
) -> Callable[[], Union["Engine", "AsyncEngine"]]: # pragma: no cover
"""Get the engine for the given key.
Returns:
A callable that returns the engine.
"""
config = self.get_config(key)
def _get_engine() -> Union["Engine", "AsyncEngine"]:
return config.get_engine()
return _get_engine
def provide_async_engine(self, key: Optional[str] = None) -> Callable[[], "AsyncEngine"]: # pragma: no cover
"""Get the async engine for the given key.
Returns:
A callable that returns the engine.
"""
config = self.get_async_config(key)
def _get_engine() -> "AsyncEngine":
return config.get_engine()
return _get_engine
def provide_sync_engine(self, key: Optional[str] = None) -> Callable[[], "Engine"]: # pragma: no cover
"""Get the sync engine for the given key.
Returns:
A callable that returns the engine.
"""
config = self.get_sync_config(key)
def _get_engine() -> "Engine":
return config.get_engine()
return _get_engine
def add_session_dependency(
self, session_type: type[Union["Session", "AsyncSession"]], key: Optional[str] = None
) -> None:
"""Add a session dependency to the Sanic app."""
self.sanic_app.ext.add_dependency(session_type, self.provide_session(key)) # pyright: ignore[reportUnknownMemberType]
def add_engine_dependency(
self, engine_type: type[Union["Engine", "AsyncEngine"]], key: Optional[str] = None
) -> None:
"""Add an engine dependency to the Sanic app."""
self.sanic_app.ext.add_dependency(engine_type, self.provide_engine(key)) # pyright: ignore[reportUnknownMemberType]
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/starlette/ 0000775 0000000 0000000 00000000000 15003544734 0025703 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/starlette/__init__.py 0000664 0000000 0000000 00000001774 15003544734 0030025 0 ustar 00root root 0000000 0000000 """Starlette extension for Advanced Alchemy.
This module provides Starlette integration for Advanced Alchemy, including session management and service utilities.
"""
from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig
from advanced_alchemy.extensions.starlette.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.starlette.extension import AdvancedAlchemy
__all__ = (
"AdvancedAlchemy",
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"base",
"exceptions",
"filters",
"mixins",
"operations",
"repository",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/starlette/config.py 0000664 0000000 0000000 00000035677 15003544734 0027544 0 ustar 00root root 0000000 0000000 """Configuration classes for Starlette integration.
This module provides configuration classes for integrating SQLAlchemy with Starlette applications,
including both synchronous and asynchronous database configurations.
"""
import contextlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from click import echo
from sqlalchemy.exc import OperationalError
from starlette.concurrency import run_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware
from typing_extensions import Literal
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.service import schema_dump
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from starlette.applications import Starlette
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
def _make_unique_state_key(app: "Starlette", key: str) -> str: # pragma: no cover
"""Generates a unique state key for the Starlette application.
Ensures that the key does not already exist in the application's state.
Args:
app (starlette.applications.Starlette): The Starlette application instance.
key (str): The base key name.
Returns:
str: A unique key name.
"""
i = 0
while True:
if not hasattr(app.state, key):
return key
key = f"{key}_{i}"
i += i
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(schema_dump(value))
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's Engine.
This class extends the base EngineConfig with Starlette-specific JSON serialization options.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
Attributes:
json_deserializer: Callable for converting JSON strings to Python objects.
json_serializer: Callable for converting Python objects to JSON strings.
"""
json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object. But default, this uses the built-in serializers."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, By default, the built-in serializer is used."""
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""SQLAlchemy Async config for Starlette."""
app: "Optional[Starlette]" = None
"""The Starlette application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
async with self.engine_instance.begin() as conn:
try:
await conn.run_sync(
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all
)
await conn.commit()
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
else:
echo(" * Created target metadata.")
def init_app(self, app: "Starlette") -> None:
"""Initialize the Starlette application with this configuration.
Args:
app: The Starlette application instance.
"""
self.app = app
self.bind_key = self.bind_key or "default"
_ = self.create_session_maker()
self.session_key = _make_unique_state_key(app, f"advanced_alchemy_async_session_{self.session_key}")
self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_async_engine_{self.engine_key}")
self.session_maker_key = _make_unique_state_key(
app, f"advanced_alchemy_async_session_maker_{self.session_maker_key}"
)
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
if self.create_all:
await self.create_all_metadata()
def create_session_maker(self) -> Callable[[], "AsyncSession"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
async def session_handler(
self, session: "AsyncSession", request: "Request", response: "Response"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (starlette.requests.Request):
The incoming HTTP request.
response (starlette.responses.Response):
The outgoing HTTP response.
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
await session.commit()
else:
await session.rollback()
finally:
await session.close()
with contextlib.suppress(AttributeError, KeyError):
delattr(request.state, self.session_key)
async def middleware_dispatch(
self, request: "Request", call_next: "RequestResponseEndpoint"
) -> "Response": # pragma: no cover
"""Middleware dispatch function to handle requests and responses.
Processes the request, invokes the next middleware or route handler, and
applies the session handler after the response is generated.
Args:
request (starlette.requests.Request): The incoming HTTP request.
call_next (starlette.middleware.base.RequestResponseEndpoint):
The next middleware or route handler.
Returns:
starlette.responses.Response: The HTTP response.
"""
response = await call_next(request)
session = cast("Optional[AsyncSession]", getattr(request.state, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
return response
async def close_engine(self) -> None: # pragma: no cover
"""Close the engine."""
if self.engine_instance is not None:
await self.engine_instance.dispose()
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
"""
await self.close_engine()
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_maker_key)
delattr(self.app.state, self.session_key)
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""SQLAlchemy Sync config for Starlette."""
app: "Optional[Starlette]" = None
"""The Starlette application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
with self.engine_instance.begin() as conn:
try:
await run_in_threadpool(
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all, conn
)
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
def init_app(self, app: "Starlette") -> None:
"""Initialize the Starlette application with this configuration.
Args:
app: The Starlette application instance.
"""
self.app = app
self.bind_key = self.bind_key or "default"
self.session_key = _make_unique_state_key(app, f"advanced_alchemy_sync_session_{self.session_key}")
self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_sync_engine_{self.engine_key}")
self.session_maker_key = _make_unique_state_key(
app, f"advanced_alchemy_sync_session_maker_{self.session_maker_key}"
)
_ = self.create_session_maker()
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
if self.create_all:
await self.create_all_metadata()
def create_session_maker(self) -> Callable[[], "Session"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
async def session_handler(
self, session: "Session", request: "Request", response: "Response"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (starlette.requests.Request):
The incoming HTTP request.
response (starlette.responses.Response):
The outgoing HTTP response.
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
await run_in_threadpool(session.commit)
else:
await run_in_threadpool(session.rollback)
finally:
await run_in_threadpool(session.close)
with contextlib.suppress(AttributeError, KeyError):
delattr(request.state, self.session_key)
async def middleware_dispatch(
self, request: "Request", call_next: "RequestResponseEndpoint"
) -> "Response": # pragma: no cover
"""Middleware dispatch function to handle requests and responses.
Processes the request, invokes the next middleware or route handler, and
applies the session handler after the response is generated.
Args:
request (starlette.requests.Request): The incoming HTTP request.
call_next (starlette.middleware.base.RequestResponseEndpoint):
The next middleware or route handler.
Returns:
starlette.responses.Response: The HTTP response.
"""
response = await call_next(request)
session = cast("Optional[Session]", getattr(request.state, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
return response
async def close_engine(self) -> None: # pragma: no cover
"""Close the engines."""
if self.engine_instance is not None:
await run_in_threadpool(self.engine_instance.dispose)
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
"""
await self.close_engine()
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_maker_key)
delattr(self.app.state, self.session_key)
python-advanced-alchemy-1.4.1/advanced_alchemy/extensions/starlette/extension.py 0000664 0000000 0000000 00000035071 15003544734 0030277 0 ustar 00root root 0000000 0000000 import contextlib
from collections.abc import AsyncGenerator, Callable, Generator, Sequence
from contextlib import asynccontextmanager, contextmanager
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
cast,
overload,
)
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
from advanced_alchemy._listeners import set_async_context
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.starlette.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
if TYPE_CHECKING:
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.orm import Session
from starlette.applications import Starlette
class AdvancedAlchemy:
"""AdvancedAlchemy integration for Starlette applications.
This class manages SQLAlchemy sessions and engine lifecycle within a Starlette application.
It provides middleware for handling transactions based on commit strategies.
Args:
config (advanced_alchemy.config.asyncio.SQLAlchemyAsyncConfig | advanced_alchemy.config.sync.SQLAlchemySyncConfig):
The SQLAlchemy configuration.
app (starlette.applications.Starlette | None):
The Starlette application instance. Defaults to None.
"""
def __init__(
self,
config: Union[
SQLAlchemyAsyncConfig, SQLAlchemySyncConfig, Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]
],
app: Optional["Starlette"] = None,
) -> None:
self._config = config if isinstance(config, Sequence) else [config]
self._mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = self.map_configs()
self._app = cast("Optional[Starlette]", None)
if app is not None:
self.init_app(app)
@property
def config(self) -> Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]:
"""Current Advanced Alchemy configuration."""
return self._config
def init_app(self, app: "Starlette") -> None:
"""Initializes the Starlette application with SQLAlchemy engine and sessionmaker.
Sets up middleware and shutdown handlers for managing the database engine.
Args:
app (starlette.applications.Starlette): The Starlette application instance.
Raises:
advanced_alchemy.exceptions.ImproperConfigurationError:
If the application is not initialized.
"""
self._app = app
unique_bind_keys = {config.bind_key for config in self.config}
if len(unique_bind_keys) != len(self.config): # pragma: no cover
msg = "Please ensure that each config has a unique name for the `bind_key` attribute. The default is `default` and can only be bound to a single engine."
raise ImproperConfigurationError(msg)
for config in self.config:
config.init_app(app)
app.state.advanced_alchemy = self
original_lifespan = app.router.lifespan_context
@asynccontextmanager
async def wrapped_lifespan(app: "Starlette") -> AsyncGenerator[Any, None]: # pragma: no cover
async with self.lifespan(app), original_lifespan(app) as state:
yield state
app.router.lifespan_context = wrapped_lifespan
@asynccontextmanager
async def lifespan(self, app: "Starlette") -> AsyncGenerator[Any, None]: # pragma: no cover
"""Context manager for lifespan events.
Args:
app: The starlette application.
Yields:
None
"""
await self.on_startup()
try:
yield
finally:
await self.on_shutdown()
@property
def app(self) -> "Starlette": # pragma: no cover
"""Returns the Starlette application instance.
Raises:
advanced_alchemy.exceptions.ImproperConfigurationError:
If the application is not initialized.
Returns:
starlette.applications.Starlette: The Starlette application instance.
"""
if self._app is None: # pragma: no cover
msg = "Application not initialized. Did you forget to call init_app?"
raise ImproperConfigurationError(msg)
return self._app
async def on_startup(self) -> None: # pragma: no cover
"""Initializes the database."""
for config in self.config:
await config.on_startup()
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
"""
for config in self.config:
await config.on_shutdown()
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, "advanced_alchemy")
def map_configs(self) -> dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]:
"""Maps the configs to the session bind keys.
Returns:
A dictionary of config bind keys to configs.
"""
mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = {}
for config in self.config:
if config.bind_key is None:
config.bind_key = "default"
mapped_configs[config.bind_key] = config
return mapped_configs
def get_config(self, key: Optional[str] = None) -> Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]:
"""Get the config for the given key.
Args:
key: The key to get the config for.
Raises:
advanced_alchemy.exceptions.ImproperConfigurationError:
If the config is not found.
Returns:
The config for the given key.
"""
if key is None:
key = "default"
if key == "default" and len(self.config) == 1:
key = self.config[0].bind_key or "default"
config = self._mapped_configs.get(key)
if config is None: # pragma: no cover
msg = f"Config with key {key} not found"
raise ImproperConfigurationError(msg)
return config
def get_async_config(self, key: Optional[str] = None) -> SQLAlchemyAsyncConfig:
"""Get the async config for the given key.
Raises:
advanced_alchemy.exceptions.ImproperConfigurationError:
If the config is not found.
Returns:
The async config for the given key.
"""
config = self.get_config(key)
if not isinstance(config, SQLAlchemyAsyncConfig): # pragma: no cover
msg = "Expected an async config, but got a sync config"
raise ImproperConfigurationError(msg)
return config
def get_sync_config(self, key: Optional[str] = None) -> SQLAlchemySyncConfig:
"""Get the sync config for the given key.
Raises:
advanced_alchemy.exceptions.ImproperConfigurationError:
If the config is not found.
Returns:
The sync config for the given key.
"""
config = self.get_config(key)
if not isinstance(config, SQLAlchemySyncConfig): # pragma: no cover
msg = "Expected a sync config, but got an async config"
raise ImproperConfigurationError(msg)
return config
@asynccontextmanager
async def with_async_session(
self, key: Optional[str] = None
) -> AsyncGenerator["AsyncSession", None]: # pragma: no cover
"""Context manager for getting an async session.
Yields:
The async session for the given key.
"""
config = self.get_async_config(key)
async with config.get_session() as session:
yield session
@contextmanager
def with_sync_session(self, key: Optional[str] = None) -> Generator["Session", None]: # pragma: no cover
"""Context manager for getting a sync session.
Yields:
The sync session for the given key.
"""
config = self.get_sync_config(key)
with config.get_session() as session:
yield session
@overload
@staticmethod
def _get_session_from_request(request: Request, config: SQLAlchemyAsyncConfig) -> "AsyncSession": ...
@overload
@staticmethod
def _get_session_from_request(request: Request, config: SQLAlchemySyncConfig) -> "Session": ...
@staticmethod
def _get_session_from_request(
request: Request,
config: Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig], # pragma: no cover
) -> Union["Session", "AsyncSession"]: # pragma: no cover
"""Get the session for the given key.
Args:
request: The request object.
config: The config object.
Returns:
The session for the given key.
"""
session = getattr(request.state, config.session_key, None)
if session is None:
session = config.create_session_maker()()
setattr(request.state, config.session_key, session)
set_async_context(isinstance(session, AsyncSession))
return session
def get_session(
self, request: Request, key: Optional[str] = None
) -> Union["Session", "AsyncSession"]: # pragma: no cover
"""Get the session for the given key.
Args:
request: The request object.
key: The key to get the session for.
Returns:
The session for the given key.
"""
config = self.get_config(key)
return self._get_session_from_request(request, config)
def get_async_session(self, request: Request, key: Optional[str] = None) -> "AsyncSession": # pragma: no cover
"""Get the async session for the given key.
Args:
request: The request object.
key: The key to get the session for.
Returns:
The async session for the given key.
"""
config = self.get_async_config(key)
return self._get_session_from_request(request, config)
def get_sync_session(self, request: Request, key: Optional[str] = None) -> "Session": # pragma: no cover
"""Get the sync session for the given key.
Args:
request: The request object.
key: The key to get the session for.
Returns:
The sync session for the given key.
"""
config = self.get_sync_config(key)
return self._get_session_from_request(request, config)
def provide_session(
self, key: Optional[str] = None
) -> Callable[[Request], Union["Session", "AsyncSession"]]: # pragma: no cover
"""Get the session for the given key.
Args:
key: The key to get the session for.
Returns:
The session for the given key.
"""
config = self.get_config(key)
def _get_session(request: Request) -> Union["Session", "AsyncSession"]:
set_async_context(isinstance(config, SQLAlchemyAsyncConfig))
return self._get_session_from_request(request, config)
return _get_session
def provide_async_session(
self, key: Optional[str] = None
) -> Callable[[Request], "AsyncSession"]: # pragma: no cover
"""Get the async session for the given key.
Args:
key: The key to get the session for.
Returns:
The async session for the given key.
"""
config = self.get_async_config(key)
def _get_session(request: Request) -> "AsyncSession":
set_async_context(True)
return self._get_session_from_request(request, config)
return _get_session
def provide_sync_session(self, key: Optional[str] = None) -> Callable[[Request], "Session"]: # pragma: no cover
"""Get the sync session for the given key.
Args:
key: The key to get the session for.
Returns:
The sync session for the given key.
"""
config = self.get_sync_config(key)
def _get_session(request: Request) -> "Session":
set_async_context(False)
return self._get_session_from_request(request, config)
return _get_session
def get_engine(self, key: Optional[str] = None) -> Union["Engine", "AsyncEngine"]: # pragma: no cover
"""Get the engine for the given key.
Args:
key: The key to get the engine for.
Returns:
The engine for the given key.
"""
config = self.get_config(key)
return config.get_engine()
def get_async_engine(self, key: Optional[str] = None) -> "AsyncEngine": # pragma: no cover
"""Get the async engine for the given key.
Args:
key: The key to get the engine for.
Returns:
The async engine for the given key.
"""
config = self.get_async_config(key)
return config.get_engine()
def get_sync_engine(self, key: Optional[str] = None) -> "Engine": # pragma: no cover
"""Get the sync engine for the given key.
Args:
key: The key to get the engine for.
Returns:
The sync engine for the given key.
"""
config = self.get_sync_config(key)
return config.get_engine()
def provide_engine(
self, key: Optional[str] = None
) -> Callable[[], Union["Engine", "AsyncEngine"]]: # pragma: no cover
"""Get the engine for the given key.
Args:
key: The key to get the engine for.
Returns:
The engine for the given key.
"""
config = self.get_config(key)
def _get_engine() -> Union["Engine", "AsyncEngine"]:
return config.get_engine()
return _get_engine
def provide_async_engine(self, key: Optional[str] = None) -> Callable[[], "AsyncEngine"]: # pragma: no cover
"""Get the async engine for the given key.
Args:
key: The key to get the engine for.
Returns:
The async engine for the given key.
"""
config = self.get_async_config(key)
def _get_engine() -> "AsyncEngine":
return config.get_engine()
return _get_engine
def provide_sync_engine(self, key: Optional[str] = None) -> Callable[[], "Engine"]: # pragma: no cover
"""Get the sync engine for the given key.
Args:
key: The key to get the engine for.
Returns:
The sync engine for the given key.
"""
config = self.get_sync_config(key)
def _get_engine() -> "Engine":
return config.get_engine()
return _get_engine
python-advanced-alchemy-1.4.1/advanced_alchemy/filters.py 0000664 0000000 0000000 00000114744 15003544734 0023532 0 ustar 00root root 0000000 0000000 """SQLAlchemy filter constructs for advanced query operations.
This module provides a comprehensive collection of filter datastructures designed to
enhance SQLAlchemy query construction. It implements type-safe, reusable filter patterns
for common database query operations.
Features:
Type-safe filter construction, datetime range filtering, collection-based filtering,
pagination support, search operations, and customizable ordering.
Note:
All filter classes implement the :class:`StatementFilter` ABC, ensuring consistent
interface across different filter types.
See Also:
- :class:`sqlalchemy.sql.expression.Select`: Core SQLAlchemy select expression
- :class:`sqlalchemy.orm.Query`: SQLAlchemy ORM query interface
- :mod:`advanced_alchemy.base`: Base model definitions
"""
import datetime
import logging
from abc import ABC, abstractmethod
from collections.abc import Collection
from dataclasses import dataclass
from operator import attrgetter
from typing import (
Any,
Callable,
ClassVar,
Generic,
Literal,
Optional,
Union,
cast,
)
from sqlalchemy import (
BinaryExpression,
ColumnElement,
Date,
Delete,
Select,
Update,
and_,
any_,
exists,
false,
not_,
or_,
select,
text,
true,
)
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql import operators as op
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from typing_extensions import TypeAlias, TypedDict, TypeVar
from advanced_alchemy.base import ModelProtocol
__all__ = (
"BeforeAfter",
"CollectionFilter",
"ComparisonFilter",
"ExistsFilter",
"FilterGroup",
"FilterMap",
"FilterTypes",
"InAnyFilter",
"LimitOffset",
"LogicalOperatorMap",
"MultiFilter",
"NotExistsFilter",
"NotInCollectionFilter",
"NotInSearchFilter",
"OnBeforeAfter",
"OrderBy",
"PaginationFilter",
"SearchFilter",
"StatementFilter",
"StatementFilterT",
"StatementTypeT",
)
T = TypeVar("T")
ModelT = TypeVar("ModelT", bound=ModelProtocol)
StatementFilterT = TypeVar("StatementFilterT", bound="StatementFilter")
StatementTypeT = TypeVar(
"StatementTypeT",
bound=Union[
ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Select[Any], Update, Delete
],
)
logger = logging.getLogger("advanced_alchemy")
# Define TypedDicts for filter and logical maps
class FilterMap(TypedDict):
before_after: "type[BeforeAfter]"
on_before_after: "type[OnBeforeAfter]"
collection: "type[CollectionFilter[Any]]"
not_in_collection: "type[NotInCollectionFilter[Any]]"
limit_offset: "type[LimitOffset]"
order_by: "type[OrderBy]"
search: "type[SearchFilter]"
not_in_search: "type[NotInSearchFilter]"
comparison: "type[ComparisonFilter]"
exists: "type[ExistsFilter]"
not_exists: "type[NotExistsFilter]"
filter_group: "type[FilterGroup]"
class LogicalOperatorMap(TypedDict):
and_: Callable[..., ColumnElement[bool]]
or_: Callable[..., ColumnElement[bool]]
class StatementFilter(ABC):
"""Abstract base class for SQLAlchemy statement filters.
This class defines the interface for all filter types in the system. Each filter
implementation must provide a method to append its filtering logic to an existing
SQLAlchemy statement.
"""
@abstractmethod
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args: Any, **kwargs: Any
) -> StatementTypeT:
"""Append filter conditions to a SQLAlchemy statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
StatementTypeT: Modified SQLAlchemy statement with filter conditions applied
Raises:
NotImplementedError: If the concrete class doesn't implement this method
Note:
This method must be implemented by all concrete filter classes.
See Also:
:meth:`sqlalchemy.sql.expression.Select.where`: SQLAlchemy where clause
"""
return statement
@staticmethod
def _get_instrumented_attr(model: Any, key: Union[str, InstrumentedAttribute[Any]]) -> InstrumentedAttribute[Any]:
"""Get SQLAlchemy instrumented attribute from model.
Args:
model: SQLAlchemy model class or instance
key: Attribute name or instrumented attribute
Returns:
InstrumentedAttribute[Any]: SQLAlchemy instrumented attribute
See Also:
:class:`sqlalchemy.orm.attributes.InstrumentedAttribute`: SQLAlchemy attribute
"""
if isinstance(key, str):
return cast("InstrumentedAttribute[Any]", getattr(model, key))
return key
@dataclass
class BeforeAfter(StatementFilter):
"""DateTime range filter with exclusive bounds.
This filter creates date/time range conditions using < and > operators,
excluding the boundary values.
If either `before` or `after` is None, that boundary condition is not applied.
See Also:
---------
:class:`OnBeforeAfter` : Inclusive datetime range filtering
"""
field_name: str
"""Name of the model attribute to filter on."""
before: Optional[datetime.datetime]
"""Filter results where field is earlier than this value."""
after: Optional[datetime.datetime]
"""Filter results where field is later than this value."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply datetime range conditions to statement.
Parameters
----------
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
Returns:
--------
StatementTypeT
Modified statement with datetime range conditions
"""
field = self._get_instrumented_attr(model, self.field_name)
if self.before is not None:
statement = cast("StatementTypeT", statement.where(field < self.before))
if self.after is not None:
statement = cast("StatementTypeT", statement.where(field > self.after))
return statement
@dataclass
class OnBeforeAfter(StatementFilter):
"""DateTime range filter with inclusive bounds.
This filter creates date/time range conditions using <= and >= operators,
including the boundary values.
If either `on_or_before` or `on_or_after` is None, that boundary condition
is not applied.
See Also:
---------
:class:`BeforeAfter` : Exclusive datetime range filtering
"""
field_name: str
"""Name of the model attribute to filter on."""
on_or_before: Optional[datetime.datetime]
"""Filter results where field is on or earlier than this value."""
on_or_after: Optional[datetime.datetime]
"""Filter results where field is on or later than this value."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply inclusive datetime range conditions to statement.
Parameters
----------
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
Returns:
--------
StatementTypeT
Modified statement with inclusive datetime range conditions
"""
field = self._get_instrumented_attr(model, self.field_name)
if self.on_or_before is not None:
statement = cast("StatementTypeT", statement.where(field <= self.on_or_before))
if self.on_or_after is not None:
statement = cast("StatementTypeT", statement.where(field >= self.on_or_after))
return statement
class InAnyFilter(StatementFilter, ABC):
"""Base class for filters using IN or ANY operators.
This abstract class provides common functionality for filters that check
membership in a collection using either the SQL IN operator or the ANY operator.
"""
@dataclass
class CollectionFilter(InAnyFilter, Generic[T]):
"""Data required to construct a WHERE ... IN (...) clause.
This filter restricts records based on a field's presence in a collection of values.
The filter supports both ``IN`` and ``ANY`` operators for collection membership testing.
Use ``prefer_any=True`` in ``append_to_statement`` to use the ``ANY`` operator.
"""
field_name: str
"""Name of the model attribute to filter on."""
values: Union[Collection[T], None]
"""Values for the ``IN`` clause. If this is None, no filter is applied.
An empty list will force an empty result set (WHERE 1=-1)"""
def append_to_statement(
self,
statement: StatementTypeT,
model: type[ModelT],
prefer_any: bool = False,
) -> StatementTypeT:
"""Apply a WHERE ... IN or WHERE ... ANY (...) clause to the statement.
Parameters
----------
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
prefer_any : bool, optional
If True, uses the SQLAlchemy :func:`any_` operator instead of
:func:`in_` for the filter condition
Returns:
--------
StatementTypeT
Modified statement with the appropriate IN conditions
"""
field = self._get_instrumented_attr(model, self.field_name)
if self.values is None:
return statement
if not self.values:
# Return empty result set by forcing a false condition
return cast("StatementTypeT", statement.where(text("1=-1")))
if prefer_any:
return cast("StatementTypeT", statement.where(any_(self.values) == field)) # type: ignore[arg-type]
return cast("StatementTypeT", statement.where(field.in_(self.values)))
@dataclass
class NotInCollectionFilter(InAnyFilter, Generic[T]):
"""Data required to construct a WHERE ... NOT IN (...) clause.
This filter restricts records based on a field's absence in a collection of values.
The filter supports both ``NOT IN`` and ``!= ANY`` operators for collection exclusion.
Use ``prefer_any=True`` in ``append_to_statement`` to use the ``ANY`` operator.
Parameters
----------
field_name : str
Name of the model attribute to filter on
values : abc.Collection[T] | None
Values for the ``NOT IN`` clause. If this is None or empty,
the filter is not applied.
"""
field_name: str
"""Name of the model attribute to filter on."""
values: Union[Collection[T], None]
"""Values for the ``NOT IN`` clause. If None or empty, no filter is applied."""
def append_to_statement(
self,
statement: StatementTypeT,
model: type[ModelT],
prefer_any: bool = False,
) -> StatementTypeT:
"""Apply a WHERE ... NOT IN or WHERE ... != ANY(...) clause to the statement.
Parameters
----------
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
prefer_any : bool, optional
If True, uses the SQLAlchemy :func:`any_` operator instead of
:func:`notin_` for the filter condition
Returns:
--------
StatementTypeT
Modified statement with the appropriate NOT IN conditions
"""
field = self._get_instrumented_attr(model, self.field_name)
if not self.values:
# If None or empty, we do not modify the statement
return statement
if prefer_any:
return cast("StatementTypeT", statement.where(any_(self.values) != field)) # type: ignore[arg-type]
return cast("StatementTypeT", statement.where(field.notin_(self.values)))
class PaginationFilter(StatementFilter, ABC):
"""Abstract base class for pagination filters.
Subclasses should implement pagination logic, such as limit/offset or
cursor-based pagination.
"""
@dataclass
class LimitOffset(PaginationFilter):
"""Limit and offset pagination filter.
Implements traditional pagination using SQL LIMIT and OFFSET clauses.
Only applies to SELECT statements; other statement types are returned unmodified.
Note:
This filter only modifies SELECT statements. For other statement types
(UPDATE, DELETE), the statement is returned unchanged.
See Also:
- :meth:`sqlalchemy.sql.expression.Select.limit`: SQLAlchemy LIMIT clause
- :meth:`sqlalchemy.sql.expression.Select.offset`: SQLAlchemy OFFSET clause
"""
limit: int
"""Maximum number of rows to return."""
offset: int
"""Number of rows to skip before returning results."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply LIMIT/OFFSET pagination to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with limit and offset applied
Note:
Only modifies SELECT statements. Other statement types are returned as-is.
See Also:
:class:`sqlalchemy.sql.expression.Select`: SQLAlchemy SELECT statement
"""
if isinstance(statement, Select):
return cast("StatementTypeT", statement.limit(self.limit).offset(self.offset))
return statement
@dataclass
class OrderBy(StatementFilter):
"""Order by a specific field.
Appends an ORDER BY clause to SELECT statements, sorting records by the
specified field in ascending or descending order.
Note:
This filter only modifies SELECT statements. For other statement types,
the statement is returned unchanged.
See Also:
- :meth:`sqlalchemy.sql.expression.Select.order_by`: SQLAlchemy ORDER BY clause
- :meth:`sqlalchemy.sql.expression.ColumnElement.asc`: Ascending order
- :meth:`sqlalchemy.sql.expression.ColumnElement.desc`: Descending order
"""
field_name: str
"""Name of the model attribute to sort on."""
sort_order: Literal["asc", "desc"] = "asc"
"""Sort direction ("asc" or "desc")."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Append an ORDER BY clause to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with an ORDER BY clause
Note:
Only modifies SELECT statements. Other statement types are returned as-is.
See Also:
:meth:`sqlalchemy.sql.expression.Select.order_by`: SQLAlchemy ORDER BY
"""
if not isinstance(statement, Select):
return statement
field = self._get_instrumented_attr(model, self.field_name)
if self.sort_order == "desc":
return cast("StatementTypeT", statement.order_by(field.desc()))
return cast("StatementTypeT", statement.order_by(field.asc()))
@dataclass
class SearchFilter(StatementFilter):
"""Case-sensitive or case-insensitive substring matching filter.
Implements text search using SQL LIKE or ILIKE operators. Can search across
multiple fields using OR conditions.
Note:
The search pattern automatically adds wildcards before and after the search
value, equivalent to SQL pattern '%value%'.
See Also:
- :class:`.NotInSearchFilter`: Opposite filter using NOT LIKE/ILIKE
- :meth:`sqlalchemy.sql.expression.ColumnOperators.like`: Case-sensitive LIKE
- :meth:`sqlalchemy.sql.expression.ColumnOperators.ilike`: Case-insensitive LIKE
"""
field_name: Union[str, set[str]]
"""Name or set of names of model attributes to search on."""
value: str
"""Text to match within the field(s)."""
ignore_case: Optional[bool] = False
"""Whether to use case-insensitive matching."""
@property
def _operator(self) -> Callable[..., ColumnElement[bool]]:
"""Return the SQL operator for combining multiple search clauses.
Returns:
Callable[..., ColumnElement[bool]]: The `or_` operator for OR conditions
See Also:
:func:`sqlalchemy.sql.expression.or_`: SQLAlchemy OR operator
"""
return or_
@property
def _func(self) -> "attrgetter[Callable[[str], BinaryExpression[bool]]]":
"""Return the appropriate LIKE or ILIKE operator as a function.
Returns:
attrgetter: Bound method for LIKE or ILIKE operations
See Also:
- :meth:`sqlalchemy.sql.expression.ColumnOperators.like`: LIKE operator
- :meth:`sqlalchemy.sql.expression.ColumnOperators.ilike`: ILIKE operator
"""
return attrgetter("ilike" if self.ignore_case else "like")
@property
def normalized_field_names(self) -> set[str]:
"""Convert field_name to a set if it's a single string.
Returns:
set[str]: Set of field names to be searched
"""
return {self.field_name} if isinstance(self.field_name, str) else self.field_name
def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]:
"""Generate the LIKE/ILIKE clauses for all specified fields.
Args:
model: The SQLAlchemy model class
Returns:
list[BinaryExpression[bool]]: List of text matching expressions
See Also:
:class:`sqlalchemy.sql.expression.BinaryExpression`: SQLAlchemy expression
"""
search_clause: list[BinaryExpression[bool]] = []
for field_name in self.normalized_field_names:
try:
field = self._get_instrumented_attr(model, field_name)
search_text = f"%{self.value}%"
search_clause.append(self._func(field)(search_text))
except AttributeError:
msg = f"Skipping search for field {field_name}. It is not found in model {model.__name__}"
logger.debug(msg)
continue
return search_clause
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Append a LIKE/ILIKE clause to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with text search clauses
See Also:
:meth:`sqlalchemy.sql.expression.Select.where`: SQLAlchemy WHERE clause
"""
where_clause = self._operator(*self.get_search_clauses(model))
return cast("StatementTypeT", statement.where(where_clause))
# Regular typed dictionary for operators_map
operators_map: dict[str, Callable[[Any, Any], ColumnElement[bool]]] = {
"eq": op.eq,
"ne": op.ne,
"gt": op.gt,
"ge": op.ge,
"lt": op.lt,
"le": op.le,
"in": op.in_op,
"notin": op.notin_op,
"between": lambda c, v: c.between(v[0], v[1]),
"like": op.like_op,
"ilike": op.ilike_op,
"startswith": op.startswith_op,
"istartswith": lambda c, v: c.ilike(v + "%"),
"endswith": op.endswith_op,
"iendswith": lambda c, v: c.ilike(v + "%"),
"dateeq": lambda c, v: cast("Date", c) == v,
}
VALID_OPERATORS = set(operators_map.keys())
"""Set of valid operators that can be used in ComparisonFilter."""
@dataclass
class ComparisonFilter(StatementFilter):
"""Simple comparison filter for equality and inequality operations.
This filter applies basic comparison operators (=, !=, >, >=, <, <=) to a field.
It provides a generic way to perform common comparison operations.
Args:
field_name: Name of the model attribute to filter on
operator: Comparison operator to use (must be one of: 'eq', 'ne', 'gt', 'ge', 'lt', 'le', 'in', 'notin', 'between', 'like', 'ilike', 'startswith', 'istartswith', 'endswith', 'iendswith', 'dateeq')
value: Value to compare against
Raises:
ValueError: If an invalid operator is provided
"""
field_name: str
"""Name of the model attribute to filter on."""
operator: str
"""Comparison operator to use (one of 'eq', 'ne', 'gt', 'ge', 'lt', 'le')."""
value: Any
"""Value to compare against."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply a comparison operation to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with the comparison condition
Raises:
ValueError: If an invalid operator is provided
"""
field = self._get_instrumented_attr(model, self.field_name)
operator_func = operators_map.get(self.operator)
if operator_func is None:
msg = f"Invalid operator '{self.operator}'. Must be one of: {', '.join(sorted(VALID_OPERATORS))}"
raise ValueError(msg)
condition = operator_func(field, self.value)
return cast("StatementTypeT", statement.where(condition))
@dataclass
class NotInSearchFilter(SearchFilter):
"""Filter for excluding records that match a substring.
Implements negative text search using SQL NOT LIKE or NOT ILIKE operators.
Can exclude across multiple fields using AND conditions.
Args:
field_name: Name or set of names of model attributes to search on
value: Text to exclude from the field(s)
ignore_case: If True, uses NOT ILIKE for case-insensitive matching
Note:
Uses AND for multiple fields, meaning records matching any field will be excluded.
See Also:
- :class:`.SearchFilter`: Opposite filter using LIKE/ILIKE
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notlike`: NOT LIKE operator
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notilike`: NOT ILIKE operator
"""
@property
def _operator(self) -> Callable[..., ColumnElement[bool]]:
"""Return the SQL operator for combining multiple negated search clauses.
Returns:
Callable[..., ColumnElement[bool]]: The `and_` operator for AND conditions
See Also:
:func:`sqlalchemy.sql.expression.and_`: SQLAlchemy AND operator
"""
return and_
@property
def _func(self) -> "attrgetter[Callable[[str], BinaryExpression[bool]]]":
"""Return the appropriate NOT LIKE or NOT ILIKE operator as a function.
Returns:
attrgetter: Bound method for NOT LIKE or NOT ILIKE operations
See Also:
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notlike`: NOT LIKE
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notilike`: NOT ILIKE
"""
return attrgetter("not_ilike" if self.ignore_case else "not_like")
@dataclass
class ExistsFilter(StatementFilter):
"""Filter for EXISTS subqueries.
This filter creates an EXISTS condition using a list of column expressions.
The expressions can be combined using either AND or OR logic. The filter applies
a correlated subquery that returns only the rows from the main query that match
the specified conditions.
For example, if searching movies with `Movie.genre == "Action"`, only rows where
the genre is "Action" will be returned.
Parameters
----------
values : list[ColumnElement[bool]]
values: List of SQLAlchemy column expressions to use in the EXISTS clause
operator : Literal["and", "or"], optional
operator: If "and", combines conditions with AND, otherwise uses OR. Defaults to "and".
Example:
--------
Basic usage with AND conditions::
from sqlalchemy import select
from advanced_alchemy.filters import ExistsFilter
filter = ExistsFilter(
values=[User.email.like("%@example.com%")],
)
statement = filter.append_to_statement(
select(Organization), Organization
)
This will return only organizations where the user's email contains "@example.com".
Using OR conditions::
filter = ExistsFilter(
values=[User.role == "admin", User.role == "owner"],
operator="or",
)
This will return organizations where the user's role is either "admin" OR "owner".
See Also:
--------
:class:`NotExistsFilter`: The inverse of this filter
:func:`sqlalchemy.sql.expression.exists`: SQLAlchemy EXISTS expression
"""
values: list[ColumnElement[bool]]
"""List of SQLAlchemy column expressions to use in the EXISTS clause."""
operator: Literal["and", "or"] = "and"
"""If "and", combines conditions with the AND operator, otherwise uses OR."""
@property
def _and(self) -> Callable[..., ColumnElement[bool]]:
"""Access the SQLAlchemy `and_` operator.
Returns:
Callable[..., ColumnElement[bool]]: The `and_` operator for AND conditions
See Also:
:func:`sqlalchemy.sql.expression.and_`: SQLAlchemy AND operator
"""
return and_
@property
def _or(self) -> Callable[..., ColumnElement[bool]]:
"""Access the SQLAlchemy `or_` operator.
Returns:
Callable[..., ColumnElement[bool]]: The `or_` operator for OR conditions
See Also:
:func:`sqlalchemy.sql.expression.or_`: SQLAlchemy OR operator
"""
return or_
def _get_combined_conditions(self) -> ColumnElement[bool]:
"""Combine the filter conditions using the specified operator.
Returns:
ColumnElement[bool]:
A SQLAlchemy column expression combining all conditions with AND or OR
"""
op = self._and if self.operator == "and" else self._or
return op(*self.values)
def get_exists_clause(self, model: type[ModelT]) -> ColumnElement[bool]:
"""Generate the EXISTS clause for the statement.
Args:
model : type[ModelT]
The SQLAlchemy model class to correlate with
Returns:
ColumnElement[bool]:
A correlated EXISTS expression for use in a WHERE clause
"""
# Handle empty values list case
if not self.values:
# Use explicitly imported 'false' from sqlalchemy
# Return SQLAlchemy FALSE expression
return false()
# Combine all values with AND or OR (using the operator specified in the filter)
# This creates a single boolean expression from multiple conditions
combined_conditions = self._get_combined_conditions()
# Create a correlated subquery with the combined conditions
try:
subquery = select(1).where(combined_conditions)
correlated_subquery = subquery.correlate(model.__table__)
return exists(correlated_subquery)
except Exception: # noqa: BLE001
return false()
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Append EXISTS condition to the statement.
Args:
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
Returns:
StatementTypeT:
Modified statement with EXISTS condition
"""
# We apply the exists clause regardless of whether self.values is empty,
# as get_exists_clause handles the empty case by returning false().
exists_clause = self.get_exists_clause(model)
return cast("StatementTypeT", statement.where(exists_clause))
@dataclass
class NotExistsFilter(StatementFilter):
"""Filter for NOT EXISTS subqueries.
This filter creates a NOT EXISTS condition using a list of column expressions.
The expressions can be combined using either AND or OR logic. The filter applies
a correlated subquery that returns only the rows from the main query that DO NOT
match the specified conditions.
For example, if searching movies with `Movie.genre == "Action"`, only rows where
the genre is NOT "Action" will be returned.
Parameters
----------
values : list[ColumnElement[bool]]
values: List of SQLAlchemy column expressions to use in the NOT EXISTS clause
operator : Literal["and", "or"], optional
operator: If "and", combines conditions with AND, otherwise uses OR. Defaults to "and".
Example:
--------
Basic usage with AND conditions::
from sqlalchemy import select
from advanced_alchemy.filters import NotExistsFilter
filter = NotExistsFilter(
values=[User.email.like("%@example.com%")],
)
statement = filter.append_to_statement(
select(Organization), Organization
)
This will return only organizations where the user's email does NOT contain "@example.com".
Using OR conditions::
filter = NotExistsFilter(
values=[User.role == "admin", User.role == "owner"],
operator="or",
)
This will return organizations where the user's role is NEITHER "admin" NOR "owner".
See Also:
--------
:class:`ExistsFilter`: The inverse of this filter
:func:`sqlalchemy.sql.expression.not_`: SQLAlchemy NOT operator
:func:`sqlalchemy.sql.expression.exists`: SQLAlchemy EXISTS expression
"""
values: list[ColumnElement[bool]]
"""List of SQLAlchemy column expressions to use in the NOT EXISTS clause."""
operator: Literal["and", "or"] = "and"
"""If "and", combines conditions with the AND operator, otherwise uses OR."""
@property
def _and(self) -> Callable[..., ColumnElement[bool]]:
"""Access the SQLAlchemy `and_` operator.
Returns:
Callable[..., ColumnElement[bool]]: The `and_` operator for AND conditions
See Also:
:func:`sqlalchemy.sql.expression.and_`: SQLAlchemy AND operator
"""
return and_
@property
def _or(self) -> Callable[..., ColumnElement[bool]]:
"""Access the SQLAlchemy `or_` operator.
Returns:
Callable[..., ColumnElement[bool]]: The `or_` operator for OR conditions
See Also:
:func:`sqlalchemy.sql.expression.or_`: SQLAlchemy OR operator
"""
return or_
def _get_combined_conditions(self) -> ColumnElement[bool]:
op = self._and if self.operator == "and" else self._or
return op(*self.values)
def get_exists_clause(self, model: type[ModelT]) -> ColumnElement[bool]:
"""Generate the NOT EXISTS clause for the statement.
Args:
model : type[ModelT]
The SQLAlchemy model class to correlate with
Returns:
ColumnElement[bool]:
A correlated NOT EXISTS expression for use in a WHERE clause
"""
# Handle empty values list case
if not self.values:
# Return SQLAlchemy TRUE expression
return true()
# Combine conditions and create correlated subquery
combined_conditions = self._get_combined_conditions()
subquery = select(1).where(combined_conditions)
correlated_subquery = subquery.correlate(model.__table__)
return not_(exists(correlated_subquery))
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Append NOT EXISTS condition to the statement.
Args:
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
Returns:
StatementTypeT:
Modified statement with NOT EXISTS condition
"""
# We apply the exists clause regardless of whether self.values is empty,
# as get_exists_clause handles the empty case by returning true.
exists_clause = self.get_exists_clause(model)
return cast("StatementTypeT", statement.where(exists_clause))
@dataclass
class FilterGroup(StatementFilter):
"""A group of filters combined with a logical operator.
This class combines multiple filters with a logical operator (AND/OR).
It provides a way to create complex nested filter conditions.
"""
logical_operator: Callable[..., ColumnElement[bool]]
"""Logical operator to combine the filters."""
filters: list[StatementFilter]
"""List of filters to combine."""
def append_to_statement(
self,
statement: StatementTypeT,
model: type[ModelT],
) -> "StatementTypeT":
"""Apply all filters combined with the logical operator.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with combined filters
"""
if not self.filters:
return statement
# Create a list of expressions from each filter
expressions = []
for filter_obj in self.filters:
# Each filter needs to be applied to a clean version of the statement
# to get just its expression
filter_statement = filter_obj.append_to_statement(select(), model)
# Extract the whereclause from the filter's statement
if hasattr(filter_statement, "whereclause") and filter_statement.whereclause is not None:
expressions.append(filter_statement.whereclause) # pyright: ignore
if expressions:
# Combine all expressions with the logical operator
combined = self.logical_operator(*expressions)
return cast("StatementTypeT", statement.where(combined))
return statement
@dataclass
class MultiFilter(StatementFilter):
"""Apply multiple filters to a query based on a JSON/dict input.
This filter provides a way to construct complex filter trees from
a structured dictionary input, supporting nested logical groups and
various filter types.
"""
filters: dict[str, Any]
"""JSON/dict structure representing the filters."""
# TypedDict class variables
_filter_map: ClassVar[FilterMap] = {
"before_after": BeforeAfter,
"on_before_after": OnBeforeAfter,
"collection": CollectionFilter,
"not_in_collection": NotInCollectionFilter,
"limit_offset": LimitOffset,
"order_by": OrderBy,
"search": SearchFilter,
"not_in_search": NotInSearchFilter,
"filter_group": FilterGroup,
"comparison": ComparisonFilter,
"exists": ExistsFilter,
"not_exists": NotExistsFilter,
}
_logical_map: ClassVar[LogicalOperatorMap] = {
"and_": and_,
"or_": or_,
}
def append_to_statement(
self,
statement: StatementTypeT,
model: type[ModelT],
) -> StatementTypeT:
"""Apply the filters to the statement based on the filter definitions.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with all filters applied
"""
for filter_type, conditions in self.filters.items():
operator = self._logical_map.get(filter_type)
if operator and isinstance(conditions, list):
# Create filters from the conditions
valid_filters = []
for cond in conditions: # pyright: ignore
filter_instance = self._create_filter(cond) # pyright: ignore
if filter_instance is not None:
valid_filters.append(filter_instance) # pyright: ignore
# Only create a filter group if we have valid filters
if valid_filters:
filter_group = FilterGroup(
logical_operator=operator, # type: ignore
filters=valid_filters, # pyright: ignore
)
statement = filter_group.append_to_statement(statement, model)
return statement
def _create_filter(self, condition: dict[str, Any]) -> Optional[StatementFilter]:
"""Create a filter instance from a condition dictionary.
Args:
condition: Dictionary defining a filter
Returns:
Optional[StatementFilter]: Filter instance if successfully created, None otherwise
"""
# Check if condition is a nested logical group
logical_keys = set(self._logical_map.keys())
intersect = logical_keys.intersection(condition.keys())
if intersect:
# It's a nested filter group
for key in intersect:
operator = self._logical_map.get(key)
if operator and isinstance(condition.get(key), list):
nested_filters = []
for cond in condition[key]:
filter_instance = self._create_filter(cond)
if filter_instance is not None:
nested_filters.append(filter_instance) # pyright: ignore
if nested_filters:
return FilterGroup(logical_operator=operator, filters=nested_filters) # type: ignore
else:
# Regular filter
filter_type = condition.get("type")
if filter_type is not None and isinstance(filter_type, str):
filter_class = self._filter_map.get(filter_type)
if filter_class is not None:
try:
# Create a copy of the condition without the type key
filter_args = {k: v for k, v in condition.items() if k != "type"}
return filter_class(**filter_args) # type: ignore
except Exception: # noqa: BLE001
return None
return None
# Define FilterTypes using direct class references
FilterTypes: TypeAlias = Union[
BeforeAfter,
OnBeforeAfter,
CollectionFilter[Any],
LimitOffset,
OrderBy,
SearchFilter,
NotInCollectionFilter[Any],
NotInSearchFilter,
ExistsFilter,
NotExistsFilter,
ComparisonFilter,
MultiFilter,
FilterGroup,
]
"""Aggregate type alias of the types supported for collection filtering."""
python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/ 0000775 0000000 0000000 00000000000 15003544734 0023004 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/__init__.py 0000664 0000000 0000000 00000001176 15003544734 0025122 0 ustar 00root root 0000000 0000000 from advanced_alchemy.mixins.audit import AuditColumns
from advanced_alchemy.mixins.bigint import BigIntPrimaryKey
from advanced_alchemy.mixins.nanoid import NanoIDPrimaryKey
from advanced_alchemy.mixins.sentinel import SentinelMixin
from advanced_alchemy.mixins.slug import SlugKey
from advanced_alchemy.mixins.unique import UniqueMixin
from advanced_alchemy.mixins.uuid import UUIDPrimaryKey, UUIDv6PrimaryKey, UUIDv7PrimaryKey
__all__ = (
"AuditColumns",
"BigIntPrimaryKey",
"NanoIDPrimaryKey",
"SentinelMixin",
"SlugKey",
"UUIDPrimaryKey",
"UUIDv6PrimaryKey",
"UUIDv7PrimaryKey",
"UniqueMixin",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/audit.py 0000664 0000000 0000000 00000001722 15003544734 0024466 0 ustar 00root root 0000000 0000000 import datetime
from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column, validates
from advanced_alchemy.types import DateTimeUTC
@declarative_mixin
class AuditColumns:
"""Created/Updated At Fields Mixin."""
created_at: Mapped[datetime.datetime] = mapped_column(
DateTimeUTC(timezone=True),
default=lambda: datetime.datetime.now(datetime.timezone.utc),
)
"""Date/time of instance creation."""
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTimeUTC(timezone=True),
default=lambda: datetime.datetime.now(datetime.timezone.utc),
onupdate=lambda: datetime.datetime.now(datetime.timezone.utc),
)
"""Date/time of instance last update."""
@validates("created_at", "updated_at")
def validate_tz_info(self, _: str, value: datetime.datetime) -> datetime.datetime:
if value.tzinfo is None:
value = value.replace(tzinfo=datetime.timezone.utc)
return value
python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/bigint.py 0000664 0000000 0000000 00000001037 15003544734 0024633 0 ustar 00root root 0000000 0000000 from sqlalchemy import Sequence
from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr, mapped_column
from advanced_alchemy.types import BigIntIdentity
@declarative_mixin
class BigIntPrimaryKey:
"""BigInt Primary Key Field Mixin."""
@declared_attr
def id(cls) -> Mapped[int]:
"""BigInt Primary key column."""
return mapped_column(
BigIntIdentity,
Sequence(f"{cls.__tablename__}_id_seq", optional=False), # type: ignore[attr-defined]
primary_key=True,
)
python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/nanoid.py 0000664 0000000 0000000 00000001303 15003544734 0024623 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING
from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
from advanced_alchemy.mixins.sentinel import SentinelMixin
from advanced_alchemy.types import NANOID_INSTALLED
if NANOID_INSTALLED and not TYPE_CHECKING:
from fastnanoid import ( # type: ignore[import-not-found,unused-ignore] # pyright: ignore[reportMissingImports]
generate as nanoid,
)
else:
from uuid import uuid4 as nanoid # type: ignore[assignment,unused-ignore]
@declarative_mixin
class NanoIDPrimaryKey(SentinelMixin):
"""Nano ID Primary Key Field Mixin."""
id: Mapped[str] = mapped_column(default=nanoid, primary_key=True)
"""Nano ID Primary key column."""
python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/sentinel.py 0000664 0000000 0000000 00000002040 15003544734 0025173 0 ustar 00root root 0000000 0000000 from typing import TypedDict
from sqlalchemy.orm import Mapped, MappedAsDataclass, declarative_mixin, declared_attr, mapped_column
from sqlalchemy.sql.schema import _InsertSentinelColumnDefault # pyright: ignore [reportPrivateUsage]
from typing_extensions import NotRequired
class SentinelKwargs(TypedDict):
init: NotRequired[bool]
@declarative_mixin
class SentinelMixin:
"""Mixin to add a sentinel column for SQLAlchemy models."""
__abstract__ = True
_sentinel_kwargs: SentinelKwargs = {}
def __init_subclass__(cls) -> None:
super().__init_subclass__()
if issubclass(cls, MappedAsDataclass):
cls._sentinel_kwargs["init"] = False
@declared_attr
def _sentinel(cls) -> Mapped[int]:
return mapped_column(
name="sa_orm_sentinel",
insert_default=_InsertSentinelColumnDefault(),
_omit_from_statements=True,
insert_sentinel=True,
use_existing_column=True,
nullable=True,
**cls._sentinel_kwargs,
)
python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/slug.py 0000664 0000000 0000000 00000002617 15003544734 0024336 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any
from sqlalchemy import Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr, mapped_column
if TYPE_CHECKING:
from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType # pyright: ignore[reportPrivateUsage]
@declarative_mixin
class SlugKey:
"""Slug unique Field Model Mixin."""
@declared_attr
def slug(cls) -> Mapped[str]:
"""Slug field."""
return mapped_column(
String(length=100),
nullable=False,
)
@staticmethod
def _create_unique_slug_index(*_: Any, **kwargs: Any) -> bool:
return bool(kwargs["dialect"].name.startswith("spanner"))
@staticmethod
def _create_unique_slug_constraint(*_: Any, **kwargs: Any) -> bool:
return not kwargs["dialect"].name.startswith("spanner")
@declared_attr.directive
@classmethod
def __table_args__(cls) -> "TableArgsType":
return (
UniqueConstraint(
cls.slug,
name=f"uq_{cls.__tablename__}_slug", # type: ignore[attr-defined]
).ddl_if(callable_=cls._create_unique_slug_constraint),
Index(
f"ix_{cls.__tablename__}_slug_unique", # type: ignore[attr-defined]
cls.slug,
unique=True,
).ddl_if(callable_=cls._create_unique_slug_index),
)
python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/unique.py 0000664 0000000 0000000 00000013132 15003544734 0024664 0 ustar 00root root 0000000 0000000 from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union
from sqlalchemy import ColumnElement, select
from sqlalchemy.orm import declarative_mixin
from typing_extensions import Self
from advanced_alchemy.exceptions import wrap_sqlalchemy_exception
if TYPE_CHECKING:
from collections.abc import Hashable, Iterator
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import Session
from sqlalchemy.orm.scoping import scoped_session
__all__ = ("UniqueMixin",)
@declarative_mixin
class UniqueMixin:
"""Mixin for instantiating objects while ensuring uniqueness on some field(s).
This is a slightly modified implementation derived from https://github.com/sqlalchemy/sqlalchemy/wiki/UniqueObject
"""
@classmethod
@contextmanager
def _prevent_autoflush(
cls,
session: "Union[AsyncSession, async_scoped_session[AsyncSession], Session, scoped_session[Session]]",
) -> "Iterator[None]":
with session.no_autoflush, wrap_sqlalchemy_exception():
yield
@classmethod
def _check_uniqueness(
cls,
cache: "Optional[dict[tuple[type[Self], Hashable], Self]]",
session: "Union[AsyncSession, async_scoped_session[AsyncSession], Session, scoped_session[Session]]",
key: "tuple[type[Self], Hashable]",
*args: Any,
**kwargs: Any,
) -> "tuple[dict[tuple[type[Self], Hashable], Self], Select[tuple[Self]], Optional[Self]]":
if cache is None:
cache = {}
setattr(session, "_unique_cache", cache)
statement = select(cls).where(cls.unique_filter(*args, **kwargs)).limit(2)
return cache, statement, cache.get(key)
@classmethod
async def as_unique_async(
cls,
session: "Union[AsyncSession, async_scoped_session[AsyncSession]]",
*args: Any,
**kwargs: Any,
) -> Self:
"""Instantiate and return a unique object within the provided session based on the given arguments.
If an object with the same unique identifier already exists in the session, it is returned from the cache.
Args:
session (AsyncSession | async_scoped_session[AsyncSession]): SQLAlchemy async session
*args (Any): Values used to instantiate the instance if no duplicate exists
**kwargs (Any): Values used to instantiate the instance if no duplicate exists
Returns:
Self: The unique object instance.
"""
key = cls, cls.unique_hash(*args, **kwargs)
cache, statement, obj = cls._check_uniqueness(
getattr(session, "_unique_cache", None),
session,
key,
*args,
**kwargs,
)
if obj:
return obj
with cls._prevent_autoflush(session):
if (obj := (await session.execute(statement)).scalar_one_or_none()) is None:
session.add(obj := cls(*args, **kwargs))
cache[key] = obj
return obj
@classmethod
def as_unique_sync(
cls,
session: "Union[Session, scoped_session[Session]]",
*args: Any,
**kwargs: Any,
) -> Self:
"""Instantiate and return a unique object within the provided session based on the given arguments.
If an object with the same unique identifier already exists in the session, it is returned from the cache.
Args:
session (Session | scoped_session[Session]): SQLAlchemy sync session
*args (Any): Values used to instantiate the instance if no duplicate exists
**kwargs (Any): Values used to instantiate the instance if no duplicate exists
Returns:
Self: The unique object instance.
"""
key = cls, cls.unique_hash(*args, **kwargs)
cache, statement, obj = cls._check_uniqueness(
getattr(session, "_unique_cache", None),
session,
key,
*args,
**kwargs,
)
if obj:
return obj
with cls._prevent_autoflush(session):
if (obj := session.execute(statement).scalar_one_or_none()) is None:
session.add(obj := cls(*args, **kwargs))
cache[key] = obj
return obj
@classmethod
def unique_hash(cls, *args: Any, **kwargs: Any) -> "Hashable":
"""Generate a unique key based on the provided arguments.
This method should be implemented in the subclass.
Args:
*args (Any): Values passed to the alternate classmethod constructors
**kwargs (Any): Values passed to the alternate classmethod constructors
Raises:
NotImplementedError: If not implemented in the subclass.
Returns:
Hashable: Any hashable object.
"""
msg = "Implement this in subclass"
raise NotImplementedError(msg)
@classmethod
def unique_filter(cls, *args: Any, **kwargs: Any) -> "ColumnElement[bool]":
"""Generate a filter condition for ensuring uniqueness.
This method should be implemented in the subclass.
Args:
*args (Any): Values passed to the alternate classmethod constructors
**kwargs (Any): Values passed to the alternate classmethod constructors
Raises:
NotImplementedError: If not implemented in the subclass.
Returns:
ColumnElement[bool]: Filter condition to establish the uniqueness.
"""
msg = "Implement this in subclass"
raise NotImplementedError(msg)
python-advanced-alchemy-1.4.1/advanced_alchemy/mixins/uuid.py 0000664 0000000 0000000 00000002401 15003544734 0024321 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING
from uuid import UUID, uuid4
from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
from advanced_alchemy.mixins.sentinel import SentinelMixin
from advanced_alchemy.types import UUID_UTILS_INSTALLED
if UUID_UTILS_INSTALLED and not TYPE_CHECKING:
from uuid_utils.compat import ( # type: ignore[no-redef,unused-ignore] # pyright: ignore[reportMissingImports]
uuid4,
uuid6,
uuid7,
)
else:
from uuid import uuid4 # type: ignore[no-redef,unused-ignore]
uuid6 = uuid4 # type: ignore[assignment, unused-ignore]
uuid7 = uuid4 # type: ignore[assignment, unused-ignore]
@declarative_mixin
class UUIDPrimaryKey(SentinelMixin):
"""UUID Primary Key Field Mixin."""
id: Mapped[UUID] = mapped_column(default=uuid4, primary_key=True)
"""UUID Primary key column."""
@declarative_mixin
class UUIDv6PrimaryKey(SentinelMixin):
"""UUID v6 Primary Key Field Mixin."""
id: Mapped[UUID] = mapped_column(default=uuid6, primary_key=True)
"""UUID Primary key column."""
@declarative_mixin
class UUIDv7PrimaryKey(SentinelMixin):
"""UUID v7 Primary Key Field Mixin."""
id: Mapped[UUID] = mapped_column(default=uuid7, primary_key=True)
"""UUID Primary key column."""
python-advanced-alchemy-1.4.1/advanced_alchemy/operations.py 0000664 0000000 0000000 00000002422 15003544734 0024232 0 ustar 00root root 0000000 0000000 """Advanced database operations for SQLAlchemy.
This module provides high-performance database operations that extend beyond basic CRUD
functionality. It implements specialized database operations optimized for bulk data
handling and schema management.
The operations module is designed to work seamlessly with SQLAlchemy Core and ORM,
providing efficient implementations for common database operations patterns.
Features
--------
- Table merging and upsert operations
- Dynamic table creation from SELECT statements
- Bulk data import/export operations
- Optimized copy operations for PostgreSQL
- Transaction-safe batch operations
Todo:
-----
- Implement merge operations with customizable conflict resolution
- Add CTAS (Create Table As Select) functionality
- Implement bulk copy operations (COPY TO/FROM) for PostgreSQL
- Add support for temporary table operations
- Implement materialized view operations
Notes:
------
This module is designed to be database-agnostic where possible, with specialized
optimizations for specific database backends where appropriate.
See Also:
---------
- :mod:`sqlalchemy.sql.expression` : SQLAlchemy Core expression language
- :mod:`sqlalchemy.orm` : SQLAlchemy ORM functionality
- :mod:`advanced_alchemy.extensions` : Additional database extensions
"""
python-advanced-alchemy-1.4.1/advanced_alchemy/py.typed 0000664 0000000 0000000 00000000000 15003544734 0023162 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/repository/ 0000775 0000000 0000000 00000000000 15003544734 0023714 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/repository/__init__.py 0000664 0000000 0000000 00000003025 15003544734 0026025 0 ustar 00root root 0000000 0000000 from advanced_alchemy.exceptions import ErrorMessages
from advanced_alchemy.repository._async import (
SQLAlchemyAsyncQueryRepository,
SQLAlchemyAsyncRepository,
SQLAlchemyAsyncRepositoryProtocol,
SQLAlchemyAsyncSlugRepository,
SQLAlchemyAsyncSlugRepositoryProtocol,
)
from advanced_alchemy.repository._sync import (
SQLAlchemySyncQueryRepository,
SQLAlchemySyncRepository,
SQLAlchemySyncRepositoryProtocol,
SQLAlchemySyncSlugRepository,
SQLAlchemySyncSlugRepositoryProtocol,
)
from advanced_alchemy.repository._util import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
FilterableRepository,
FilterableRepositoryProtocol,
LoadSpec,
get_instrumented_attr,
model_from_dict,
)
from advanced_alchemy.repository.typing import ModelOrRowMappingT, ModelT, OrderingPair
from advanced_alchemy.utils.dataclass import Empty, EmptyType
__all__ = (
"DEFAULT_ERROR_MESSAGE_TEMPLATES",
"Empty",
"EmptyType",
"ErrorMessages",
"FilterableRepository",
"FilterableRepositoryProtocol",
"LoadSpec",
"ModelOrRowMappingT",
"ModelT",
"OrderingPair",
"SQLAlchemyAsyncQueryRepository",
"SQLAlchemyAsyncRepository",
"SQLAlchemyAsyncRepositoryProtocol",
"SQLAlchemyAsyncSlugRepository",
"SQLAlchemyAsyncSlugRepositoryProtocol",
"SQLAlchemySyncQueryRepository",
"SQLAlchemySyncRepository",
"SQLAlchemySyncRepositoryProtocol",
"SQLAlchemySyncSlugRepository",
"SQLAlchemySyncSlugRepositoryProtocol",
"get_instrumented_attr",
"model_from_dict",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/repository/_async.py 0000664 0000000 0000000 00000306363 15003544734 0025555 0 ustar 00root root 0000000 0000000 import random
import string
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
Optional,
Protocol,
Union,
cast,
runtime_checkable,
)
from sqlalchemy import (
Delete,
Result,
Row,
Select,
TextClause,
Update,
any_,
delete,
over,
select,
text,
update,
)
from sqlalchemy import func as sql_func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import StatementFilter, StatementTypeT
from advanced_alchemy.repository._util import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
FilterableRepository,
FilterableRepositoryProtocol,
LoadSpec,
get_abstract_loader_options,
get_instrumented_attr,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, T
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
if TYPE_CHECKING:
from sqlalchemy.engine.interfaces import _CoreSingleExecuteParams # pyright: ignore[reportPrivateUsage]
DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
@runtime_checkable
class SQLAlchemyAsyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Base Protocol"""
id_attribute: str
match_fields: Optional[Union[list[str], str]] = None
statement: Select[tuple[ModelT]]
session: Union[AsyncSession, async_scoped_session[AsyncSession]]
auto_expunge: bool
auto_refresh: bool
auto_commit: bool
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None
error_messages: Optional[ErrorMessages] = None
wrap_exceptions: bool = True
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None: ...
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any: ...
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT: ...
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT: ...
async def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT: ...
async def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]: ...
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> Sequence[ModelT]: ...
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
sanity_check: bool = True,
**kwargs: Any,
) -> Sequence[ModelT]: ...
async def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool: ...
async def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> ModelT: ...
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[ModelT]: ...
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
async def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> int: ...
async def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
async def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> list[ModelT]: ...
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]: ...
async def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
async def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> list[ModelT]: ...
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
count_with_window_function: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]: ...
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
**kwargs: Any,
) -> list[ModelT]: ...
@classmethod
async def check_health(cls, session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> bool: ...
@runtime_checkable
class SQLAlchemyAsyncSlugRepositoryProtocol(SQLAlchemyAsyncRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Protocol for SQLAlchemy repositories that support slug-based operations.
Extends the base repository protocol to add slug-related functionality.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
"""
async def get_by_slug(
self,
slug: str,
*,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Get a model instance by its slug.
Args:
slug: The slug value to search for.
error_messages: Optional custom error message templates.
load: Specification for eager loading of relationships.
execution_options: Options for statement execution.
**kwargs: Additional filtering criteria.
Returns:
ModelT | None: The found model instance or None if not found.
"""
...
async def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Generate a unique slug for a given value.
Args:
value_to_slugify: The string to convert to a slug.
**kwargs: Additional parameters for slug generation.
Returns:
str: A unique slug derived from the input value.
"""
...
class SQLAlchemyAsyncRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT], FilterableRepository[ModelT]):
"""Async SQLAlchemy repository implementation.
Provides a complete implementation of async database operations using SQLAlchemy,
including CRUD operations, filtering, and relationship loading.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
.. seealso::
:class:`~advanced_alchemy.repository._util.FilterableRepository`
"""
id_attribute: str = "id"
"""Name of the unique identifier for the model."""
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
error_messages: Optional[ErrorMessages] = None
"""Default error messages for the repository."""
wrap_exceptions: bool = True
"""Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised."""
inherit_lazy_relationships: bool = True
"""Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
merge_loader_options: bool = True
"""Merges the default loader options with the loader options specified in the ``load`` argument. This is useful for when you want to totally
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
match_fields: Optional[Union[list[str], str]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: bool = False
"""Optionally apply the ``unique()`` method to results before returning.
This is useful for certain SQLAlchemy uses cases such as applying ``contains_eager`` to a query containing a one-to-many relationship
"""
count_with_window_function: bool = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query.
"""
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
wrap_exceptions: bool = True,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
**kwargs: Any,
) -> None:
"""Repository for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
load: Set default relationships to be loaded
execution_options: Set default execution options
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
**kwargs: Additional arguments.
"""
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.order_by = order_by
self.session = session
self.error_messages = self._get_error_messages(
error_messages=error_messages, default_messages=self.error_messages
)
self.wrap_exceptions = wrap_exceptions
self._uniquify = uniquify if uniquify is not None else self.uniquify
self.count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options(
loader_options=load if load is not None else self.loader_options,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
execution_options = execution_options if execution_options is not None else self.execution_options
self._default_execution_options = execution_options or {}
self.statement = select(self.model_type) if statement is None else statement
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
"""Get the uniquify value, preferring the method parameter over instance setting.
Args:
uniquify: Optional override for the uniquify setting.
Returns:
bool: The uniquify value to use.
"""
return bool(uniquify) if uniquify is not None else self._uniquify
@staticmethod
def _get_error_messages(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
if default_messages == Empty:
default_messages = None
messages = DEFAULT_ERROR_MESSAGE_TEMPLATES
if default_messages and isinstance(default_messages, dict):
messages.update(default_messages)
if error_messages:
messages.update(cast("ErrorMessages", error_messages))
return messages
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT:
"""Raise :exc:`advanced_alchemy.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item (:class:`T `) to be tested for existence.
Raises:
NotFoundError: If ``item_or_none`` is ``None``
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
def _get_execution_options(
self,
execution_options: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
if execution_options is None:
return self._default_execution_options
return execution_options
def _get_loader_options(
self,
loader_options: Optional[LoadSpec],
) -> Union[tuple[list[_AbstractLoad], bool], tuple[None, bool]]:
if loader_options is None:
# use the defaults set at initialization
return self._default_loader_options, self._loader_options_have_wildcards or self._uniquify
return get_abstract_loader_options(
loader_options=loader_options,
default_loader_options=self._default_loader_options,
default_options_have_wildcards=self._loader_options_have_wildcards or self._uniquify,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
async def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT:
"""Add ``data`` to the collection.
Args:
data: Instance to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
The added instance.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
instance = await self._attach_to_session(data)
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(instance, auto_refresh=auto_refresh)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]:
"""Add many `data` to the collection.
Args:
data: list of Instances to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
The added instances.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
self.session.add_all(data)
await self._flush_or_commit(auto_commit=auto_commit)
for datum in data:
self._expunge(datum, auto_expunge=auto_expunge)
return data
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Delete instance identified by ``item_id``.
Args:
item_id: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The deleted instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
instance = await self.get(
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
)
await self.session.delete(instance)
await self._flush_or_commit(auto_commit=auto_commit)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Delete instance identified by `item_id`.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The deleted instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = get_instrumented_attr(
self.model_type,
id_attribute if id_attribute is not None else self.id_attribute,
)
instances: list[ModelT] = []
if self._prefer_any:
chunk_size = len(item_ids) + 1
chunk_size = self._get_insertmanyvalues_max_parameters(chunk_size)
for idx in range(0, len(item_ids), chunk_size):
chunk = item_ids[idx : min(idx + chunk_size, len(item_ids))]
if self._dialect.delete_executemany_returning:
instances.extend(
await self.session.scalars(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
else:
instances.extend(
await self.session.scalars(
self._get_delete_many_statement(
statement_type="select",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
await self.session.execute(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
@staticmethod
def _get_insertmanyvalues_max_parameters(chunk_size: Optional[int] = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Delete instances specified by referenced kwargs and filters.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Arguments to apply to a delete
Raises:
RepositoryError: If the number of deleted rows does not match the number of selected instances
Returns:
The deleted instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
model_type = self.model_type
statement = self._get_base_stmt(
statement=delete(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs)
statement = self._apply_filters(*filters, statement=statement, apply_pagination=False)
instances: list[ModelT] = []
if self._dialect.delete_executemany_returning:
instances.extend(await self.session.scalars(statement.returning(model_type)))
else:
instances.extend(
await self.list(
*filters,
load=load,
execution_options=execution_options,
auto_expunge=auto_expunge,
**kwargs,
),
)
result = await self.session.execute(statement)
row_count = getattr(result, "rowcount", -2)
if sanity_check and row_count >= 0 and len(instances) != row_count: # pyright: ignore
# backends will return a -1 if they can't determine impacted rowcount
# only compare length of selected instances to results if it's >= 0
await self.session.rollback()
raise RepositoryError(detail="Deleted count does not match fetched count. Rollback issued.")
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
async def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
True if the instance was found. False if not found..
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
existing = await self.count(
*filters,
load=load,
execution_options=execution_options,
error_messages=error_messages,
**kwargs,
)
return existing > 0
@staticmethod
def _get_base_stmt(
*,
statement: StatementTypeT,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> StatementTypeT:
"""Get base statement with options applied.
Args:
statement: The select statement to modify
loader_options: Options for loading relationships
execution_options: Options for statement execution
Returns:
Modified select statement
"""
if loader_options:
statement = cast("StatementTypeT", statement.options(*loader_options))
if execution_options:
statement = cast("StatementTypeT", statement.execution_options(**execution_options))
return statement
def _get_delete_many_statement(
self,
*,
model_type: type[ModelT],
id_attribute: InstrumentedAttribute[Any],
id_chunk: list[Any],
supports_returning: bool,
statement_type: Literal["delete", "select"] = "delete",
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Select[tuple[ModelT]], Delete, ReturningDelete[tuple[ModelT]]]:
# Base statement is static
statement = self._get_base_stmt(
statement=delete(model_type) if statement_type == "delete" else select(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
if execution_options:
statement = statement.execution_options(**execution_options)
if supports_returning and statement_type != "select":
statement = cast("ReturningDelete[tuple[ModelT]]", statement.returning(model_type)) # type: ignore[union-attr,assignment] # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
if self._prefer_any:
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
async def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Get instance identified by `item_id`.
Args:
item_id: Identifier of the instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The retrieved instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = id_attribute if id_attribute is not None else self.id_attribute
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)])
instance = (await self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
"""Get instance identified by ``kwargs``.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
instance = (await self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance or None
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
instance = cast(
"Result[tuple[ModelT]]",
(await self._execute(statement, uniquify=loader_options_have_wildcard)),
).scalar_one_or_none()
if instance:
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Union[bool, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` or create if it doesn't exist.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, automatically perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be created.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = await self.get_one_or_none(
*filters,
**match_filter,
load=load,
execution_options=execution_options,
)
if not existing:
return (
await self.add(
self.model_type(**kwargs),
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
),
True,
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, False
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` and update the model if the arguments are different.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be updated.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = await self.get_one(*filters, **match_filter, load=load, execution_options=execution_options)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, updated
async def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
"""Get the count of records returned by a query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
results = await self._execute(
statement=self._get_count_stmt(
statement=statement, loader_options=loader_options, execution_options=execution_options
),
uniquify=loader_options_have_wildcard,
)
return cast("int", results.scalar_one())
async def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.
Args:
data: An instance that should have a value for `self.id_attribute` that
exists in the collection.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
item_id = self.get_id_attribute_value(
data,
id_attribute=id_attribute,
)
# this will raise for not found, and will put the item in the session
await self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options)
# this will merge the inbound data to the instance we just put in the session
instance = await self._attach_to_session(data, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
"""Update one or more instances with the attribute values present on `data`.
This function has an optimized bulk update based on the configured SQL dialect:
- For backends supporting `RETURNING` with `executemany`, a single bulk update with returning clause is executed.
- For other backends, it does a bulk update and then returns the updated data after a refresh.
Args:
data: A list of instances to update. Each should have a value for `self.id_attribute` that exists in the
collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore[misc]
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options = self._get_loader_options(load)[0]
supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle"
statement = self._get_update_many_statement(
self.model_type,
supports_returning,
loader_options=loader_options,
execution_options=execution_options,
)
if supports_returning:
instances = list(
await self.session.scalars(
statement,
cast("_CoreSingleExecuteParams", data_to_update), # this is not correct but the only way
# currently to deal with an SQLAlchemy typing issue. See
# https://github.com/sqlalchemy/sqlalchemy/discussions/9925
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
await self.session.execute(statement, data_to_update, execution_options=execution_options)
await self._flush_or_commit(auto_commit=auto_commit)
return data
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Union[list[_AbstractLoad], None],
execution_options: Union[dict[str, Any], None],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
# Base update statement is static
statement = self._get_base_stmt(
statement=update(table=model_type), loader_options=loader_options, execution_options=execution_options
)
if supports_returning:
return statement.returning(model_type)
return statement
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if self._dialect.name in {"spanner", "spanner+spanner"} or not count_with_window_function:
return await self._list_and_count_basic(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
order_by=order_by,
error_messages=error_messages,
**kwargs,
)
return await self._list_and_count_window(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
error_messages=error_messages,
order_by=order_by,
**kwargs,
)
def _expunge(self, instance: ModelT, auto_expunge: Optional[bool]) -> None:
if auto_expunge is None:
auto_expunge = self.auto_expunge
return self.session.expunge(instance) if auto_expunge else None
async def _flush_or_commit(self, auto_commit: Optional[bool]) -> None:
if auto_commit is None:
auto_commit = self.auto_commit
return await self.session.commit() if auto_commit else await self.session.flush()
async def _refresh(
self,
instance: ModelT,
auto_refresh: Optional[bool],
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
) -> None:
if auto_refresh is None:
auto_refresh = self.auto_refresh
return (
await self.session.refresh(
instance=instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
if auto_refresh
else None
)
async def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: List[OrderingPair] | OrderingPair | None = None,
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = await self._execute(
statement.add_columns(over(sql_func.count())), uniquify=loader_options_have_wildcard
)
count: int = 0
instances: list[ModelT] = []
for i, (instance, count_value) in enumerate(result):
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
if i == 0:
count = count_value
return instances, count
async def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
count_result = await self.session.execute(
self._get_count_stmt(
statement,
loader_options=loader_options,
execution_options=execution_options,
),
)
count = count_result.scalar_one()
result = await self._execute(statement, uniquify=loader_options_have_wildcard)
instances: list[ModelT] = []
for (instance,) in result:
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
return instances, count
@staticmethod
def _get_count_stmt(
statement: Select[tuple[ModelT]],
loader_options: Optional[list[_AbstractLoad]], # noqa: ARG004
execution_options: Optional[dict[str, Any]], # noqa: ARG004
) -> Select[tuple[int]]:
# Count statement transformations are static
return (
statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True)
.limit(None)
.offset(None)
.order_by(None)
)
async def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Modify or create instance.
Updates instance with the attribute values present on `data`, or creates a new instance if
one doesn't exist.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated or created instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: getattr(data, field_name, None)
for field_name in match_fields
if getattr(data, field_name, None) is not None
}
elif getattr(data, self.id_attribute, None) is not None:
match_filter = {self.id_attribute: getattr(data, self.id_attribute, None)}
else:
match_filter = data.to_dict(exclude={self.id_attribute})
existing = await self.get_one_or_none(load=load, execution_options=execution_options, **match_filter)
if not existing:
return await self.add(
data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
instance = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
"""Modify or create multiple instances.
Update instances with the attribute values present on `data`, or create a new instance if
one doesn't exist.
!!! tip
In most cases, you will want to set `match_fields` to the combination of attributes, excluded the primary key, that define uniqueness for a row.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on ``data`` named as value of
:attr:`id_attribute`.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements
match_fields: a list of keys to use to match the existing model. When
empty, automatically uses ``self.id_attribute`` (`id` by default) to match .
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated or created instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
instances: list[ModelT] = []
data_to_update: list[ModelT] = []
data_to_insert: list[ModelT] = []
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
match_filter: list[Union[StatementFilter, ColumnElement[bool]]] = []
if match_fields:
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
]
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
existing_objs = await self.list(
*match_filter,
load=load,
execution_options=execution_options,
auto_expunge=False,
)
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = list(
{getattr(datum, field_name) for datum in existing_objs if datum}, # ensure the list is unique
)
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
if getattr(datum, self.id_attribute, None) in existing_ids:
data_to_update.append(datum)
else:
data_to_insert.append(datum)
if data_to_insert:
instances.extend(
await self.add_many(data_to_insert, auto_commit=False, auto_expunge=False),
)
if data_to_update:
instances.extend(
await self.update_many(
data_to_update,
auto_commit=False,
auto_expunge=False,
load=load,
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]:
return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None]
def _get_match_fields(
self,
match_fields: Optional[Union[list[str], str]] = None,
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
def _merge_on_match_fields(
self,
data: list[ModelT],
existing_data: list[ModelT],
match_fields: Optional[Union[list[str], str]] = None,
) -> list[ModelT]:
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for datum in data:
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(datum, self.id_attribute, getattr(existing_datum, self.id_attribute))
return data
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
"""Get a list of instances, optionally filtered.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = await self._execute(statement, uniquify=loader_options_have_wildcard)
instances = list(result.scalars())
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return cast("list[ModelT]", instances)
@classmethod
async def check_health(cls, session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> bool:
"""Perform a health check on the database.
Args:
session: through which we run a check statement
Returns:
``True`` if healthy.
"""
with wrap_sqlalchemy_exception():
return ( # type: ignore[no-any-return]
await session.execute(cls._get_health_check_statement(session))
).scalar_one() == 1
@staticmethod
def _get_health_check_statement(session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> TextClause:
if session.bind and session.bind.dialect.name == "oracle":
return text("SELECT 1 FROM DUAL")
return text("SELECT 1")
async def _attach_to_session(
self, model: ModelT, strategy: Literal["add", "merge"] = "add", load: bool = True
) -> ModelT:
"""Attach detached instance to the session.
Args:
model: The instance to be attached to the session.
strategy: How the instance should be attached.
- "add": New instance added to session
- "merge": Instance merged with existing, or new one added.
load: Boolean, when False, merge switches into
a "high performance" mode which causes it to forego emitting history
events as well as all database access. This flag is used for
cases such as transferring graphs of objects into a session
from a second level cache, or to transfer just-loaded objects
into the session owned by a worker thread or process
without re-querying the database.
Raises:
ValueError: If `strategy` is not one of the expected values.
Returns:
Instance attached to the session - if `"merge"` strategy, may not be same instance
that was provided.
"""
if strategy == "add":
self.session.add(model)
return model
if strategy == "merge":
return await self.session.merge(model, load=load)
msg = "Unexpected value for `strategy`, must be `'add'` or `'merge'`" # type: ignore[unreachable]
raise ValueError(msg)
async def _execute(
self,
statement: Select[Any],
uniquify: bool = False,
) -> Result[Any]:
result = await self.session.execute(statement)
if uniquify or self._uniquify:
result = result.unique()
return result
class SQLAlchemyAsyncSlugRepository(
SQLAlchemyAsyncRepository[ModelT],
SQLAlchemyAsyncSlugRepositoryProtocol[ModelT],
):
"""Extends the repository to include slug model features.."""
async def get_by_slug(
self,
slug: str,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Select record by slug value.
Returns:
The model instance or None if not found.
"""
return await self.get_one_or_none(
slug=slug,
load=load,
execution_options=execution_options,
error_messages=error_messages,
uniquify=uniquify,
)
async def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if await self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
async def _is_slug_unique(
self,
slug: str,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool:
return await self.exists(slug=slug, load=load, execution_options=execution_options, **kwargs) is False
class SQLAlchemyAsyncQueryRepository:
"""SQLAlchemy Query Repository.
This is a loosely typed helper to query for when you need to select data in ways that don't align to the normal repository pattern.
"""
error_messages: Optional[ErrorMessages] = None
def __init__(
self,
*,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
error_messages: Optional[ErrorMessages] = None,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
session: Session managing the unit-of-work for the operation.
error_messages: A set of error messages to use for operations.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.session = session
self.error_messages = error_messages
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
async def get_one(
self,
statement: Select[tuple[Any]],
**kwargs: Any,
) -> Row[Any]:
"""Get instance identified by ``kwargs``.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
instance = (await self.execute(statement)).scalar_one_or_none()
return self.check_not_found(instance)
async def get_one_or_none(
self,
statement: Select[Any],
**kwargs: Any,
) -> Optional[Row[Any]]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance or None
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
instance = (await self.execute(statement)).scalar_one_or_none()
return instance or None
async def count(self, statement: Select[Any], **kwargs: Any) -> int:
"""Get the count of records returned by a query.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(
None,
)
statement = self._filter_statement_by_kwargs(statement, **kwargs)
results = await self.execute(statement)
return results.scalar_one() # type: ignore
async def list_and_count(
self,
statement: Select[Any],
count_with_window_function: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query.
count_with_window_function: Force list and count to use two queries instead of an analytical window function.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
if self._dialect.name in {"spanner", "spanner+spanner"} or count_with_window_function:
return await self._list_and_count_basic(statement=statement, **kwargs)
return await self._list_and_count_window(statement=statement, **kwargs)
async def _list_and_count_window(
self,
statement: Select[Any],
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = statement.add_columns(over(sql_func.count(text("1"))))
statement = self._filter_statement_by_kwargs(statement, **kwargs)
result = await self.execute(statement)
count: int = 0
instances: list[Row[Any]] = []
for i, (instance, count_value) in enumerate(result):
instances.append(instance)
if i == 0:
count = count_value
return instances, count
@staticmethod
def _get_count_stmt(statement: Select[Any]) -> Select[Any]:
return statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None) # pyright: ignore[reportUnknownVariable]
async def _list_and_count_basic(
self,
statement: Select[Any],
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query. .
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
count_result = await self.session.execute(self._get_count_stmt(statement))
count = count_result.scalar_one()
result = await self.execute(statement)
instances: list[Row[Any]] = []
for (instance,) in result:
instances.append(instance)
return instances, count
async def list(self, statement: Select[Any], **kwargs: Any) -> list[Row[Any]]:
"""Get a list of instances, optionally filtered.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
result = await self.execute(statement)
return list(result.all())
def _filter_statement_by_kwargs(
self,
statement: Select[Any],
/,
**kwargs: Any,
) -> Select[Any]:
"""Filter the collection by kwargs.
Args:
statement: statement to filter
**kwargs: key/value pairs such that objects remaining in the statement after filtering
have the property that their attribute named `key` has value equal to `value`.
Returns:
The filtered statement.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
return statement.filter_by(**kwargs)
# the following is all sqlalchemy implementation detail, and shouldn't be directly accessed
@staticmethod
def check_not_found(item_or_none: Optional[T]) -> T:
"""Raise :class:`NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item to be tested for existence.
Raises:
NotFoundError: If ``item_or_none`` is ``None``
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
async def execute(
self,
statement: Union[
ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Update, Delete, Select[Any]
],
) -> Result[Any]:
return await self.session.execute(statement)
python-advanced-alchemy-1.4.1/advanced_alchemy/repository/_sync.py 0000664 0000000 0000000 00000304574 15003544734 0025416 0 ustar 00root root 0000000 0000000 # Do not edit this file directly. It has been autogenerated from
# advanced_alchemy/repository/_async.py
import random
import string
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
Optional,
Protocol,
Union,
cast,
runtime_checkable,
)
from sqlalchemy import (
Delete,
Result,
Row,
Select,
TextClause,
Update,
any_,
delete,
over,
select,
text,
update,
)
from sqlalchemy import func as sql_func
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import StatementFilter, StatementTypeT
from advanced_alchemy.repository._util import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
FilterableRepository,
FilterableRepositoryProtocol,
LoadSpec,
get_abstract_loader_options,
get_instrumented_attr,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, T
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
if TYPE_CHECKING:
from sqlalchemy.engine.interfaces import _CoreSingleExecuteParams # pyright: ignore[reportPrivateUsage]
DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
@runtime_checkable
class SQLAlchemySyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Base Protocol"""
id_attribute: str
match_fields: Optional[Union[list[str], str]] = None
statement: Select[tuple[ModelT]]
session: Union[Session, scoped_session[Session]]
auto_expunge: bool
auto_refresh: bool
auto_commit: bool
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None
error_messages: Optional[ErrorMessages] = None
wrap_exceptions: bool = True
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[Session, scoped_session[Session]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None: ...
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any: ...
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT: ...
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT: ...
def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT: ...
def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]: ...
def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> Sequence[ModelT]: ...
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
sanity_check: bool = True,
**kwargs: Any,
) -> Sequence[ModelT]: ...
def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool: ...
def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> ModelT: ...
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[ModelT]: ...
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> int: ...
def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> list[ModelT]: ...
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]: ...
def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> list[ModelT]: ...
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
count_with_window_function: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]: ...
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
**kwargs: Any,
) -> list[ModelT]: ...
@classmethod
def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool: ...
@runtime_checkable
class SQLAlchemySyncSlugRepositoryProtocol(SQLAlchemySyncRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Protocol for SQLAlchemy repositories that support slug-based operations.
Extends the base repository protocol to add slug-related functionality.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
"""
def get_by_slug(
self,
slug: str,
*,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Get a model instance by its slug.
Args:
slug: The slug value to search for.
error_messages: Optional custom error message templates.
load: Specification for eager loading of relationships.
execution_options: Options for statement execution.
**kwargs: Additional filtering criteria.
Returns:
ModelT | None: The found model instance or None if not found.
"""
...
def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Generate a unique slug for a given value.
Args:
value_to_slugify: The string to convert to a slug.
**kwargs: Additional parameters for slug generation.
Returns:
str: A unique slug derived from the input value.
"""
...
class SQLAlchemySyncRepository(SQLAlchemySyncRepositoryProtocol[ModelT], FilterableRepository[ModelT]):
"""Async SQLAlchemy repository implementation.
Provides a complete implementation of async database operations using SQLAlchemy,
including CRUD operations, filtering, and relationship loading.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
.. seealso::
:class:`~advanced_alchemy.repository._util.FilterableRepository`
"""
id_attribute: str = "id"
"""Name of the unique identifier for the model."""
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
error_messages: Optional[ErrorMessages] = None
"""Default error messages for the repository."""
wrap_exceptions: bool = True
"""Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised."""
inherit_lazy_relationships: bool = True
"""Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
merge_loader_options: bool = True
"""Merges the default loader options with the loader options specified in the ``load`` argument. This is useful for when you want to totally
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
match_fields: Optional[Union[list[str], str]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: bool = False
"""Optionally apply the ``unique()`` method to results before returning.
This is useful for certain SQLAlchemy uses cases such as applying ``contains_eager`` to a query containing a one-to-many relationship
"""
count_with_window_function: bool = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query.
"""
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[Session, scoped_session[Session]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
wrap_exceptions: bool = True,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
**kwargs: Any,
) -> None:
"""Repository for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
load: Set default relationships to be loaded
execution_options: Set default execution options
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
**kwargs: Additional arguments.
"""
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.order_by = order_by
self.session = session
self.error_messages = self._get_error_messages(
error_messages=error_messages, default_messages=self.error_messages
)
self.wrap_exceptions = wrap_exceptions
self._uniquify = uniquify if uniquify is not None else self.uniquify
self.count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options(
loader_options=load if load is not None else self.loader_options,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
execution_options = execution_options if execution_options is not None else self.execution_options
self._default_execution_options = execution_options or {}
self.statement = select(self.model_type) if statement is None else statement
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
"""Get the uniquify value, preferring the method parameter over instance setting.
Args:
uniquify: Optional override for the uniquify setting.
Returns:
bool: The uniquify value to use.
"""
return bool(uniquify) if uniquify is not None else self._uniquify
@staticmethod
def _get_error_messages(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
if default_messages == Empty:
default_messages = None
messages = DEFAULT_ERROR_MESSAGE_TEMPLATES
if default_messages and isinstance(default_messages, dict):
messages.update(default_messages)
if error_messages:
messages.update(cast("ErrorMessages", error_messages))
return messages
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT:
"""Raise :exc:`advanced_alchemy.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item (:class:`T `) to be tested for existence.
Raises:
NotFoundError: If ``item_or_none`` is ``None``
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
def _get_execution_options(
self,
execution_options: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
if execution_options is None:
return self._default_execution_options
return execution_options
def _get_loader_options(
self,
loader_options: Optional[LoadSpec],
) -> Union[tuple[list[_AbstractLoad], bool], tuple[None, bool]]:
if loader_options is None:
# use the defaults set at initialization
return self._default_loader_options, self._loader_options_have_wildcards or self._uniquify
return get_abstract_loader_options(
loader_options=loader_options,
default_loader_options=self._default_loader_options,
default_options_have_wildcards=self._loader_options_have_wildcards or self._uniquify,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT:
"""Add ``data`` to the collection.
Args:
data: Instance to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
The added instance.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
instance = self._attach_to_session(data)
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(instance, auto_refresh=auto_refresh)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]:
"""Add many `data` to the collection.
Args:
data: list of Instances to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
The added instances.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
self.session.add_all(data)
self._flush_or_commit(auto_commit=auto_commit)
for datum in data:
self._expunge(datum, auto_expunge=auto_expunge)
return data
def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Delete instance identified by ``item_id``.
Args:
item_id: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The deleted instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
instance = self.get(
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
)
self.session.delete(instance)
self._flush_or_commit(auto_commit=auto_commit)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Delete instance identified by `item_id`.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The deleted instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = get_instrumented_attr(
self.model_type,
id_attribute if id_attribute is not None else self.id_attribute,
)
instances: list[ModelT] = []
if self._prefer_any:
chunk_size = len(item_ids) + 1
chunk_size = self._get_insertmanyvalues_max_parameters(chunk_size)
for idx in range(0, len(item_ids), chunk_size):
chunk = item_ids[idx : min(idx + chunk_size, len(item_ids))]
if self._dialect.delete_executemany_returning:
instances.extend(
self.session.scalars(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
else:
instances.extend(
self.session.scalars(
self._get_delete_many_statement(
statement_type="select",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
self.session.execute(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
)
self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
@staticmethod
def _get_insertmanyvalues_max_parameters(chunk_size: Optional[int] = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Delete instances specified by referenced kwargs and filters.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Arguments to apply to a delete
Raises:
RepositoryError: If the number of deleted rows does not match the number of selected instances
Returns:
The deleted instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
model_type = self.model_type
statement = self._get_base_stmt(
statement=delete(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs)
statement = self._apply_filters(*filters, statement=statement, apply_pagination=False)
instances: list[ModelT] = []
if self._dialect.delete_executemany_returning:
instances.extend(self.session.scalars(statement.returning(model_type)))
else:
instances.extend(
self.list(
*filters,
load=load,
execution_options=execution_options,
auto_expunge=auto_expunge,
**kwargs,
),
)
result = self.session.execute(statement)
row_count = getattr(result, "rowcount", -2)
if sanity_check and row_count >= 0 and len(instances) != row_count: # pyright: ignore
# backends will return a -1 if they can't determine impacted rowcount
# only compare length of selected instances to results if it's >= 0
self.session.rollback()
raise RepositoryError(detail="Deleted count does not match fetched count. Rollback issued.")
self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
True if the instance was found. False if not found..
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
existing = self.count(
*filters,
load=load,
execution_options=execution_options,
error_messages=error_messages,
**kwargs,
)
return existing > 0
@staticmethod
def _get_base_stmt(
*,
statement: StatementTypeT,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> StatementTypeT:
"""Get base statement with options applied.
Args:
statement: The select statement to modify
loader_options: Options for loading relationships
execution_options: Options for statement execution
Returns:
Modified select statement
"""
if loader_options:
statement = cast("StatementTypeT", statement.options(*loader_options))
if execution_options:
statement = cast("StatementTypeT", statement.execution_options(**execution_options))
return statement
def _get_delete_many_statement(
self,
*,
model_type: type[ModelT],
id_attribute: InstrumentedAttribute[Any],
id_chunk: list[Any],
supports_returning: bool,
statement_type: Literal["delete", "select"] = "delete",
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Select[tuple[ModelT]], Delete, ReturningDelete[tuple[ModelT]]]:
# Base statement is static
statement = self._get_base_stmt(
statement=delete(model_type) if statement_type == "delete" else select(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
if execution_options:
statement = statement.execution_options(**execution_options)
if supports_returning and statement_type != "select":
statement = cast("ReturningDelete[tuple[ModelT]]", statement.returning(model_type)) # type: ignore[union-attr,assignment] # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
if self._prefer_any:
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Get instance identified by `item_id`.
Args:
item_id: Identifier of the instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The retrieved instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = id_attribute if id_attribute is not None else self.id_attribute
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)])
instance = (self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
"""Get instance identified by ``kwargs``.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
instance = (self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance or None
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
instance = cast(
"Result[tuple[ModelT]]",
(self._execute(statement, uniquify=loader_options_have_wildcard)),
).scalar_one_or_none()
if instance:
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Union[bool, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` or create if it doesn't exist.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, automatically perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be created.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = self.get_one_or_none(
*filters,
**match_filter,
load=load,
execution_options=execution_options,
)
if not existing:
return (
self.add(
self.model_type(**kwargs),
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
),
True,
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = self._attach_to_session(existing, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, False
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` and update the model if the arguments are different.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be updated.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = self.get_one(*filters, **match_filter, load=load, execution_options=execution_options)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = self._attach_to_session(existing, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, updated
def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
"""Get the count of records returned by a query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
results = self._execute(
statement=self._get_count_stmt(
statement=statement, loader_options=loader_options, execution_options=execution_options
),
uniquify=loader_options_have_wildcard,
)
return cast("int", results.scalar_one())
def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.
Args:
data: An instance that should have a value for `self.id_attribute` that
exists in the collection.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
item_id = self.get_id_attribute_value(
data,
id_attribute=id_attribute,
)
# this will raise for not found, and will put the item in the session
self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options)
# this will merge the inbound data to the instance we just put in the session
instance = self._attach_to_session(data, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
"""Update one or more instances with the attribute values present on `data`.
This function has an optimized bulk update based on the configured SQL dialect:
- For backends supporting `RETURNING` with `executemany`, a single bulk update with returning clause is executed.
- For other backends, it does a bulk update and then returns the updated data after a refresh.
Args:
data: A list of instances to update. Each should have a value for `self.id_attribute` that exists in the
collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore[misc]
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options = self._get_loader_options(load)[0]
supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle"
statement = self._get_update_many_statement(
self.model_type,
supports_returning,
loader_options=loader_options,
execution_options=execution_options,
)
if supports_returning:
instances = list(
self.session.scalars(
statement,
cast("_CoreSingleExecuteParams", data_to_update), # this is not correct but the only way
# currently to deal with an SQLAlchemy typing issue. See
# https://github.com/sqlalchemy/sqlalchemy/discussions/9925
execution_options=execution_options,
),
)
self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
self.session.execute(statement, data_to_update, execution_options=execution_options)
self._flush_or_commit(auto_commit=auto_commit)
return data
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Union[list[_AbstractLoad], None],
execution_options: Union[dict[str, Any], None],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
# Base update statement is static
statement = self._get_base_stmt(
statement=update(table=model_type), loader_options=loader_options, execution_options=execution_options
)
if supports_returning:
return statement.returning(model_type)
return statement
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if self._dialect.name in {"spanner", "spanner+spanner"} or not count_with_window_function:
return self._list_and_count_basic(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
order_by=order_by,
error_messages=error_messages,
**kwargs,
)
return self._list_and_count_window(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
error_messages=error_messages,
order_by=order_by,
**kwargs,
)
def _expunge(self, instance: ModelT, auto_expunge: Optional[bool]) -> None:
if auto_expunge is None:
auto_expunge = self.auto_expunge
return self.session.expunge(instance) if auto_expunge else None
def _flush_or_commit(self, auto_commit: Optional[bool]) -> None:
if auto_commit is None:
auto_commit = self.auto_commit
return self.session.commit() if auto_commit else self.session.flush()
def _refresh(
self,
instance: ModelT,
auto_refresh: Optional[bool],
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
) -> None:
if auto_refresh is None:
auto_refresh = self.auto_refresh
return (
self.session.refresh(
instance=instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
if auto_refresh
else None
)
def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: List[OrderingPair] | OrderingPair | None = None,
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = self._execute(statement.add_columns(over(sql_func.count())), uniquify=loader_options_have_wildcard)
count: int = 0
instances: list[ModelT] = []
for i, (instance, count_value) in enumerate(result):
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
if i == 0:
count = count_value
return instances, count
def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
count_result = self.session.execute(
self._get_count_stmt(
statement,
loader_options=loader_options,
execution_options=execution_options,
),
)
count = count_result.scalar_one()
result = self._execute(statement, uniquify=loader_options_have_wildcard)
instances: list[ModelT] = []
for (instance,) in result:
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
return instances, count
@staticmethod
def _get_count_stmt(
statement: Select[tuple[ModelT]],
loader_options: Optional[list[_AbstractLoad]], # noqa: ARG004
execution_options: Optional[dict[str, Any]], # noqa: ARG004
) -> Select[tuple[int]]:
# Count statement transformations are static
return (
statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True)
.limit(None)
.offset(None)
.order_by(None)
)
def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Modify or create instance.
Updates instance with the attribute values present on `data`, or creates a new instance if
one doesn't exist.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated or created instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: getattr(data, field_name, None)
for field_name in match_fields
if getattr(data, field_name, None) is not None
}
elif getattr(data, self.id_attribute, None) is not None:
match_filter = {self.id_attribute: getattr(data, self.id_attribute, None)}
else:
match_filter = data.to_dict(exclude={self.id_attribute})
existing = self.get_one_or_none(load=load, execution_options=execution_options, **match_filter)
if not existing:
return self.add(
data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
instance = self._attach_to_session(existing, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
"""Modify or create multiple instances.
Update instances with the attribute values present on `data`, or create a new instance if
one doesn't exist.
!!! tip
In most cases, you will want to set `match_fields` to the combination of attributes, excluded the primary key, that define uniqueness for a row.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on ``data`` named as value of
:attr:`id_attribute`.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements
match_fields: a list of keys to use to match the existing model. When
empty, automatically uses ``self.id_attribute`` (`id` by default) to match .
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated or created instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
instances: list[ModelT] = []
data_to_update: list[ModelT] = []
data_to_insert: list[ModelT] = []
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
match_filter: list[Union[StatementFilter, ColumnElement[bool]]] = []
if match_fields:
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
]
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
existing_objs = self.list(
*match_filter,
load=load,
execution_options=execution_options,
auto_expunge=False,
)
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = list(
{getattr(datum, field_name) for datum in existing_objs if datum}, # ensure the list is unique
)
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
if getattr(datum, self.id_attribute, None) in existing_ids:
data_to_update.append(datum)
else:
data_to_insert.append(datum)
if data_to_insert:
instances.extend(
self.add_many(data_to_insert, auto_commit=False, auto_expunge=False),
)
if data_to_update:
instances.extend(
self.update_many(
data_to_update,
auto_commit=False,
auto_expunge=False,
load=load,
execution_options=execution_options,
),
)
self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]:
return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None]
def _get_match_fields(
self,
match_fields: Optional[Union[list[str], str]] = None,
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
def _merge_on_match_fields(
self,
data: list[ModelT],
existing_data: list[ModelT],
match_fields: Optional[Union[list[str], str]] = None,
) -> list[ModelT]:
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for datum in data:
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(datum, self.id_attribute, getattr(existing_datum, self.id_attribute))
return data
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
"""Get a list of instances, optionally filtered.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = self._execute(statement, uniquify=loader_options_have_wildcard)
instances = list(result.scalars())
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return cast("list[ModelT]", instances)
@classmethod
def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool:
"""Perform a health check on the database.
Args:
session: through which we run a check statement
Returns:
``True`` if healthy.
"""
with wrap_sqlalchemy_exception():
return ( # type: ignore[no-any-return]
session.execute(cls._get_health_check_statement(session))
).scalar_one() == 1
@staticmethod
def _get_health_check_statement(session: Union[Session, scoped_session[Session]]) -> TextClause:
if session.bind and session.bind.dialect.name == "oracle":
return text("SELECT 1 FROM DUAL")
return text("SELECT 1")
def _attach_to_session(self, model: ModelT, strategy: Literal["add", "merge"] = "add", load: bool = True) -> ModelT:
"""Attach detached instance to the session.
Args:
model: The instance to be attached to the session.
strategy: How the instance should be attached.
- "add": New instance added to session
- "merge": Instance merged with existing, or new one added.
load: Boolean, when False, merge switches into
a "high performance" mode which causes it to forego emitting history
events as well as all database access. This flag is used for
cases such as transferring graphs of objects into a session
from a second level cache, or to transfer just-loaded objects
into the session owned by a worker thread or process
without re-querying the database.
Raises:
ValueError: If `strategy` is not one of the expected values.
Returns:
Instance attached to the session - if `"merge"` strategy, may not be same instance
that was provided.
"""
if strategy == "add":
self.session.add(model)
return model
if strategy == "merge":
return self.session.merge(model, load=load)
msg = "Unexpected value for `strategy`, must be `'add'` or `'merge'`" # type: ignore[unreachable]
raise ValueError(msg)
def _execute(
self,
statement: Select[Any],
uniquify: bool = False,
) -> Result[Any]:
result = self.session.execute(statement)
if uniquify or self._uniquify:
result = result.unique()
return result
class SQLAlchemySyncSlugRepository(
SQLAlchemySyncRepository[ModelT],
SQLAlchemySyncSlugRepositoryProtocol[ModelT],
):
"""Extends the repository to include slug model features.."""
def get_by_slug(
self,
slug: str,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Select record by slug value.
Returns:
The model instance or None if not found.
"""
return self.get_one_or_none(
slug=slug,
load=load,
execution_options=execution_options,
error_messages=error_messages,
uniquify=uniquify,
)
def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
def _is_slug_unique(
self,
slug: str,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool:
return self.exists(slug=slug, load=load, execution_options=execution_options, **kwargs) is False
class SQLAlchemySyncQueryRepository:
"""SQLAlchemy Query Repository.
This is a loosely typed helper to query for when you need to select data in ways that don't align to the normal repository pattern.
"""
error_messages: Optional[ErrorMessages] = None
def __init__(
self,
*,
session: Union[Session, scoped_session[Session]],
error_messages: Optional[ErrorMessages] = None,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
session: Session managing the unit-of-work for the operation.
error_messages: A set of error messages to use for operations.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.session = session
self.error_messages = error_messages
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
def get_one(
self,
statement: Select[tuple[Any]],
**kwargs: Any,
) -> Row[Any]:
"""Get instance identified by ``kwargs``.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
instance = (self.execute(statement)).scalar_one_or_none()
return self.check_not_found(instance)
def get_one_or_none(
self,
statement: Select[Any],
**kwargs: Any,
) -> Optional[Row[Any]]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance or None
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
instance = (self.execute(statement)).scalar_one_or_none()
return instance or None
def count(self, statement: Select[Any], **kwargs: Any) -> int:
"""Get the count of records returned by a query.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(
None,
)
statement = self._filter_statement_by_kwargs(statement, **kwargs)
results = self.execute(statement)
return results.scalar_one() # type: ignore
def list_and_count(
self,
statement: Select[Any],
count_with_window_function: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query.
count_with_window_function: Force list and count to use two queries instead of an analytical window function.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
if self._dialect.name in {"spanner", "spanner+spanner"} or count_with_window_function:
return self._list_and_count_basic(statement=statement, **kwargs)
return self._list_and_count_window(statement=statement, **kwargs)
def _list_and_count_window(
self,
statement: Select[Any],
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = statement.add_columns(over(sql_func.count(text("1"))))
statement = self._filter_statement_by_kwargs(statement, **kwargs)
result = self.execute(statement)
count: int = 0
instances: list[Row[Any]] = []
for i, (instance, count_value) in enumerate(result):
instances.append(instance)
if i == 0:
count = count_value
return instances, count
@staticmethod
def _get_count_stmt(statement: Select[Any]) -> Select[Any]:
return statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None) # pyright: ignore[reportUnknownVariable]
def _list_and_count_basic(
self,
statement: Select[Any],
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query. .
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
count_result = self.session.execute(self._get_count_stmt(statement))
count = count_result.scalar_one()
result = self.execute(statement)
instances: list[Row[Any]] = []
for (instance,) in result:
instances.append(instance)
return instances, count
def list(self, statement: Select[Any], **kwargs: Any) -> list[Row[Any]]:
"""Get a list of instances, optionally filtered.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
result = self.execute(statement)
return list(result.all())
def _filter_statement_by_kwargs(
self,
statement: Select[Any],
/,
**kwargs: Any,
) -> Select[Any]:
"""Filter the collection by kwargs.
Args:
statement: statement to filter
**kwargs: key/value pairs such that objects remaining in the statement after filtering
have the property that their attribute named `key` has value equal to `value`.
Returns:
The filtered statement.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
return statement.filter_by(**kwargs)
# the following is all sqlalchemy implementation detail, and shouldn't be directly accessed
@staticmethod
def check_not_found(item_or_none: Optional[T]) -> T:
"""Raise :class:`NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item to be tested for existence.
Raises:
NotFoundError: If ``item_or_none`` is ``None``
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
def execute(
self,
statement: Union[
ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Update, Delete, Select[Any]
],
) -> Result[Any]:
return self.session.execute(statement)
python-advanced-alchemy-1.4.1/advanced_alchemy/repository/_util.py 0000664 0000000 0000000 00000033027 15003544734 0025407 0 ustar 00root root 0000000 0000000 from collections.abc import Iterable, Sequence
from typing import Any, Literal, Optional, Protocol, Union, cast, overload
from sqlalchemy import (
Delete,
Dialect,
Select,
UnaryExpression,
Update,
)
from sqlalchemy.orm import (
InstrumentedAttribute,
MapperProperty,
RelationshipProperty,
joinedload,
lazyload,
selectinload,
)
from sqlalchemy.orm.strategy_options import (
_AbstractLoad, # pyright: ignore[reportPrivateUsage] # pyright: ignore[reportPrivateUsage]
)
from sqlalchemy.sql import ColumnElement, ColumnExpressionArgument
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from typing_extensions import TypeAlias
from advanced_alchemy.base import ModelProtocol
from advanced_alchemy.exceptions import ErrorMessages
from advanced_alchemy.exceptions import wrap_sqlalchemy_exception as _wrap_sqlalchemy_exception
from advanced_alchemy.filters import (
InAnyFilter,
PaginationFilter,
StatementFilter,
StatementTypeT,
)
from advanced_alchemy.repository.typing import ModelT, OrderingPair
WhereClauseT = ColumnExpressionArgument[bool]
SingleLoad: TypeAlias = Union[
_AbstractLoad,
Literal["*"],
InstrumentedAttribute[Any],
RelationshipProperty[Any],
MapperProperty[Any],
]
LoadCollection: TypeAlias = Sequence[Union[SingleLoad, Sequence[SingleLoad]]]
ExecutableOptions: TypeAlias = Sequence[ExecutableOption]
LoadSpec: TypeAlias = Union[LoadCollection, SingleLoad, ExecutableOption, ExecutableOptions]
OrderByT: TypeAlias = Union[
str,
InstrumentedAttribute[Any],
RelationshipProperty[Any],
]
# NOTE: For backward compatibility with Litestar - this is imported from here within the litestar codebase.
wrap_sqlalchemy_exception = _wrap_sqlalchemy_exception
DEFAULT_ERROR_MESSAGE_TEMPLATES: ErrorMessages = {
"integrity": "There was a data validation error during processing",
"foreign_key": "A foreign key is missing or invalid",
"multiple_rows": "Multiple matching rows found",
"duplicate_key": "A record matching the supplied data already exists.",
"other": "There was an error during data processing",
"check_constraint": "The data failed a check constraint during processing",
"not_found": "The requested resource was not found",
}
"""Default error messages for repository errors."""
def get_instrumented_attr(
model: type[ModelProtocol],
key: Union[str, InstrumentedAttribute[Any]],
) -> InstrumentedAttribute[Any]:
"""Get an instrumented attribute from a model.
Args:
model: SQLAlchemy model class.
key: Either a string attribute name or an :class:`sqlalchemy.orm.InstrumentedAttribute`.
Returns:
:class:`sqlalchemy.orm.InstrumentedAttribute`: The instrumented attribute from the model.
"""
if isinstance(key, str):
return cast("InstrumentedAttribute[Any]", getattr(model, key))
return key
def model_from_dict(model: type[ModelT], **kwargs: Any) -> ModelT:
"""Create an ORM model instance from a dictionary of attributes.
Args:
model: The SQLAlchemy model class to instantiate.
**kwargs: Keyword arguments containing model attribute values.
Returns:
ModelT: A new instance of the model populated with the provided values.
"""
data = {
column_name: kwargs[column_name]
for column_name in model.__mapper__.columns.keys() # noqa: SIM118 # pyright: ignore[reportUnknownMemberType]
if column_name in kwargs
}
return model(**data)
def get_abstract_loader_options(
loader_options: Union[LoadSpec, None],
default_loader_options: Union[list[_AbstractLoad], None] = None,
default_options_have_wildcards: bool = False,
merge_with_default: bool = True,
inherit_lazy_relationships: bool = True,
cycle_count: int = 0,
) -> tuple[list[_AbstractLoad], bool]:
"""Generate SQLAlchemy loader options for eager loading relationships.
Args:
loader_options :class:`~advanced_alchemy.repository.typing.LoadSpec`|:class:`None` Specification for how to load relationships. Can be:
- None: Use defaults
- :class:`sqlalchemy.orm.strategy_options._AbstractLoad`: Direct SQLAlchemy loader option
- :class:`sqlalchemy.orm.InstrumentedAttribute`: Model relationship attribute
- :class:`sqlalchemy.orm.RelationshipProperty`: SQLAlchemy relationship
- str: "*" for wildcard loading
- :class:`typing.Sequence` of the above
default_loader_options: :class:`typing.Sequence` of :class:`sqlalchemy.orm.strategy_options._AbstractLoad` loader options to start with.
default_options_have_wildcards: Whether the default options contain wildcards.
merge_with_default: Whether to merge the default options with the loader options.
inherit_lazy_relationships: Whether to inherit the ``lazy`` configuration from the model's relationships.
cycle_count: Number of times this function has been called recursively.
Returns:
tuple[:class:`list`[:class:`sqlalchemy.orm.strategy_options._AbstractLoad`], bool]: A tuple containing:
- :class:`list` of :class:`sqlalchemy.orm.strategy_options._AbstractLoad` SQLAlchemy loader option objects
- Boolean indicating if any wildcard loaders are present
"""
loads: list[_AbstractLoad] = []
if cycle_count == 0 and not inherit_lazy_relationships:
loads.append(lazyload("*"))
if cycle_count == 0 and merge_with_default and default_loader_options is not None:
loads.extend(default_loader_options)
options_have_wildcards = default_options_have_wildcards
if loader_options is None:
return (loads, options_have_wildcards)
if isinstance(loader_options, _AbstractLoad):
return ([loader_options], options_have_wildcards)
if isinstance(loader_options, InstrumentedAttribute):
loader_options = [loader_options.property]
if isinstance(loader_options, RelationshipProperty):
class_ = loader_options.class_attribute
return (
[selectinload(class_)]
if loader_options.uselist
else [joinedload(class_, innerjoin=loader_options.innerjoin)],
options_have_wildcards if loader_options.uselist else True,
)
if isinstance(loader_options, str) and loader_options == "*":
options_have_wildcards = True
return ([joinedload("*")], options_have_wildcards)
if isinstance(loader_options, (list, tuple)):
for attribute in loader_options: # pyright: ignore[reportUnknownVariableType]
if isinstance(attribute, (list, tuple)):
load_chain, options_have_wildcards = get_abstract_loader_options(
loader_options=attribute, # pyright: ignore[reportUnknownArgumentType]
default_options_have_wildcards=options_have_wildcards,
inherit_lazy_relationships=inherit_lazy_relationships,
merge_with_default=merge_with_default,
cycle_count=cycle_count + 1,
)
loader = load_chain[-1]
for sub_load in load_chain[-2::-1]:
loader = sub_load.options(loader)
loads.append(loader)
else:
load_chain, options_have_wildcards = get_abstract_loader_options(
loader_options=attribute, # pyright: ignore[reportUnknownArgumentType]
default_options_have_wildcards=options_have_wildcards,
inherit_lazy_relationships=inherit_lazy_relationships,
merge_with_default=merge_with_default,
cycle_count=cycle_count + 1,
)
loads.extend(load_chain)
return (loads, options_have_wildcards)
class FilterableRepositoryProtocol(Protocol[ModelT]):
"""Protocol defining the interface for filterable repositories.
This protocol defines the required attributes and methods that any
filterable repository implementation must provide.
"""
model_type: type[ModelT]
"""The SQLAlchemy model class this repository manages."""
class FilterableRepository(FilterableRepositoryProtocol[ModelT]):
"""Default implementation of a filterable repository.
Provides core filtering, ordering and pagination functionality for
SQLAlchemy models.
"""
model_type: type[ModelT]
"""The SQLAlchemy model class this repository manages."""
prefer_any_dialects: Optional[tuple[str]] = ("postgresql",)
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None
"""List or single :class:`~advanced_alchemy.repository.typing.OrderingPair` to use for sorting."""
_prefer_any: bool = False
"""Whether to prefer ANY() over IN() in queries."""
_dialect: Dialect
"""The SQLAlchemy :class:`sqlalchemy.dialects.Dialect` being used."""
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Select[tuple[ModelT]],
) -> Select[tuple[ModelT]]: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Delete,
) -> Delete: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Union[ReturningDelete[tuple[ModelT]], ReturningUpdate[tuple[ModelT]]],
) -> Union[ReturningDelete[tuple[ModelT]], ReturningUpdate[tuple[ModelT]]]: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Update,
) -> Update: ...
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: StatementTypeT,
) -> StatementTypeT:
"""Apply filters to a SQL statement.
Args:
*filters: Filter conditions to apply.
apply_pagination: Whether to apply pagination filters.
statement: The base SQL statement to filter.
Returns:
StatementTypeT: The filtered SQL statement.
"""
for filter_ in filters:
if isinstance(filter_, (PaginationFilter,)):
if apply_pagination:
statement = filter_.append_to_statement(statement, self.model_type)
elif isinstance(filter_, (InAnyFilter,)):
statement = filter_.append_to_statement(statement, self.model_type)
elif isinstance(filter_, ColumnElement):
statement = cast("StatementTypeT", statement.where(filter_))
else:
statement = filter_.append_to_statement(statement, self.model_type)
return statement
def _filter_select_by_kwargs(
self,
statement: StatementTypeT,
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
) -> StatementTypeT:
"""Filter a statement using keyword arguments.
Args:
statement: :class:`sqlalchemy.sql.Select` The SQL statement to filter.
kwargs: Dictionary or iterable of tuples containing filter criteria.
Keys should be model attribute names, values are what to filter for.
Returns:
StatementTypeT: The filtered SQL statement.
"""
for key, val in dict(kwargs).items():
field = get_instrumented_attr(self.model_type, key)
statement = cast("StatementTypeT", statement.where(field == val))
return statement
def _apply_order_by(
self,
statement: StatementTypeT,
order_by: Union[
OrderingPair,
list[OrderingPair],
],
) -> StatementTypeT:
"""Apply ordering to a SQL statement.
Args:
statement: The SQL statement to order.
order_by: Ordering specification. Either a single tuple or list of tuples where:
- First element is the field name or :class:`sqlalchemy.orm.InstrumentedAttribute` to order by
- Second element is a boolean indicating descending (True) or ascending (False)
Returns:
StatementTypeT: The ordered SQL statement.
"""
if not isinstance(order_by, list):
order_by = [order_by]
for order_field in order_by:
if isinstance(order_field, UnaryExpression):
statement = statement.order_by(order_field) # type: ignore
else:
field = get_instrumented_attr(self.model_type, order_field[0])
statement = self._order_by_attribute(statement, field, order_field[1])
return statement
@staticmethod
def _order_by_attribute(
statement: StatementTypeT,
field: InstrumentedAttribute[Any],
is_desc: bool,
) -> StatementTypeT:
"""Apply ordering by a single attribute to a SQL statement.
Args:
statement: The SQL statement to order.
field: The model attribute to order by.
is_desc: Whether to order in descending (True) or ascending (False) order.
Returns:
StatementTypeT: The ordered SQL statement.
"""
if not isinstance(statement, Select):
return statement
return cast("StatementTypeT", statement.order_by(field.desc() if is_desc else field.asc()))
python-advanced-alchemy-1.4.1/advanced_alchemy/repository/memory/ 0000775 0000000 0000000 00000000000 15003544734 0025224 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/repository/memory/__init__.py 0000664 0000000 0000000 00000000624 15003544734 0027337 0 ustar 00root root 0000000 0000000 from advanced_alchemy.repository.memory._async import SQLAlchemyAsyncMockRepository, SQLAlchemyAsyncMockSlugRepository
from advanced_alchemy.repository.memory._sync import SQLAlchemySyncMockRepository, SQLAlchemySyncMockSlugRepository
__all__ = [
"SQLAlchemyAsyncMockRepository",
"SQLAlchemyAsyncMockSlugRepository",
"SQLAlchemySyncMockRepository",
"SQLAlchemySyncMockSlugRepository",
]
python-advanced-alchemy-1.4.1/advanced_alchemy/repository/memory/_async.py 0000664 0000000 0000000 00000073174 15003544734 0027066 0 ustar 00root root 0000000 0000000 import datetime
import random
import re
import string
from collections import abc
from collections.abc import Iterable
from typing import Any, Optional, Union, cast, overload
from unittest.mock import create_autospec
from sqlalchemy import (
ColumnElement,
Dialect,
Select,
StatementLambdaElement,
Update,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql.dml import ReturningUpdate
from typing_extensions import Self
from advanced_alchemy.exceptions import ErrorMessages, IntegrityError, NotFoundError, RepositoryError
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
LimitOffset,
NotInCollectionFilter,
NotInSearchFilter,
OnBeforeAfter,
OrderBy,
SearchFilter,
StatementFilter,
)
from advanced_alchemy.repository._async import SQLAlchemyAsyncRepositoryProtocol, SQLAlchemyAsyncSlugRepositoryProtocol
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES, LoadSpec
from advanced_alchemy.repository.memory.base import (
AnyObject,
InMemoryStore,
SQLAlchemyInMemoryStore,
SQLAlchemyMultiStore,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
class SQLAlchemyAsyncMockRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT]):
"""In memory repository."""
__database__: SQLAlchemyMultiStore[ModelT] = SQLAlchemyMultiStore(SQLAlchemyInMemoryStore)
__database_registry__: dict[type[Self], SQLAlchemyMultiStore[ModelT]] = {}
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
model_type: type[ModelT]
id_attribute: Any = "id"
match_fields: Optional[Union[list[str], str]] = None
uniquify: bool = False
_exclude_kwargs: set[str] = {
"statement",
"session",
"auto_expunge",
"auto_refresh",
"auto_commit",
"attribute_names",
"with_for_update",
"count_with_window_function",
"loader_options",
"execution_options",
"order_by",
"load",
"error_messages",
"wrap_exceptions",
"uniquify",
}
def __init__(
self,
*,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> None:
self.session = session
self.statement = create_autospec("Select[Tuple[ModelT]]", instance=True)
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.error_messages = self._get_error_messages(error_messages=error_messages)
self.wrap_exceptions = wrap_exceptions
self.order_by = order_by
self._dialect: Dialect = create_autospec(Dialect, instance=True)
self._dialect.name = "mock"
self.__filtered_store__: InMemoryStore[ModelT] = self.__database__.store_type()
self._default_options: Any = []
self._default_execution_options: Any = {}
self._loader_options: Any = []
self._loader_options_have_wildcards = False
self.uniquify = bool(uniquify)
def __init_subclass__(cls) -> None:
cls.__database_registry__[cls] = cls.__database__ # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
@staticmethod
def _get_error_messages(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
default_messages = cast(
"Optional[ErrorMessages]",
default_messages if default_messages != Empty else DEFAULT_ERROR_MESSAGE_TEMPLATES,
)
if error_messages is not None and default_messages is not None:
default_messages.update(cast("ErrorMessages", error_messages))
return default_messages
@classmethod
def __database_add__(cls, identity: Any, data: ModelT) -> ModelT:
return cast("ModelT", cls.__database__.add(identity, data)) # pyright: ignore[reportUnnecessaryCast,reportGeneralTypeIssues]
@classmethod
def __database_clear__(cls) -> None:
for database in cls.__database_registry__.values(): # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
database.remove_all()
@overload
def __collection__(self) -> InMemoryStore[ModelT]: ...
@overload
def __collection__(self, identity: type[AnyObject]) -> InMemoryStore[AnyObject]: ...
def __collection__(
self,
identity: Optional[type[AnyObject]] = None,
) -> Union[InMemoryStore[AnyObject], InMemoryStore[ModelT]]:
if identity:
return self.__database__.store(identity)
return self.__filtered_store__ or self.__database__.store(self.model_type)
@staticmethod
def check_not_found(item_or_none: Union[ModelT, None]) -> ModelT:
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
return {key: value for key, value in kwargs.items() if key not in self._exclude_kwargs}
@staticmethod
def _apply_limit_offset_pagination(result: list[ModelT], limit: int, offset: int) -> list[ModelT]:
return result[offset:limit]
@staticmethod
def _filter_in_collection(
result: list[ModelT],
field_name: str,
values: abc.Collection[Any],
) -> list[ModelT]:
return [item for item in result if getattr(item, field_name) in values]
@staticmethod
def _filter_not_in_collection(
result: list[ModelT],
field_name: str,
values: abc.Collection[Any],
) -> list[ModelT]:
if not values:
return result
return [item for item in result if getattr(item, field_name) not in values]
@staticmethod
def _filter_on_datetime_field(
result: list[ModelT],
field_name: str,
before: Optional[datetime.datetime] = None,
after: Optional[datetime.datetime] = None,
on_or_before: Optional[datetime.datetime] = None,
on_or_after: Optional[datetime.datetime] = None,
) -> list[ModelT]:
result_: list[ModelT] = []
for item in result:
attr: datetime.datetime = getattr(item, field_name)
if before is not None and attr < before:
result_.append(item)
if after is not None and attr > after:
result_.append(item)
if on_or_before is not None and attr <= on_or_before:
result_.append(item)
if on_or_after is not None and attr >= on_or_after:
result_.append(item)
return result_
@staticmethod
def _filter_by_like(
result: list[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> list[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: list[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(items))
@staticmethod
def _filter_by_not_like(
result: list[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> list[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: list[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(result).difference(set(items)))
def _filter_result_by_kwargs(
self,
result: Iterable[ModelT],
/,
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
) -> list[ModelT]:
kwargs_: dict[Any, Any] = kwargs if isinstance(kwargs, dict) else dict(*kwargs) # pyright: ignore
kwargs_ = self._exclude_unused_kwargs(kwargs_) # pyright: ignore
try:
return [item for item in result if all(getattr(item, field) == value for field, value in kwargs_.items())] # pyright: ignore
except AttributeError as error:
raise RepositoryError from error
@staticmethod
def _order_by(result: list[ModelT], field_name: str, sort_desc: bool = False) -> list[ModelT]:
return sorted(result, key=lambda item: getattr(item, field_name), reverse=sort_desc)
def _apply_filters(
self,
result: list[ModelT],
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
) -> list[ModelT]:
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if apply_pagination:
result = self._apply_limit_offset_pagination(result, filter_.limit, filter_.offset)
elif isinstance(filter_, BeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
before=filter_.before,
after=filter_.after,
)
elif isinstance(filter_, OnBeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
on_or_before=filter_.on_or_before,
on_or_after=filter_.on_or_after,
)
elif isinstance(filter_, NotInCollectionFilter):
if filter_.values is not None: # pyright: ignore
result = self._filter_not_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore
elif isinstance(filter_, CollectionFilter):
if filter_.values is not None: # pyright: ignore
result = self._filter_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore
elif isinstance(filter_, OrderBy):
result = self._order_by(
result,
filter_.field_name,
sort_desc=filter_.sort_order == "desc",
)
elif isinstance(filter_, NotInSearchFilter):
result = self._filter_by_not_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif isinstance(filter_, SearchFilter):
result = self._filter_by_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif not isinstance(filter_, ColumnElement):
msg = f"Unexpected filter: {filter_}"
raise RepositoryError(msg)
return result
def _get_match_fields(
self,
match_fields: Union[list[str], str, None],
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
async def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[list[ModelT], int]:
result = await self.list(*filters, **kwargs)
return result, len(result)
async def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[list[ModelT], int]:
return await self._list_and_count_basic(*filters, **kwargs)
def _find_or_raise_not_found(self, id_: Any) -> ModelT:
return self.check_not_found(self.__collection__().get_or_none(id_))
@staticmethod
def _find_one_or_raise_error(result: list[ModelT]) -> ModelT:
if not result:
msg = "No item found when one was expected"
raise IntegrityError(msg)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0] # pyright: ignore
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
return self.statement # type: ignore[no-any-return] # pyright: ignore[reportReturnType]
@classmethod
async def check_health(cls, session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> bool:
return True
async def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
return self._find_or_raise_not_found(item_id)
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
return self.check_not_found(await self.get_one_or_none(**kwargs))
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
result = self._filter_result_by_kwargs(self.__collection__().list(), kwargs)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0] if result else None
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[list[str], str, None] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = await self.get_one_or_none(**match_filter)
if not existing:
return (await self.add(self.model_type(**kwargs_)), True)
if upsert:
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = await self.update(existing)
return existing, False
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[list[str], str, None] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = await self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = await self.update(existing)
return existing, updated
async def exists(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
existing = await self.count(*filters, **kwargs)
return existing > 0
async def count(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
result = self._apply_filters(self.__collection__().list(), *filters)
return len(self._filter_result_by_kwargs(result, kwargs))
async def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT:
try:
self.__database__.add(self.model_type, data)
except KeyError as exc:
msg = "Item already exist in collection"
raise IntegrityError(msg) from exc
return data
async def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> list[ModelT]:
for obj in data:
await self.add(obj) # pyright: ignore[reportCallIssue]
return data
async def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
self._find_or_raise_not_found(self.__collection__().key(data))
return self.__collection__().update(data)
async def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
return [self.__collection__().update(obj) for obj in data if obj in self.__collection__()]
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
try:
return self._find_or_raise_not_found(item_id)
finally:
self.__collection__().remove(item_id)
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
deleted: list[ModelT] = []
for id_ in item_ids:
if obj := self.__collection__().get_or_none(id_):
deleted.append(obj)
self.__collection__().remove(id_)
return deleted
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
result = self.__collection__().list()
result = self._apply_filters(result, *filters)
models = self._filter_result_by_kwargs(result, kwargs)
item_ids = [getattr(model, self.id_attribute) for model in models]
return await self.delete_many(item_ids=item_ids)
async def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
# sourcery skip: assign-if-exp, reintroduce-else
if data in self.__collection__():
return await self.update(data)
return await self.add(data)
async def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
return [await self.upsert(item) for item in data]
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
return await self._list_and_count_basic(*filters, **kwargs)
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
result = self.__collection__().list()
result = self._apply_filters(result, *filters)
return self._filter_result_by_kwargs(result, kwargs)
class SQLAlchemyAsyncMockSlugRepository(
SQLAlchemyAsyncMockRepository[ModelT],
SQLAlchemyAsyncSlugRepositoryProtocol[ModelT],
):
async def get_by_slug(
self,
slug: str,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
return await self.get_one_or_none(slug=slug)
async def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if await self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
async def _is_slug_unique(
self,
slug: str,
**kwargs: Any,
) -> bool:
return await self.exists(slug=slug) is False
python-advanced-alchemy-1.4.1/advanced_alchemy/repository/memory/_sync.py 0000664 0000000 0000000 00000072623 15003544734 0026723 0 ustar 00root root 0000000 0000000 # Do not edit this file directly. It has been autogenerated from
# advanced_alchemy/repository/memory/_async.py
import datetime
import random
import re
import string
from collections import abc
from collections.abc import Iterable
from typing import Any, Optional, Union, cast, overload
from unittest.mock import create_autospec
from sqlalchemy import (
ColumnElement,
Dialect,
Select,
StatementLambdaElement,
Update,
)
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql.dml import ReturningUpdate
from typing_extensions import Self
from advanced_alchemy.exceptions import ErrorMessages, IntegrityError, NotFoundError, RepositoryError
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
LimitOffset,
NotInCollectionFilter,
NotInSearchFilter,
OnBeforeAfter,
OrderBy,
SearchFilter,
StatementFilter,
)
from advanced_alchemy.repository._sync import SQLAlchemySyncRepositoryProtocol, SQLAlchemySyncSlugRepositoryProtocol
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES, LoadSpec
from advanced_alchemy.repository.memory.base import (
AnyObject,
InMemoryStore,
SQLAlchemyInMemoryStore,
SQLAlchemyMultiStore,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
class SQLAlchemySyncMockRepository(SQLAlchemySyncRepositoryProtocol[ModelT]):
"""In memory repository."""
__database__: SQLAlchemyMultiStore[ModelT] = SQLAlchemyMultiStore(SQLAlchemyInMemoryStore)
__database_registry__: dict[type[Self], SQLAlchemyMultiStore[ModelT]] = {}
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
model_type: type[ModelT]
id_attribute: Any = "id"
match_fields: Optional[Union[list[str], str]] = None
uniquify: bool = False
_exclude_kwargs: set[str] = {
"statement",
"session",
"auto_expunge",
"auto_refresh",
"auto_commit",
"attribute_names",
"with_for_update",
"count_with_window_function",
"loader_options",
"execution_options",
"order_by",
"load",
"error_messages",
"wrap_exceptions",
"uniquify",
}
def __init__(
self,
*,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
session: Union[Session, scoped_session[Session]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> None:
self.session = session
self.statement = create_autospec("Select[Tuple[ModelT]]", instance=True)
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.error_messages = self._get_error_messages(error_messages=error_messages)
self.wrap_exceptions = wrap_exceptions
self.order_by = order_by
self._dialect: Dialect = create_autospec(Dialect, instance=True)
self._dialect.name = "mock"
self.__filtered_store__: InMemoryStore[ModelT] = self.__database__.store_type()
self._default_options: Any = []
self._default_execution_options: Any = {}
self._loader_options: Any = []
self._loader_options_have_wildcards = False
self.uniquify = bool(uniquify)
def __init_subclass__(cls) -> None:
cls.__database_registry__[cls] = cls.__database__ # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
@staticmethod
def _get_error_messages(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
default_messages = cast(
"Optional[ErrorMessages]",
default_messages if default_messages != Empty else DEFAULT_ERROR_MESSAGE_TEMPLATES,
)
if error_messages is not None and default_messages is not None:
default_messages.update(cast("ErrorMessages", error_messages))
return default_messages
@classmethod
def __database_add__(cls, identity: Any, data: ModelT) -> ModelT:
return cast("ModelT", cls.__database__.add(identity, data)) # pyright: ignore[reportUnnecessaryCast,reportGeneralTypeIssues]
@classmethod
def __database_clear__(cls) -> None:
for database in cls.__database_registry__.values(): # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
database.remove_all()
@overload
def __collection__(self) -> InMemoryStore[ModelT]: ...
@overload
def __collection__(self, identity: type[AnyObject]) -> InMemoryStore[AnyObject]: ...
def __collection__(
self,
identity: Optional[type[AnyObject]] = None,
) -> Union[InMemoryStore[AnyObject], InMemoryStore[ModelT]]:
if identity:
return self.__database__.store(identity)
return self.__filtered_store__ or self.__database__.store(self.model_type)
@staticmethod
def check_not_found(item_or_none: Union[ModelT, None]) -> ModelT:
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
return {key: value for key, value in kwargs.items() if key not in self._exclude_kwargs}
@staticmethod
def _apply_limit_offset_pagination(result: list[ModelT], limit: int, offset: int) -> list[ModelT]:
return result[offset:limit]
@staticmethod
def _filter_in_collection(
result: list[ModelT],
field_name: str,
values: abc.Collection[Any],
) -> list[ModelT]:
return [item for item in result if getattr(item, field_name) in values]
@staticmethod
def _filter_not_in_collection(
result: list[ModelT],
field_name: str,
values: abc.Collection[Any],
) -> list[ModelT]:
if not values:
return result
return [item for item in result if getattr(item, field_name) not in values]
@staticmethod
def _filter_on_datetime_field(
result: list[ModelT],
field_name: str,
before: Optional[datetime.datetime] = None,
after: Optional[datetime.datetime] = None,
on_or_before: Optional[datetime.datetime] = None,
on_or_after: Optional[datetime.datetime] = None,
) -> list[ModelT]:
result_: list[ModelT] = []
for item in result:
attr: datetime.datetime = getattr(item, field_name)
if before is not None and attr < before:
result_.append(item)
if after is not None and attr > after:
result_.append(item)
if on_or_before is not None and attr <= on_or_before:
result_.append(item)
if on_or_after is not None and attr >= on_or_after:
result_.append(item)
return result_
@staticmethod
def _filter_by_like(
result: list[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> list[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: list[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(items))
@staticmethod
def _filter_by_not_like(
result: list[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> list[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: list[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(result).difference(set(items)))
def _filter_result_by_kwargs(
self,
result: Iterable[ModelT],
/,
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
) -> list[ModelT]:
kwargs_: dict[Any, Any] = kwargs if isinstance(kwargs, dict) else dict(*kwargs) # pyright: ignore
kwargs_ = self._exclude_unused_kwargs(kwargs_) # pyright: ignore
try:
return [item for item in result if all(getattr(item, field) == value for field, value in kwargs_.items())] # pyright: ignore
except AttributeError as error:
raise RepositoryError from error
@staticmethod
def _order_by(result: list[ModelT], field_name: str, sort_desc: bool = False) -> list[ModelT]:
return sorted(result, key=lambda item: getattr(item, field_name), reverse=sort_desc)
def _apply_filters(
self,
result: list[ModelT],
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
) -> list[ModelT]:
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if apply_pagination:
result = self._apply_limit_offset_pagination(result, filter_.limit, filter_.offset)
elif isinstance(filter_, BeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
before=filter_.before,
after=filter_.after,
)
elif isinstance(filter_, OnBeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
on_or_before=filter_.on_or_before,
on_or_after=filter_.on_or_after,
)
elif isinstance(filter_, NotInCollectionFilter):
if filter_.values is not None: # pyright: ignore
result = self._filter_not_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore
elif isinstance(filter_, CollectionFilter):
if filter_.values is not None: # pyright: ignore
result = self._filter_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore
elif isinstance(filter_, OrderBy):
result = self._order_by(
result,
filter_.field_name,
sort_desc=filter_.sort_order == "desc",
)
elif isinstance(filter_, NotInSearchFilter):
result = self._filter_by_not_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif isinstance(filter_, SearchFilter):
result = self._filter_by_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif not isinstance(filter_, ColumnElement):
msg = f"Unexpected filter: {filter_}"
raise RepositoryError(msg)
return result
def _get_match_fields(
self,
match_fields: Union[list[str], str, None],
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[list[ModelT], int]:
result = self.list(*filters, **kwargs)
return result, len(result)
def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[list[ModelT], int]:
return self._list_and_count_basic(*filters, **kwargs)
def _find_or_raise_not_found(self, id_: Any) -> ModelT:
return self.check_not_found(self.__collection__().get_or_none(id_))
@staticmethod
def _find_one_or_raise_error(result: list[ModelT]) -> ModelT:
if not result:
msg = "No item found when one was expected"
raise IntegrityError(msg)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0] # pyright: ignore
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
return self.statement # type: ignore[no-any-return] # pyright: ignore[reportReturnType]
@classmethod
def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool:
return True
def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
return self._find_or_raise_not_found(item_id)
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
return self.check_not_found(self.get_one_or_none(**kwargs))
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
result = self._filter_result_by_kwargs(self.__collection__().list(), kwargs)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0] if result else None
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[list[str], str, None] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = self.get_one_or_none(**match_filter)
if not existing:
return (self.add(self.model_type(**kwargs_)), True)
if upsert:
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = self.update(existing)
return existing, False
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[list[str], str, None] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = self.update(existing)
return existing, updated
def exists(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
existing = self.count(*filters, **kwargs)
return existing > 0
def count(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
result = self._apply_filters(self.__collection__().list(), *filters)
return len(self._filter_result_by_kwargs(result, kwargs))
def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT:
try:
self.__database__.add(self.model_type, data)
except KeyError as exc:
msg = "Item already exist in collection"
raise IntegrityError(msg) from exc
return data
def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> list[ModelT]:
for obj in data:
self.add(obj) # pyright: ignore[reportCallIssue]
return data
def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
self._find_or_raise_not_found(self.__collection__().key(data))
return self.__collection__().update(data)
def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
return [self.__collection__().update(obj) for obj in data if obj in self.__collection__()]
def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
try:
return self._find_or_raise_not_found(item_id)
finally:
self.__collection__().remove(item_id)
def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
deleted: list[ModelT] = []
for id_ in item_ids:
if obj := self.__collection__().get_or_none(id_):
deleted.append(obj)
self.__collection__().remove(id_)
return deleted
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
result = self.__collection__().list()
result = self._apply_filters(result, *filters)
models = self._filter_result_by_kwargs(result, kwargs)
item_ids = [getattr(model, self.id_attribute) for model in models]
return self.delete_many(item_ids=item_ids)
def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
# sourcery skip: assign-if-exp, reintroduce-else
if data in self.__collection__():
return self.update(data)
return self.add(data)
def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
return [self.upsert(item) for item in data]
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
return self._list_and_count_basic(*filters, **kwargs)
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
result = self.__collection__().list()
result = self._apply_filters(result, *filters)
return self._filter_result_by_kwargs(result, kwargs)
class SQLAlchemySyncMockSlugRepository(
SQLAlchemySyncMockRepository[ModelT],
SQLAlchemySyncSlugRepositoryProtocol[ModelT],
):
def get_by_slug(
self,
slug: str,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
return self.get_one_or_none(slug=slug)
def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
def _is_slug_unique(
self,
slug: str,
**kwargs: Any,
) -> bool:
return self.exists(slug=slug) is False
python-advanced-alchemy-1.4.1/advanced_alchemy/repository/memory/base.py 0000664 0000000 0000000 00000030146 15003544734 0026514 0 ustar 00root root 0000000 0000000 # ruff: noqa: PD011
import builtins
import contextlib
from collections import defaultdict
from inspect import isclass, signature
from typing import TYPE_CHECKING, Any, Generic, Union, cast, overload
from sqlalchemy import ColumnElement, inspect
from sqlalchemy.orm import RelationshipProperty, Session, class_mapper, object_mapper
from typing_extensions import TypeVar
from advanced_alchemy.exceptions import AdvancedAlchemyError
from advanced_alchemy.repository.typing import _MISSING, MISSING, ModelT # pyright: ignore[reportPrivateUsage]
if TYPE_CHECKING:
from collections.abc import Iterable
from sqlalchemy.orm import Mapper
CollectionT = TypeVar("CollectionT")
T = TypeVar("T")
AnyObject = TypeVar("AnyObject", bound="Any")
class _NotSet:
pass
class InMemoryStore(Generic[T]):
def __init__(self) -> None:
self._store: dict[Any, T] = {}
def _resolve_key(self, key: Any) -> Any:
"""Test different key representations
Args:
key: The key to test
Raises:
KeyError: Raised if key is not present
Returns:
The key representation that is present in the store
"""
for key_ in (key, str(key)):
if key_ in self._store:
return key_
raise KeyError
def key(self, obj: T) -> Any:
return hash(obj)
def add(self, obj: T) -> T:
if (key := self.key(obj)) not in self._store:
self._store[key] = obj
return obj
raise KeyError
def update(self, obj: T) -> T:
key = self._resolve_key(self.key(obj))
self._store[key] = obj
return obj
@overload
def get(self, key: Any, default: type[_NotSet] = _NotSet) -> T: ...
@overload
def get(self, key: Any, default: AnyObject) -> "Union[T, AnyObject]": ...
def get(
self, key: Any, default: "Union[AnyObject, type[_NotSet]]" = _NotSet
) -> "Union[T, AnyObject]": # pragma: no cover
"""Get the object identified by `key`, or return `default` if set or raise a `KeyError` otherwise
Args:
key: The key to test
default: Value to return if key is not present. Defaults to _NotSet.
Raises:
KeyError: Raised if key is not present
Returns:
The object identified by key
"""
try:
key = self._resolve_key(key)
except KeyError as error:
if isclass(default) and not issubclass(default, _NotSet): # pyright: ignore[reportUnnecessaryIsInstance]
return cast("AnyObject", default)
raise KeyError from error
return self._store[key]
def get_or_none(self, key: Any, default: Any = _NotSet) -> "Union[T, None]":
return self.get(key) if default is _NotSet else self.get(key, default)
def remove(self, key: Any) -> T:
return self._store.pop(self._resolve_key(key))
def list(self) -> list[T]:
return list(self._store.values())
def remove_all(self) -> None:
self._store = {}
def __contains__(self, obj: T) -> bool:
try:
self._resolve_key(self.key(obj))
except KeyError:
return False
else:
return True
def __bool__(self) -> bool:
return bool(self._store)
class MultiStore(Generic[T]):
def __init__(self, store_type: "type[InMemoryStore[T]]") -> None:
self.store_type = store_type
self._store: defaultdict[Any, InMemoryStore[T]] = defaultdict(store_type)
def add(self, identity: Any, obj: T) -> T:
return self._store[identity].add(obj)
def store(self, identity: Any) -> "InMemoryStore[T]":
return self._store[identity]
def identity(self, obj: T) -> Any:
return type(obj)
def remove_all(self) -> None:
self._store = defaultdict(self.store_type)
class SQLAlchemyInMemoryStore(InMemoryStore[ModelT]):
id_attribute: str = "id"
def _update_relationship(self, data: ModelT, ref: ModelT) -> None: # pragma: no cover
"""Set relationship data fields targeting ref class to ref.
Example:
```python
class Parent(Base):
child = relationship("Child")
class Child(Base):
pass
```
If data and ref are respectively a `Parent` and `Child` instances,
then `data.child` will be set to `ref`
Args:
data: Model instance on which to update relationships
ref: Target model instance to set on data relationships
"""
ref_mapper = object_mapper(ref)
for relationship in object_mapper(data).relationships:
local = next(iter(relationship.local_columns))
remote = next(iter(relationship.remote_side))
if not local.key or not remote.key:
msg = f"Cannot update relationship {relationship} for model {ref_mapper.class_}"
raise AdvancedAlchemyError(msg)
value = getattr(data, relationship.key)
if not value and relationship.mapper.class_ is ref_mapper.class_:
if relationship.uselist:
for elem in value:
if local_value := getattr(data, local.key):
setattr(elem, remote.key, local_value)
else:
setattr(data, relationship.key, ref)
def _update_fks(self, data: ModelT) -> None: # pragma: no cover
"""Update foreign key fields according to their corresponding relationships.
This make sure that `data.child_id` == `data.child.id`
or `data.children[0].parent_id` == `data.id`
Args:
data: Instance to be updated
"""
ref_mapper = object_mapper(data)
for relationship in ref_mapper.relationships:
if value := getattr(data, relationship.key):
local = next(iter(relationship.local_columns))
remote = next(iter(relationship.remote_side))
if not local.key or not remote.key:
msg = f"Cannot update relationship {relationship} for model {ref_mapper.class_}"
raise AdvancedAlchemyError(msg)
if relationship.uselist:
for elem in value:
if local_value := getattr(data, local.key):
setattr(elem, remote.key, local_value)
self._update_relationship(elem, data)
# Remove duplicates added by orm when updating list items
if isinstance(value, list):
setattr(data, relationship.key, type(value)(set(value))) # pyright: ignore[reportUnknownArgumentType]
else:
if remote_value := getattr(value, remote.key):
setattr(data, local.key, remote_value)
self._update_relationship(value, data)
def _set_defaults(self, data: ModelT) -> None: # pragma: no cover
"""Set fields with dynamic defaults.
Args:
data: Instance to be updated
"""
for elem in object_mapper(data).c:
default = getattr(elem, "default", MISSING)
value = getattr(data, elem.key, MISSING)
# If value is MISSING, it may be a declared_attr whose name can't be
# determined from the column/relationship element returned
if value is not MISSING and not value and not isinstance(default, _MISSING) and default is not None:
if default.is_scalar:
default_value: Any = default.arg
elif default.is_callable:
default_callable = default.arg.__func__ if isinstance(default.arg, staticmethod) else default.arg # pyright: ignore[reportUnknownMemberType]
if (
# Eager test because inspect.signature() does not
# recognize builtins
hasattr(builtins, default_callable.__name__)
# If present, context contains information about the current
# statement and can be used to access values from other columns.
# As we can't reproduce such context in Pydantic, we don't want
# include a default_factory in that case.
or "context" not in signature(default_callable).parameters
):
default_value = default.arg({}) # pyright: ignore
else:
continue
else:
continue
setattr(data, elem.key, default_value)
@staticmethod
def changed_attrs(data: ModelT) -> "Iterable[str]": # pragma: no cover
res: list[str] = []
mapper = inspect(data)
if mapper is None:
msg = f"Cannot inspect {data.__class__} model"
raise AdvancedAlchemyError(msg)
attrs = class_mapper(data.__class__).column_attrs
for attr in attrs:
hist = getattr(mapper.attrs, attr.key).history
if hist.has_changes():
res.append(attr.key)
return res
def key(self, obj: ModelT) -> str:
return str(getattr(obj, self.id_attribute))
def add(self, obj: ModelT) -> ModelT:
self._set_defaults(obj)
self._update_fks(obj)
return super().add(obj)
def update(self, obj: ModelT) -> ModelT:
existing = self.get(self.key(obj))
for attr in self.changed_attrs(obj):
setattr(existing, attr, getattr(obj, attr))
self._update_fks(existing)
return super().update(existing)
class SQLAlchemyMultiStore(MultiStore[ModelT]):
@staticmethod
def _new_instances(instance: ModelT) -> "Iterable[ModelT]":
session = Session()
session.add(instance)
relations = list(session.new)
session.expunge_all()
return relations
def _set_relationships_for_fks(self, data: ModelT) -> None: # pragma: no cover
"""Set relationships matching newly added foreign keys on the instance.
Example:
```python
class Parent(Base):
id: Mapped[UUID]
class Child(Base):
id: Mapped[UUID]
parent_id: Mapped[UUID] = mapped_column(ForeignKey("parent.id"))
parent: Mapped[Parent] = relationship(Parent)
```
If `data` is a Child instance and `parent_id` is set, `parent` will be set
to the matching Parent instance if found in the repository
Args:
data: The model to update
"""
obj_mapper = object_mapper(data)
mappers: dict[str, Mapper[Any]] = {}
column_relationships: dict[ColumnElement[Any], RelationshipProperty[Any]] = {}
for mapper in obj_mapper.registry.mappers:
for table in mapper.tables:
mappers[table.name] = mapper
for relationship in obj_mapper.relationships:
for column in relationship.local_columns:
column_relationships[column] = relationship
# sourcery skip: assign-if-exp
if state := inspect(data):
new_attrs: dict[str, Any] = state.dict
else:
new_attrs = {}
for column in obj_mapper.columns:
if column.key not in new_attrs or not column.foreign_keys:
continue
remote_mapper = mappers[next(iter(column.foreign_keys))._table_key()] # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
try:
obj = self.store(remote_mapper.class_).get(new_attrs.get(column.key, None))
except KeyError:
continue
with contextlib.suppress(KeyError):
setattr(data, column_relationships[column].key, obj)
def add(self, identity: Any, obj: ModelT) -> ModelT:
for relation in self._new_instances(obj):
instance_type = self.identity(relation)
self._set_relationships_for_fks(relation)
if relation in self.store(instance_type):
continue
self.store(instance_type).add(relation)
return obj
python-advanced-alchemy-1.4.1/advanced_alchemy/repository/typing.py 0000664 0000000 0000000 00000005443 15003544734 0025606 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any, Union
from sqlalchemy import UnaryExpression
from sqlalchemy.orm import InstrumentedAttribute
from typing_extensions import TypeAlias, TypeVar
if TYPE_CHECKING:
from sqlalchemy import RowMapping, Select
from advanced_alchemy import base
from advanced_alchemy.repository._async import SQLAlchemyAsyncRepository
from advanced_alchemy.repository._sync import SQLAlchemySyncRepository
from advanced_alchemy.repository.memory._async import SQLAlchemyAsyncMockRepository
from advanced_alchemy.repository.memory._sync import SQLAlchemySyncMockRepository
__all__ = (
"MISSING",
"ModelOrRowMappingT",
"ModelT",
"OrderingPair",
"RowMappingT",
"RowT",
"SQLAlchemyAsyncRepositoryT",
"SQLAlchemySyncRepositoryT",
"SelectT",
"T",
)
T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="base.ModelProtocol")
"""Type variable for SQLAlchemy models.
:class:`~advanced_alchemy.base.ModelProtocol`
"""
SelectT = TypeVar("SelectT", bound="Select[Any]")
"""Type variable for SQLAlchemy select statements.
:class:`~sqlalchemy.sql.Select`
"""
RowT = TypeVar("RowT", bound=tuple[Any, ...])
"""Type variable for rows.
:class:`~sqlalchemy.engine.Row`
"""
RowMappingT = TypeVar("RowMappingT", bound="RowMapping")
"""Type variable for row mappings.
:class:`~sqlalchemy.engine.RowMapping`
"""
ModelOrRowMappingT = TypeVar("ModelOrRowMappingT", bound="Union[base.ModelProtocol, RowMapping]")
"""Type variable for models or row mappings.
:class:`~advanced_alchemy.base.ModelProtocol` | :class:`~sqlalchemy.engine.RowMapping`
"""
SQLAlchemySyncRepositoryT = TypeVar(
"SQLAlchemySyncRepositoryT",
bound="Union[SQLAlchemySyncRepository[Any], SQLAlchemySyncMockRepository[Any]]",
default="Any",
)
"""Type variable for synchronous SQLAlchemy repositories.
:class:`~advanced_alchemy.repository.SQLAlchemySyncRepository`
"""
SQLAlchemyAsyncRepositoryT = TypeVar(
"SQLAlchemyAsyncRepositoryT",
bound="Union[SQLAlchemyAsyncRepository[Any], SQLAlchemyAsyncMockRepository[Any]]",
default="Any",
)
"""Type variable for asynchronous SQLAlchemy repositories.
:class:`~advanced_alchemy.repository.SQLAlchemyAsyncRepository`
"""
OrderingPair: TypeAlias = Union[tuple[Union[str, InstrumentedAttribute[Any]], bool], UnaryExpression[Any]]
"""Type alias for ordering pairs.
A tuple of (column, ascending) where:
- column: Union[str, :class:`sqlalchemy.orm.InstrumentedAttribute`]
- ascending: bool
- or a :class:`sqlalchemy.sql.elements.UnaryExpression` which is the standard way to express an ordering in SQLAlchemy
This type is used to specify ordering criteria for repository queries.
"""
class _MISSING:
"""Placeholder for missing values."""
MISSING = _MISSING()
"""Missing value placeholder.
:class:`~advanced_alchemy.repository.typing._MISSING`
"""
python-advanced-alchemy-1.4.1/advanced_alchemy/service/ 0000775 0000000 0000000 00000000000 15003544734 0023135 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/service/__init__.py 0000664 0000000 0000000 00000004463 15003544734 0025255 0 ustar 00root root 0000000 0000000 from advanced_alchemy.repository import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
Empty,
EmptyType,
ErrorMessages,
LoadSpec,
ModelOrRowMappingT,
ModelT,
OrderingPair,
model_from_dict,
)
from advanced_alchemy.service._async import (
SQLAlchemyAsyncQueryService,
SQLAlchemyAsyncRepositoryReadService,
SQLAlchemyAsyncRepositoryService,
)
from advanced_alchemy.service._sync import (
SQLAlchemySyncQueryService,
SQLAlchemySyncRepositoryReadService,
SQLAlchemySyncRepositoryService,
)
from advanced_alchemy.service._util import ResultConverter, find_filter
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service.typing import (
FilterTypeT,
ModelDictListT,
ModelDictT,
ModelDTOT,
SupportedSchemaModel,
is_dict,
is_dict_with_field,
is_dict_without_field,
is_dto_data,
is_msgspec_struct,
is_msgspec_struct_with_field,
is_msgspec_struct_without_field,
is_pydantic_model,
is_pydantic_model_with_field,
is_pydantic_model_without_field,
is_schema,
is_schema_or_dict,
is_schema_or_dict_with_field,
is_schema_or_dict_without_field,
is_schema_with_field,
is_schema_without_field,
schema_dump,
)
__all__ = (
"DEFAULT_ERROR_MESSAGE_TEMPLATES",
"Empty",
"EmptyType",
"ErrorMessages",
"FilterTypeT",
"LoadSpec",
"ModelDTOT",
"ModelDictListT",
"ModelDictT",
"ModelOrRowMappingT",
"ModelT",
"OffsetPagination",
"OrderingPair",
"ResultConverter",
"SQLAlchemyAsyncQueryService",
"SQLAlchemyAsyncRepositoryReadService",
"SQLAlchemyAsyncRepositoryService",
"SQLAlchemySyncQueryService",
"SQLAlchemySyncRepositoryReadService",
"SQLAlchemySyncRepositoryService",
"SupportedSchemaModel",
"find_filter",
"is_dict",
"is_dict_with_field",
"is_dict_without_field",
"is_dto_data",
"is_msgspec_struct",
"is_msgspec_struct_with_field",
"is_msgspec_struct_without_field",
"is_pydantic_model",
"is_pydantic_model_with_field",
"is_pydantic_model_without_field",
"is_schema",
"is_schema_or_dict",
"is_schema_or_dict_with_field",
"is_schema_or_dict_without_field",
"is_schema_with_field",
"is_schema_without_field",
"model_from_dict",
"schema_dump",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/service/_async.py 0000664 0000000 0000000 00000136276 15003544734 0025002 0 ustar 00root root 0000000 0000000 # ruff: noqa: PLR6301
"""Service object implementation for SQLAlchemy.
RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""
from collections.abc import AsyncIterator, Iterable, Sequence
from contextlib import asynccontextmanager
from functools import cached_property
from typing import Any, ClassVar, Generic, Optional, Union, cast
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql import ColumnElement
from typing_extensions import Self
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.exceptions import AdvancedAlchemyError, ErrorMessages, ImproperConfigurationError, RepositoryError
from advanced_alchemy.filters import StatementFilter
from advanced_alchemy.repository import (
SQLAlchemyAsyncQueryRepository,
)
from advanced_alchemy.repository._util import LoadSpec, model_from_dict
from advanced_alchemy.repository.typing import ModelT, OrderingPair, SQLAlchemyAsyncRepositoryT
from advanced_alchemy.service._util import ResultConverter
from advanced_alchemy.service.typing import (
BulkModelDictT,
ModelDictListT,
ModelDictT,
is_dict,
is_dto_data,
is_msgspec_struct,
is_pydantic_model,
)
from advanced_alchemy.utils.dataclass import Empty, EmptyType
class SQLAlchemyAsyncQueryService(ResultConverter):
"""Simple service to execute the basic Query repository.."""
def __init__(
self,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
**repo_kwargs: Any,
) -> None:
"""Configure the service object.
Args:
session: Session managing the unit-of-work for the operation.
**repo_kwargs: Optional configuration values to pass into the repository
"""
self.repository = SQLAlchemyAsyncQueryRepository(
session=session,
**repo_kwargs,
)
@classmethod
@asynccontextmanager
async def new(
cls,
session: Optional[Union[AsyncSession, async_scoped_session[AsyncSession]]] = None,
config: Optional[SQLAlchemyAsyncConfig] = None,
) -> AsyncIterator[Self]:
"""Context manager that returns instance of service object.
Handles construction of the database session._create_select_for_model
Raises:
AdvancedAlchemyError: If no configuration or session is provided.
Yields:
The service object instance.
"""
if not config and not session:
raise AdvancedAlchemyError(detail="Please supply an optional configuration or session to use.")
if session:
yield cls(session=session)
elif config:
async with config.get_session() as db_session:
yield cls(session=db_session)
class SQLAlchemyAsyncRepositoryReadService(ResultConverter, Generic[ModelT, SQLAlchemyAsyncRepositoryT]):
"""Service object that operates on a repository object."""
repository_type: type[SQLAlchemyAsyncRepositoryT]
"""Type of the repository to use."""
loader_options: ClassVar[Optional[LoadSpec]] = None
"""Default loader options for the repository."""
execution_options: ClassVar[Optional[dict[str, Any]]] = None
"""Default execution options for the repository."""
match_fields: ClassVar[Optional[Union[list[str], str]]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: ClassVar[bool] = False
"""Optionally apply the ``unique()`` method to results before returning."""
count_with_window_function: ClassVar[bool] = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query."""
_repository_instance: SQLAlchemyAsyncRepositoryT
def __init__(
self,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
*,
statement: Optional[Select[Any]] = None,
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
**repo_kwargs: Any,
) -> None:
"""Configure the service object.
Args:
session: Session managing the unit-of-work for the operation.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap exceptions in a RepositoryError
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
**repo_kwargs: passed as keyword args to repo instantiation.
"""
load = load if load is not None else self.loader_options
execution_options = execution_options if execution_options is not None else self.execution_options
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._repository_instance: SQLAlchemyAsyncRepositoryT = self.repository_type( # type: ignore[assignment]
session=session,
statement=statement,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
auto_commit=auto_commit,
order_by=order_by,
error_messages=error_messages,
wrap_exceptions=wrap_exceptions,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
count_with_window_function=count_with_window_function,
**repo_kwargs,
)
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
return bool(uniquify or self.uniquify)
@property
def repository(self) -> SQLAlchemyAsyncRepositoryT:
"""Return the repository instance.
Raises:
ImproperConfigurationError: If the repository is not initialized.
Returns:
The repository instance.
"""
if not self._repository_instance:
msg = "Repository not initialized"
raise ImproperConfigurationError(msg)
return self._repository_instance
@cached_property
def model_type(self) -> type[ModelT]:
"""Return the model type."""
return cast("type[ModelT]", self.repository.model_type)
async def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
"""Count of records returned by query.
Args:
*filters: arguments for filtering.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: key value pairs of filter types.
Returns:
A count of the collection, filtered, but ignoring pagination.
"""
return await self.repository.count(
*filters,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
)
async def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
"""Wrap repository exists operation.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Keyword arguments for attribute based filtering.
Returns:
Representation of instance with identifier `item_id`.
"""
return await self.repository.exists(
*filters,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
)
async def get(
self,
item_id: Any,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository scalar operation.
Args:
item_id: Identifier of instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"ModelT",
await self.repository.get(
item_id=item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
"""Wrap repository scalar operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"ModelT",
await self.repository.get_one(
*filters,
auto_expunge=auto_expunge,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Wrap repository scalar operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"Optional[ModelT]",
await self.repository.get_one_or_none(
*filters,
auto_expunge=auto_expunge,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
async def to_model_on_create(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on create.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
async def to_model_on_update(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on update.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
async def to_model_on_delete(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on delete.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
async def to_model_on_upsert(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on upsert.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
async def to_model(
self,
data: "ModelDictT[ModelT]",
operation: Optional[str] = None,
) -> ModelT:
"""Parse and Convert input into a model.
Args:
data: Representations to be created.
operation: Optional operation flag so that you can provide behavior based on CRUD operation
Returns:
Representation of created instances.
"""
operation_map = {
"create": self.to_model_on_create,
"update": self.to_model_on_update,
"delete": self.to_model_on_delete,
"upsert": self.to_model_on_upsert,
}
if operation and (op := operation_map.get(operation)):
data = await op(data)
if is_dict(data):
return model_from_dict(model=self.model_type, **data)
if is_pydantic_model(data):
return model_from_dict(
model=self.model_type,
**data.model_dump(exclude_unset=True),
)
if is_msgspec_struct(data):
from msgspec import UNSET
return model_from_dict(
model=self.model_type,
**{f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET},
)
if is_dto_data(data):
return cast("ModelT", data.create_instance())
return cast("ModelT", data)
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[Sequence[ModelT], int]:
"""List of records and total count returned by query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
List of instances and count of total collection, ignoring pagination.
"""
return cast(
"tuple[Sequence[ModelT], int]",
await self.repository.list_and_count(
*filters,
statement=statement,
auto_expunge=auto_expunge,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
@classmethod
@asynccontextmanager
async def new(
cls,
session: Optional[Union[AsyncSession, async_scoped_session[AsyncSession]]] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
config: Optional[SQLAlchemyAsyncConfig] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> AsyncIterator[Self]:
"""Context manager that returns instance of service object.
Handles construction of the database session._create_select_for_model
Raises:
AdvancedAlchemyError: If no configuration or session is provided.
Yields:
The service object instance.
"""
if not config and not session:
raise AdvancedAlchemyError(detail="Please supply an optional configuration or session to use.")
if session:
yield cls(
statement=statement,
session=session,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
elif config:
async with config.get_session() as db_session:
yield cls(
statement=statement,
session=db_session,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Wrap repository scalars operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances retrieved from the repository.
"""
return cast(
"Sequence[ModelT]",
await self.repository.list(
*filters,
statement=statement,
auto_expunge=auto_expunge,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
class SQLAlchemyAsyncRepositoryService(
SQLAlchemyAsyncRepositoryReadService[ModelT, SQLAlchemyAsyncRepositoryT],
Generic[ModelT, SQLAlchemyAsyncRepositoryT],
):
"""Service object that operates on a repository object."""
async def create(
self,
data: "ModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> "ModelT":
"""Wrap repository instance creation.
Args:
data: Representation to be created.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
Representation of created instance.
"""
data = await self.to_model(data, "create")
return cast(
"ModelT",
await self.repository.add(
data=data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
),
)
async def create_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance creation.
Args:
data: Representations to be created.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of created instances.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(await self.to_model(datum, "create")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
await self.repository.add_many(
data=cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
),
)
async def update(
self,
data: "ModelDictT[ModelT]",
item_id: Optional[Any] = None,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> "ModelT":
"""Wrap repository update operation.
Args:
data: Representation to be updated.
item_id: Identifier of item to be updated.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Raises:
RepositoryError: If no configuration or session is provided.
Returns:
Updated representation.
"""
data = await self.to_model(data, "update")
if (
item_id is None
and self.repository.get_id_attribute_value( # pyright: ignore[reportUnknownMemberType]
item=data,
id_attribute=id_attribute,
)
is None
):
msg = (
"Could not identify ID attribute value. One of the following is required: "
f"``item_id`` or ``data.{id_attribute or self.repository.id_attribute}``"
)
raise RepositoryError(msg)
if item_id is not None:
data = self.repository.set_id_attribute_value(item_id=item_id, item=data, id_attribute=id_attribute) # pyright: ignore[reportUnknownMemberType]
return cast(
"ModelT",
await self.repository.update(
data=data,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def update_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance update.
Args:
data: Representations to be updated.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of updated instances.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(await self.to_model(datum, "update")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
await self.repository.update_many(
cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def upsert(
self,
data: "ModelDictT[ModelT]",
item_id: Optional[Any] = None,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository upsert operation.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
item_id: Identifier of the object for upsert.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated or created representation.
"""
data = await self.to_model(data, "upsert")
item_id = item_id if item_id is not None else self.repository.get_id_attribute_value(item=data) # pyright: ignore[reportUnknownMemberType]
if item_id is not None:
self.repository.set_id_attribute_value(item_id, data) # pyright: ignore[reportUnknownMemberType]
return cast(
"ModelT",
await self.repository.upsert(
data=data,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_expunge=auto_expunge,
auto_commit=auto_commit,
auto_refresh=auto_refresh,
match_fields=match_fields,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def upsert_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository upsert operation.
Args:
data: Instance to update existing, or be created.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements (**reserved for future use**)
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated or created representation.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(await self.to_model(datum, "upsert")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
await self.repository.upsert_many(
data=cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_expunge=auto_expunge,
auto_commit=auto_commit,
no_merge=no_merge,
match_fields=match_fields,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, perform an update operation on the model.
create: Should a model be created. If no model is found, an exception is raised.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of created instance.
"""
match_fields = match_fields or self.match_fields
validated_model = await self.to_model(kwargs, "create")
return cast(
"tuple[ModelT, bool]",
await self.repository.get_or_upsert(
*filters,
match_fields=match_fields,
upsert=upsert,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**validated_model.to_dict(),
),
)
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of updated instance.
"""
match_fields = match_fields or self.match_fields
validated_model = await self.to_model(kwargs, "update")
return cast(
"tuple[ModelT, bool]",
await self.repository.get_and_update(
*filters,
match_fields=match_fields,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**validated_model.to_dict(),
),
)
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository delete operation.
Args:
item_id: Identifier of instance to be deleted.
auto_commit: Commit objects before returning.
auto_expunge: Remove object from session before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of the deleted instance.
"""
return cast(
"ModelT",
await self.repository.delete(
item_id=item_id,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance deletion.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of removed instances.
"""
return cast(
"Sequence[ModelT]",
await self.repository.delete_many(
item_ids=item_ids,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
id_attribute=id_attribute,
chunk_size=chunk_size,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Wrap repository scalars operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances deleted from the repository.
"""
return cast(
"Sequence[ModelT]",
await self.repository.delete_where(
*filters,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
sanity_check=sanity_check,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
python-advanced-alchemy-1.4.1/advanced_alchemy/service/_sync.py 0000664 0000000 0000000 00000135460 15003544734 0024633 0 ustar 00root root 0000000 0000000 # Do not edit this file directly. It has been autogenerated from
# advanced_alchemy/service/_async.py
# ruff: noqa: PLR6301
"""Service object implementation for SQLAlchemy.
RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from functools import cached_property
from typing import Any, ClassVar, Generic, Optional, Union, cast
from sqlalchemy import Select
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.sql import ColumnElement
from typing_extensions import Self
from advanced_alchemy.config.sync import SQLAlchemySyncConfig
from advanced_alchemy.exceptions import AdvancedAlchemyError, ErrorMessages, ImproperConfigurationError, RepositoryError
from advanced_alchemy.filters import StatementFilter
from advanced_alchemy.repository import SQLAlchemySyncQueryRepository
from advanced_alchemy.repository._util import LoadSpec, model_from_dict
from advanced_alchemy.repository.typing import ModelT, OrderingPair, SQLAlchemySyncRepositoryT
from advanced_alchemy.service._util import ResultConverter
from advanced_alchemy.service.typing import (
BulkModelDictT,
ModelDictListT,
ModelDictT,
is_dict,
is_dto_data,
is_msgspec_struct,
is_pydantic_model,
)
from advanced_alchemy.utils.dataclass import Empty, EmptyType
class SQLAlchemySyncQueryService(ResultConverter):
"""Simple service to execute the basic Query repository.."""
def __init__(
self,
session: Union[Session, scoped_session[Session]],
**repo_kwargs: Any,
) -> None:
"""Configure the service object.
Args:
session: Session managing the unit-of-work for the operation.
**repo_kwargs: Optional configuration values to pass into the repository
"""
self.repository = SQLAlchemySyncQueryRepository(
session=session,
**repo_kwargs,
)
@classmethod
@contextmanager
def new(
cls,
session: Optional[Union[Session, scoped_session[Session]]] = None,
config: Optional[SQLAlchemySyncConfig] = None,
) -> Iterator[Self]:
"""Context manager that returns instance of service object.
Handles construction of the database session._create_select_for_model
Raises:
AdvancedAlchemyError: If no configuration or session is provided.
Yields:
The service object instance.
"""
if not config and not session:
raise AdvancedAlchemyError(detail="Please supply an optional configuration or session to use.")
if session:
yield cls(session=session)
elif config:
with config.get_session() as db_session:
yield cls(session=db_session)
class SQLAlchemySyncRepositoryReadService(ResultConverter, Generic[ModelT, SQLAlchemySyncRepositoryT]):
"""Service object that operates on a repository object."""
repository_type: type[SQLAlchemySyncRepositoryT]
"""Type of the repository to use."""
loader_options: ClassVar[Optional[LoadSpec]] = None
"""Default loader options for the repository."""
execution_options: ClassVar[Optional[dict[str, Any]]] = None
"""Default execution options for the repository."""
match_fields: ClassVar[Optional[Union[list[str], str]]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: ClassVar[bool] = False
"""Optionally apply the ``unique()`` method to results before returning."""
count_with_window_function: ClassVar[bool] = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query."""
_repository_instance: SQLAlchemySyncRepositoryT
def __init__(
self,
session: Union[Session, scoped_session[Session]],
*,
statement: Optional[Select[Any]] = None,
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
**repo_kwargs: Any,
) -> None:
"""Configure the service object.
Args:
session: Session managing the unit-of-work for the operation.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap exceptions in a RepositoryError
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
**repo_kwargs: passed as keyword args to repo instantiation.
"""
load = load if load is not None else self.loader_options
execution_options = execution_options if execution_options is not None else self.execution_options
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._repository_instance: SQLAlchemySyncRepositoryT = self.repository_type( # type: ignore[assignment]
session=session,
statement=statement,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
auto_commit=auto_commit,
order_by=order_by,
error_messages=error_messages,
wrap_exceptions=wrap_exceptions,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
count_with_window_function=count_with_window_function,
**repo_kwargs,
)
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
return bool(uniquify or self.uniquify)
@property
def repository(self) -> SQLAlchemySyncRepositoryT:
"""Return the repository instance.
Raises:
ImproperConfigurationError: If the repository is not initialized.
Returns:
The repository instance.
"""
if not self._repository_instance:
msg = "Repository not initialized"
raise ImproperConfigurationError(msg)
return self._repository_instance
@cached_property
def model_type(self) -> type[ModelT]:
"""Return the model type."""
return cast("type[ModelT]", self.repository.model_type)
def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
"""Count of records returned by query.
Args:
*filters: arguments for filtering.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: key value pairs of filter types.
Returns:
A count of the collection, filtered, but ignoring pagination.
"""
return self.repository.count(
*filters,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
)
def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
"""Wrap repository exists operation.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Keyword arguments for attribute based filtering.
Returns:
Representation of instance with identifier `item_id`.
"""
return self.repository.exists(
*filters,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
)
def get(
self,
item_id: Any,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository scalar operation.
Args:
item_id: Identifier of instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"ModelT",
self.repository.get(
item_id=item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
"""Wrap repository scalar operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"ModelT",
self.repository.get_one(
*filters,
auto_expunge=auto_expunge,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Wrap repository scalar operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"Optional[ModelT]",
self.repository.get_one_or_none(
*filters,
auto_expunge=auto_expunge,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
def to_model_on_create(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on create.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
def to_model_on_update(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on update.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
def to_model_on_delete(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on delete.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
def to_model_on_upsert(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on upsert.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
def to_model(
self,
data: "ModelDictT[ModelT]",
operation: Optional[str] = None,
) -> ModelT:
"""Parse and Convert input into a model.
Args:
data: Representations to be created.
operation: Optional operation flag so that you can provide behavior based on CRUD operation
Returns:
Representation of created instances.
"""
operation_map = {
"create": self.to_model_on_create,
"update": self.to_model_on_update,
"delete": self.to_model_on_delete,
"upsert": self.to_model_on_upsert,
}
if operation and (op := operation_map.get(operation)):
data = op(data)
if is_dict(data):
return model_from_dict(model=self.model_type, **data)
if is_pydantic_model(data):
return model_from_dict(
model=self.model_type,
**data.model_dump(exclude_unset=True),
)
if is_msgspec_struct(data):
from msgspec import UNSET
return model_from_dict(
model=self.model_type,
**{f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET},
)
if is_dto_data(data):
return cast("ModelT", data.create_instance())
return cast("ModelT", data)
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[Sequence[ModelT], int]:
"""List of records and total count returned by query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
List of instances and count of total collection, ignoring pagination.
"""
return cast(
"tuple[Sequence[ModelT], int]",
self.repository.list_and_count(
*filters,
statement=statement,
auto_expunge=auto_expunge,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
@classmethod
@contextmanager
def new(
cls,
session: Optional[Union[Session, scoped_session[Session]]] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
config: Optional[SQLAlchemySyncConfig] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Iterator[Self]:
"""Context manager that returns instance of service object.
Handles construction of the database session._create_select_for_model
Raises:
AdvancedAlchemyError: If no configuration or session is provided.
Yields:
The service object instance.
"""
if not config and not session:
raise AdvancedAlchemyError(detail="Please supply an optional configuration or session to use.")
if session:
yield cls(
statement=statement,
session=session,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
elif config:
with config.get_session() as db_session:
yield cls(
statement=statement,
session=db_session,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Wrap repository scalars operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances retrieved from the repository.
"""
return cast(
"Sequence[ModelT]",
self.repository.list(
*filters,
statement=statement,
auto_expunge=auto_expunge,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
class SQLAlchemySyncRepositoryService(
SQLAlchemySyncRepositoryReadService[ModelT, SQLAlchemySyncRepositoryT],
Generic[ModelT, SQLAlchemySyncRepositoryT],
):
"""Service object that operates on a repository object."""
def create(
self,
data: "ModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> "ModelT":
"""Wrap repository instance creation.
Args:
data: Representation to be created.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
Representation of created instance.
"""
data = self.to_model(data, "create")
return cast(
"ModelT",
self.repository.add(
data=data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
),
)
def create_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance creation.
Args:
data: Representations to be created.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of created instances.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(self.to_model(datum, "create")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
self.repository.add_many(
data=cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
),
)
def update(
self,
data: "ModelDictT[ModelT]",
item_id: Optional[Any] = None,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> "ModelT":
"""Wrap repository update operation.
Args:
data: Representation to be updated.
item_id: Identifier of item to be updated.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Raises:
RepositoryError: If no configuration or session is provided.
Returns:
Updated representation.
"""
data = self.to_model(data, "update")
if (
item_id is None
and self.repository.get_id_attribute_value( # pyright: ignore[reportUnknownMemberType]
item=data,
id_attribute=id_attribute,
)
is None
):
msg = (
"Could not identify ID attribute value. One of the following is required: "
f"``item_id`` or ``data.{id_attribute or self.repository.id_attribute}``"
)
raise RepositoryError(msg)
if item_id is not None:
data = self.repository.set_id_attribute_value(item_id=item_id, item=data, id_attribute=id_attribute) # pyright: ignore[reportUnknownMemberType]
return cast(
"ModelT",
self.repository.update(
data=data,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def update_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance update.
Args:
data: Representations to be updated.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of updated instances.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(self.to_model(datum, "update")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
self.repository.update_many(
cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def upsert(
self,
data: "ModelDictT[ModelT]",
item_id: Optional[Any] = None,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository upsert operation.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
item_id: Identifier of the object for upsert.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated or created representation.
"""
data = self.to_model(data, "upsert")
item_id = item_id if item_id is not None else self.repository.get_id_attribute_value(item=data) # pyright: ignore[reportUnknownMemberType]
if item_id is not None:
self.repository.set_id_attribute_value(item_id, data) # pyright: ignore[reportUnknownMemberType]
return cast(
"ModelT",
self.repository.upsert(
data=data,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_expunge=auto_expunge,
auto_commit=auto_commit,
auto_refresh=auto_refresh,
match_fields=match_fields,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def upsert_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository upsert operation.
Args:
data: Instance to update existing, or be created.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements (**reserved for future use**)
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated or created representation.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(self.to_model(datum, "upsert")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
self.repository.upsert_many(
data=cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_expunge=auto_expunge,
auto_commit=auto_commit,
no_merge=no_merge,
match_fields=match_fields,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, perform an update operation on the model.
create: Should a model be created. If no model is found, an exception is raised.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of created instance.
"""
match_fields = match_fields or self.match_fields
validated_model = self.to_model(kwargs, "create")
return cast(
"tuple[ModelT, bool]",
self.repository.get_or_upsert(
*filters,
match_fields=match_fields,
upsert=upsert,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**validated_model.to_dict(),
),
)
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of updated instance.
"""
match_fields = match_fields or self.match_fields
validated_model = self.to_model(kwargs, "update")
return cast(
"tuple[ModelT, bool]",
self.repository.get_and_update(
*filters,
match_fields=match_fields,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**validated_model.to_dict(),
),
)
def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository delete operation.
Args:
item_id: Identifier of instance to be deleted.
auto_commit: Commit objects before returning.
auto_expunge: Remove object from session before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of the deleted instance.
"""
return cast(
"ModelT",
self.repository.delete(
item_id=item_id,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance deletion.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of removed instances.
"""
return cast(
"Sequence[ModelT]",
self.repository.delete_many(
item_ids=item_ids,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
id_attribute=id_attribute,
chunk_size=chunk_size,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Wrap repository scalars operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances deleted from the repository.
"""
return cast(
"Sequence[ModelT]",
self.repository.delete_where(
*filters,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
sanity_check=sanity_check,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
python-advanced-alchemy-1.4.1/advanced_alchemy/service/_typing.py 0000664 0000000 0000000 00000007432 15003544734 0025166 0 ustar 00root root 0000000 0000000 # ruff: noqa: DOC201, PLR6301
"""This is a simple wrapper around a few important classes in each library.
This is used to ensure compatibility when one or more of the libraries are installed.
"""
from typing import (
Any,
ClassVar,
Optional,
Protocol,
cast,
runtime_checkable,
)
from typing_extensions import TypeVar, dataclass_transform
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
try:
from pydantic import BaseModel, FailFast, TypeAdapter # pyright: ignore
PYDANTIC_INSTALLED = True
except ImportError:
@runtime_checkable
class BaseModel(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""
model_fields: "ClassVar[dict[str, Any]]"
def model_dump(self, *args: Any, **kwargs: Any) -> "dict[str, Any]":
"""Placeholder"""
return {}
@runtime_checkable
class TypeAdapter(Protocol[T_co]): # type: ignore[no-redef]
"""Placeholder Implementation"""
def __init__(
self,
type: Any, # noqa: A002
*,
config: "Optional[Any]" = None,
_parent_depth: int = 2,
module: "Optional[str]" = None,
) -> None:
"""Init"""
def validate_python(
self,
object: Any, # noqa: A002
/,
*,
strict: "Optional[bool]" = None,
from_attributes: "Optional[bool]" = None,
context: "Optional[dict[str, Any]]" = None,
) -> T_co:
"""Stub"""
return cast("T_co", object)
@runtime_checkable
class FailFast(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation for FailFast"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init"""
PYDANTIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
try:
from msgspec import (
UNSET,
Struct,
UnsetType, # pyright: ignore[reportAssignmentType,reportGeneralTypeIssues]
convert, # pyright: ignore[reportGeneralTypeIssues]
)
MSGSPEC_INSTALLED: bool = True
except ImportError:
import enum
@dataclass_transform()
@runtime_checkable
class Struct(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""
__struct_fields__: "ClassVar[tuple[str, ...]]"
def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef] # noqa: ARG001
"""Placeholder implementation"""
return {}
class UnsetType(enum.Enum): # type: ignore[no-redef]
UNSET = "UNSET"
UNSET = UnsetType.UNSET # pyright: ignore[reportConstantRedefinition]
MSGSPEC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
try:
from litestar.dto.data_structures import DTOData
LITESTAR_INSTALLED = True
except ImportError:
@runtime_checkable
class DTOData(Protocol[T]): # type: ignore[no-redef]
"""Placeholder implementation"""
__slots__ = ("_backend", "_data_as_builtins")
def __init__(self, backend: Any, data_as_builtins: Any) -> None:
"""Placeholder init"""
def create_instance(self, **kwargs: Any) -> T:
"""Placeholder implementation"""
return cast("T", kwargs)
def update_instance(self, instance: T, **kwargs: Any) -> T:
"""Placeholder implementation"""
return cast("T", kwargs)
def as_builtins(self) -> Any:
"""Placeholder implementation"""
return {}
LITESTAR_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
__all__ = (
"LITESTAR_INSTALLED",
"MSGSPEC_INSTALLED",
"PYDANTIC_INSTALLED",
"UNSET",
"BaseModel",
"DTOData",
"FailFast",
"Struct",
"TypeAdapter",
"UnsetType",
"convert",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/service/_util.py 0000664 0000000 0000000 00000025554 15003544734 0024636 0 ustar 00root root 0000000 0000000 """Service object implementation for SQLAlchemy.
RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""
import datetime
from collections.abc import Sequence
from enum import Enum
from functools import partial
from pathlib import Path, PurePath
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, overload
from uuid import UUID
from advanced_alchemy.exceptions import AdvancedAlchemyError
from advanced_alchemy.filters import LimitOffset, StatementFilter
from advanced_alchemy.repository.typing import ModelOrRowMappingT
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service.typing import (
MSGSPEC_INSTALLED,
PYDANTIC_INSTALLED,
BaseModel,
FilterTypeT,
ModelDTOT,
Struct,
convert,
get_type_adapter,
)
if TYPE_CHECKING:
from sqlalchemy import ColumnElement, RowMapping
from advanced_alchemy.base import ModelProtocol
__all__ = ("ResultConverter", "find_filter")
DEFAULT_TYPE_DECODERS = [ # pyright: ignore[reportUnknownVariableType]
(lambda x: x is UUID, lambda t, v: t(v.hex)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
(lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
(lambda x: x is datetime.date, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
(lambda x: x is datetime.time, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
(lambda x: x is Enum, lambda t, v: t(v.value)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
]
def _default_msgspec_deserializer(
target_type: Any,
value: Any,
type_decoders: "Union[Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]], None]" = None,
) -> Any: # pragma: no cover
"""Transform values non-natively supported by ``msgspec``
Args:
target_type: Encountered type
value: Value to coerce
type_decoders: Optional sequence of type decoders
Raises:
TypeError: If the value cannot be coerced to the target type
Returns:
A ``msgspec``-supported type
"""
if isinstance(value, target_type):
return value
if type_decoders:
for predicate, decoder in type_decoders:
if predicate(target_type):
return decoder(target_type, value)
if issubclass(target_type, (Path, PurePath, UUID)):
return target_type(value)
try:
return target_type(value)
except Exception as e:
msg = f"Unsupported type: {type(value)!r}"
raise TypeError(msg) from e
def find_filter(
filter_type: type[FilterTypeT],
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter]]",
) -> "Union[FilterTypeT, None]":
"""Get the filter specified by filter type from the filters.
Args:
filter_type: The type of filter to find.
filters: filter types to apply to the query
Returns:
The match filter instance or None
"""
return next(
(cast("Optional[FilterTypeT]", filter_) for filter_ in filters if isinstance(filter_, filter_type)),
None,
)
class ResultConverter:
"""Simple mixin to help convert to a paginated response model.
Single objects are transformed to the supplied schema type, and lists of objects are automatically transformed into an `OffsetPagination` response of the supplied schema type.
Args:
data: A database model instance or row mapping.
Type: :class:`~advanced_alchemy.repository.typing.ModelOrRowMappingT`
Returns:
The converted schema object.
"""
@overload
def to_schema(
self,
data: "ModelOrRowMappingT",
*,
schema_type: None = None,
) -> "ModelOrRowMappingT": ...
@overload
def to_schema(
self,
data: "Union[ModelProtocol, RowMapping]",
*,
schema_type: "type[ModelDTOT]",
) -> "ModelDTOT": ...
@overload
def to_schema(
self,
data: "ModelOrRowMappingT",
total: "Optional[int]" = None,
*,
schema_type: None = None,
) -> "ModelOrRowMappingT": ...
@overload
def to_schema(
self,
data: "Union[ModelProtocol, RowMapping]",
total: "Optional[int]" = None,
*,
schema_type: "type[ModelDTOT]",
) -> "ModelDTOT": ...
@overload
def to_schema(
self,
data: "ModelOrRowMappingT",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: None = None,
) -> "ModelOrRowMappingT": ...
@overload
def to_schema(
self,
data: "Union[ModelProtocol, RowMapping]",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: "type[ModelDTOT]",
) -> "ModelDTOT": ...
@overload
def to_schema(
self,
data: "Sequence[ModelOrRowMappingT]",
*,
schema_type: None = None,
) -> "OffsetPagination[ModelOrRowMappingT]": ...
@overload
def to_schema(
self,
data: "Union[Sequence[ModelProtocol], Sequence[RowMapping]]",
*,
schema_type: "type[ModelDTOT]",
) -> "OffsetPagination[ModelDTOT]": ...
@overload
def to_schema(
self,
data: "Sequence[ModelOrRowMappingT]",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: None = None,
) -> "OffsetPagination[ModelOrRowMappingT]": ...
@overload
def to_schema(
self,
data: "Union[Sequence[ModelProtocol], Sequence[RowMapping]]",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: "type[ModelDTOT]",
) -> "OffsetPagination[ModelDTOT]": ...
def to_schema(
self,
data: "Union[ModelOrRowMappingT, Sequence[ModelOrRowMappingT], ModelProtocol, Sequence[ModelProtocol], RowMapping, Sequence[RowMapping]]",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: "Optional[type[ModelDTOT]]" = None,
) -> "Union[ModelOrRowMappingT, OffsetPagination[ModelOrRowMappingT], ModelDTOT, OffsetPagination[ModelDTOT]]":
"""Convert the object to a response schema.
When `schema_type` is None, the model is returned with no conversion.
Args:
data: The return from one of the service calls.
Type: :class:`~advanced_alchemy.repository.typing.ModelOrRowMappingT`
total: The total number of rows in the data.
filters: :class:`~advanced_alchemy.filters.StatementFilter`| :class:`sqlalchemy.sql.expression.ColumnElement` Collection of route filters.
schema_type: :class:`~advanced_alchemy.service.typing.ModelDTOT` Optional schema type to convert the data to
Raises:
AdvancedAlchemyError: If `schema_type` is not a valid Pydantic or Msgspec schema and both libraries are not installed.
Returns:
:class:`~advanced_alchemy.base.ModelProtocol` | :class:`sqlalchemy.orm.RowMapping` | :class:`~advanced_alchemy.service.pagination.OffsetPagination` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`
"""
if filters is None:
filters = []
if schema_type is None:
if not isinstance(data, Sequence):
return cast("ModelOrRowMappingT", data) # type: ignore[unreachable,unused-ignore]
limit_offset = find_filter(LimitOffset, filters=filters)
total = total or len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[ModelOrRowMappingT](
items=cast("Sequence[ModelOrRowMappingT]", data),
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
)
if MSGSPEC_INSTALLED and issubclass(schema_type, Struct):
if not isinstance(data, Sequence):
return cast(
"ModelDTOT",
convert(
obj=data,
type=schema_type,
from_attributes=True,
dec_hook=partial(
_default_msgspec_deserializer,
type_decoders=DEFAULT_TYPE_DECODERS,
),
),
)
limit_offset = find_filter(LimitOffset, filters=filters)
total = total or len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[ModelDTOT](
items=convert(
obj=data,
type=list[schema_type], # type: ignore[valid-type]
from_attributes=True,
dec_hook=partial(
_default_msgspec_deserializer,
type_decoders=DEFAULT_TYPE_DECODERS,
),
),
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
)
if PYDANTIC_INSTALLED and issubclass(schema_type, BaseModel):
if not isinstance(data, Sequence):
return cast(
"ModelDTOT",
get_type_adapter(schema_type).validate_python(data, from_attributes=True),
)
limit_offset = find_filter(LimitOffset, filters=filters)
total = total or len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[ModelDTOT](
items=get_type_adapter(list[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
)
if not MSGSPEC_INSTALLED and not PYDANTIC_INSTALLED:
msg = "Either Msgspec or Pydantic must be installed to use schema conversion"
raise AdvancedAlchemyError(msg)
msg = "`schema_type` should be a valid Pydantic or Msgspec schema"
raise AdvancedAlchemyError(msg)
python-advanced-alchemy-1.4.1/advanced_alchemy/service/pagination.py 0000664 0000000 0000000 00000001153 15003544734 0025640 0 ustar 00root root 0000000 0000000 from collections.abc import Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar
T = TypeVar("T")
__all__ = ("OffsetPagination",)
@dataclass
class OffsetPagination(Generic[T]):
"""Container for data returned using limit/offset pagination."""
__slots__ = ("items", "limit", "offset", "total")
items: Sequence[T]
"""List of data being sent as part of the response."""
limit: int
"""Maximal number of items to send."""
offset: int
"""Offset from the beginning of the query.
Identical to an index.
"""
total: int
"""Total number of items."""
python-advanced-alchemy-1.4.1/advanced_alchemy/service/typing.py 0000664 0000000 0000000 00000026360 15003544734 0025030 0 ustar 00root root 0000000 0000000 """Service object implementation for SQLAlchemy.
RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Annotated,
Any,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias, TypeGuard
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.service._typing import (
LITESTAR_INSTALLED,
MSGSPEC_INSTALLED,
PYDANTIC_INSTALLED,
UNSET,
BaseModel,
DTOData,
FailFast,
Struct,
TypeAdapter,
convert,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from advanced_alchemy.filters import StatementFilter
PYDANTIC_USE_FAILFAST = False # leave permanently disabled for now
T = TypeVar("T")
FilterTypeT = TypeVar("FilterTypeT", bound="StatementFilter")
"""Type variable for filter types.
:class:`~advanced_alchemy.filters.StatementFilter`
"""
SupportedSchemaModel: TypeAlias = Union[Struct, BaseModel]
"""Type alias for objects that support to_dict or model_dump methods."""
ModelDTOT = TypeVar("ModelDTOT", bound="SupportedSchemaModel")
"""Type variable for model DTOs.
:class:`msgspec.Struct`|:class:`pydantic.BaseModel`
"""
PydanticOrMsgspecT = SupportedSchemaModel
"""Type alias for pydantic or msgspec models.
:class:`msgspec.Struct` or :class:`pydantic.BaseModel`
"""
ModelDictT: TypeAlias = "Union[dict[str, Any], ModelT, SupportedSchemaModel, DTOData[ModelT]]"
"""Type alias for model dictionaries.
Represents:
- :type:`dict[str, Any]` | :class:`~advanced_alchemy.base.ModelProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`litestar.dto.data_structures.DTOData` | :class:`~advanced_alchemy.base.ModelProtocol`
"""
ModelDictListT: TypeAlias = "Sequence[Union[dict[str, Any], ModelT, SupportedSchemaModel]]"
"""Type alias for model dictionary lists.
A list or sequence of any of the following:
- :type:`Sequence`[:type:`dict[str, Any]` | :class:`~advanced_alchemy.base.ModelProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`]
"""
BulkModelDictT: TypeAlias = (
"Union[Sequence[Union[dict[str, Any], ModelT, SupportedSchemaModel]], DTOData[list[ModelT]]]"
)
"""Type alias for bulk model dictionaries.
:type:`Sequence`[ :type:`dict[str, Any]` | :class:`~advanced_alchemy.base.ModelProtocol` | :class:`msgspec.Struct` :class:`pydantic.BaseModel`] | :class:`litestar.dto.data_structures.DTOData`
"""
@lru_cache(typed=True)
def get_type_adapter(f: type[T]) -> TypeAdapter[T]:
"""Caches and returns a pydantic type adapter.
Args:
f: Type to create a type adapter for.
Returns:
:class:`pydantic.TypeAdapter`[:class:`typing.TypeVar`[T]]
"""
if PYDANTIC_USE_FAILFAST:
return TypeAdapter(
Annotated[f, FailFast()],
)
return TypeAdapter(f)
def is_dto_data(v: Any) -> TypeGuard[DTOData[Any]]:
"""Check if a value is a Litestar DTOData object.
Args:
v: Value to check.
Returns:
bool
"""
return LITESTAR_INSTALLED and isinstance(v, DTOData)
def is_pydantic_model(v: Any) -> TypeGuard[BaseModel]:
"""Check if a value is a pydantic model.
Args:
v: Value to check.
Returns:
bool
"""
return PYDANTIC_INSTALLED and isinstance(v, BaseModel)
def is_msgspec_struct(v: Any) -> TypeGuard[Struct]:
"""Check if a value is a msgspec struct.
Args:
v: Value to check.
Returns:
bool
"""
return MSGSPEC_INSTALLED and isinstance(v, Struct)
def is_dataclass(obj: Any) -> TypeGuard[Any]:
"""Check if an object is a dataclass."""
return hasattr(obj, "__dataclass_fields__")
def is_dataclass_with_field(obj: Any, field_name: str) -> TypeGuard[object]: # Can't specify dataclass type directly
"""Check if an object is a dataclass and has a specific field."""
return is_dataclass(obj) and hasattr(obj, field_name)
def is_dataclass_without_field(obj: Any, field_name: str) -> TypeGuard[object]:
"""Check if an object is a dataclass and does not have a specific field."""
return is_dataclass(obj) and not hasattr(obj, field_name)
def is_dict(v: Any) -> TypeGuard[dict[str, Any]]:
"""Check if a value is a dictionary.
Args:
v: Value to check.
Returns:
bool
"""
return isinstance(v, dict)
def is_dict_with_field(v: Any, field_name: str) -> TypeGuard[dict[str, Any]]:
"""Check if a dictionary has a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_dict(v) and field_name in v
def is_dict_without_field(v: Any, field_name: str) -> TypeGuard[dict[str, Any]]:
"""Check if a dictionary does not have a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_dict(v) and field_name not in v
def is_pydantic_model_with_field(v: Any, field_name: str) -> TypeGuard[BaseModel]:
"""Check if a pydantic model has a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_pydantic_model(v) and hasattr(v, field_name)
def is_pydantic_model_without_field(v: Any, field_name: str) -> TypeGuard[BaseModel]:
"""Check if a pydantic model does not have a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_pydantic_model(v) and not hasattr(v, field_name)
def is_msgspec_struct_with_field(v: Any, field_name: str) -> TypeGuard[Struct]:
"""Check if a msgspec struct has a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_msgspec_struct(v) and hasattr(v, field_name)
def is_msgspec_struct_without_field(v: Any, field_name: str) -> "TypeGuard[Struct]":
"""Check if a msgspec struct does not have a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_msgspec_struct(v) and not hasattr(v, field_name)
def is_schema(v: Any) -> "TypeGuard[SupportedSchemaModel]":
"""Check if a value is a msgspec Struct or Pydantic model.
Args:
v: Value to check.
Returns:
bool
"""
return is_msgspec_struct(v) or is_pydantic_model(v)
def is_schema_or_dict(v: Any) -> "TypeGuard[Union[SupportedSchemaModel, dict[str, Any]]]":
"""Check if a value is a msgspec Struct, Pydantic model, or dict.
Args:
v: Value to check.
Returns:
bool
"""
return is_schema(v) or is_dict(v)
def is_schema_with_field(v: Any, field_name: str) -> "TypeGuard[SupportedSchemaModel]":
"""Check if a value is a msgspec Struct or Pydantic model with a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_msgspec_struct_with_field(v, field_name) or is_pydantic_model_with_field(v, field_name)
def is_schema_without_field(v: Any, field_name: str) -> "TypeGuard[SupportedSchemaModel]":
"""Check if a value is a msgspec Struct or Pydantic model without a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return not is_schema_with_field(v, field_name)
def is_schema_or_dict_with_field(v: Any, field_name: str) -> "TypeGuard[Union[SupportedSchemaModel, dict[str, Any]]]":
"""Check if a value is a msgspec Struct, Pydantic model, or dict with a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_schema_with_field(v, field_name) or is_dict_with_field(v, field_name)
def is_schema_or_dict_without_field(
v: Any, field_name: str
) -> "TypeGuard[Union[SupportedSchemaModel, dict[str, Any]]]":
"""Check if a value is a msgspec Struct, Pydantic model, or dict without a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return not is_schema_or_dict_with_field(v, field_name)
@overload
def schema_dump(
data: "Union[dict[str, Any], SupportedSchemaModel, DTOData[ModelT]]", exclude_unset: bool = True
) -> "Union[dict[str, Any], ModelT]": ...
@overload
def schema_dump(data: ModelT, exclude_unset: bool = True) -> ModelT: ...
def schema_dump( # noqa: PLR0911
data: "Union[dict[str, Any], ModelT, SupportedSchemaModel, DTOData[ModelT]]", exclude_unset: bool = True
) -> "Union[dict[str, Any], ModelT]":
"""Dump a data object to a dictionary.
Args:
data: :type:`dict[str, Any]` | :class:`advanced_alchemy.base.ModelProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`litestar.dto.data_structures.DTOData[ModelT]`
exclude_unset: :type:`bool` Whether to exclude unset values.
Returns:
Union[:type: dict[str, Any], :class:`~advanced_alchemy.base.ModelProtocol`]
"""
if is_dict(data):
return data
if is_pydantic_model(data):
return data.model_dump(exclude_unset=exclude_unset)
if is_msgspec_struct(data):
if exclude_unset:
return {f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET}
return {f: getattr(data, f, None) for f in data.__struct_fields__}
if is_dto_data(data):
return cast("dict[str, Any]", data.as_builtins())
if hasattr(data, "__dict__"):
return data.__dict__
return cast("ModelT", data) # type: ignore[no-return-any]
__all__ = (
"LITESTAR_INSTALLED",
"MSGSPEC_INSTALLED",
"PYDANTIC_INSTALLED",
"PYDANTIC_USE_FAILFAST",
"UNSET",
"BaseModel",
"BulkModelDictT",
"DTOData",
"FailFast",
"FilterTypeT",
"ModelDTOT",
"ModelDictListT",
"ModelDictT",
"PydanticOrMsgspecT",
"Struct",
"SupportedSchemaModel",
"TypeAdapter",
"UnsetType",
"convert",
"get_type_adapter",
"is_dataclass",
"is_dataclass_with_field",
"is_dataclass_without_field",
"is_dict",
"is_dict_with_field",
"is_dict_without_field",
"is_dto_data",
"is_msgspec_struct",
"is_msgspec_struct_with_field",
"is_msgspec_struct_without_field",
"is_pydantic_model",
"is_pydantic_model_with_field",
"is_pydantic_model_without_field",
"is_schema",
"is_schema_or_dict",
"is_schema_or_dict_with_field",
"is_schema_or_dict_without_field",
"is_schema_with_field",
"is_schema_without_field",
"schema_dump",
)
if TYPE_CHECKING:
if not PYDANTIC_INSTALLED:
from advanced_alchemy.service._typing import BaseModel, FailFast, TypeAdapter
else:
from pydantic import BaseModel, FailFast, TypeAdapter # type: ignore[assignment] # noqa: TC004
if not MSGSPEC_INSTALLED:
from advanced_alchemy.service._typing import UNSET, Struct, UnsetType, convert
else:
from msgspec import UNSET, Struct, UnsetType, convert # type: ignore[assignment] # noqa: TC004
if not LITESTAR_INSTALLED:
from advanced_alchemy.service._typing import DTOData
else:
from litestar.dto import DTOData # type: ignore[assignment] # noqa: TC004
python-advanced-alchemy-1.4.1/advanced_alchemy/types/ 0000775 0000000 0000000 00000000000 15003544734 0022641 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/types/__init__.py 0000664 0000000 0000000 00000002515 15003544734 0024755 0 ustar 00root root 0000000 0000000 """SQLAlchemy custom types for use with the ORM."""
from advanced_alchemy.types import encrypted_string, file_object, password_hash
from advanced_alchemy.types.datetime import DateTimeUTC
from advanced_alchemy.types.encrypted_string import (
EncryptedString,
EncryptedText,
EncryptionBackend,
FernetBackend,
)
from advanced_alchemy.types.file_object import (
FileObject,
FileObjectList,
StorageBackend,
StorageBackendT,
StorageRegistry,
StoredObject,
storages,
)
from advanced_alchemy.types.guid import GUID, NANOID_INSTALLED, UUID_UTILS_INSTALLED
from advanced_alchemy.types.identity import BigIntIdentity
from advanced_alchemy.types.json import ORA_JSONB, JsonB
from advanced_alchemy.types.mutables import MutableList
from advanced_alchemy.types.password_hash.base import HashedPassword, PasswordHash
__all__ = (
"GUID",
"NANOID_INSTALLED",
"ORA_JSONB",
"UUID_UTILS_INSTALLED",
"BigIntIdentity",
"DateTimeUTC",
"EncryptedString",
"EncryptedText",
"EncryptionBackend",
"FernetBackend",
"FileObject",
"FileObjectList",
"HashedPassword",
"JsonB",
"MutableList",
"PasswordHash",
"StorageBackend",
"StorageBackendT",
"StorageRegistry",
"StoredObject",
"encrypted_string",
"file_object",
"password_hash",
"storages",
)
python-advanced-alchemy-1.4.1/advanced_alchemy/types/datetime.py 0000664 0000000 0000000 00000002225 15003544734 0025010 0 ustar 00root root 0000000 0000000 # ruff: noqa: FA100
import datetime
from typing import Optional
from sqlalchemy import DateTime
from sqlalchemy.engine import Dialect
from sqlalchemy.types import TypeDecorator
__all__ = ("DateTimeUTC",)
class DateTimeUTC(TypeDecorator[datetime.datetime]):
"""Timezone Aware DateTime.
Ensure UTC is stored in the database and that TZ aware dates are returned for all dialects.
"""
impl = DateTime(timezone=True)
cache_ok = True
@property
def python_type(self) -> type[datetime.datetime]:
return datetime.datetime
def process_bind_param(self, value: Optional[datetime.datetime], dialect: Dialect) -> Optional[datetime.datetime]:
if value is None:
return value
if not value.tzinfo:
msg = "tzinfo is required"
raise TypeError(msg)
return value.astimezone(datetime.timezone.utc)
def process_result_value(self, value: Optional[datetime.datetime], dialect: Dialect) -> Optional[datetime.datetime]:
if value is None:
return value
if value.tzinfo is None:
return value.replace(tzinfo=datetime.timezone.utc)
return value
python-advanced-alchemy-1.4.1/advanced_alchemy/types/encrypted_string.py 0000664 0000000 0000000 00000030557 15003544734 0026610 0 ustar 00root root 0000000 0000000 import abc
import base64
import contextlib
import os
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from sqlalchemy import String, Text, TypeDecorator
from sqlalchemy import func as sql_func
from advanced_alchemy.exceptions import IntegrityError
if TYPE_CHECKING:
from sqlalchemy.engine import Dialect
cryptography = None # type: ignore[var-annotated,unused-ignore]
with contextlib.suppress(ImportError):
from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
__all__ = ("EncryptedString", "EncryptedText", "EncryptionBackend", "FernetBackend", "PGCryptoBackend")
class EncryptionBackend(abc.ABC):
"""Abstract base class for encryption backends.
This class defines the interface that all encryption backends must implement.
Concrete implementations should provide the actual encryption/decryption logic.
Attributes:
passphrase (bytes): The encryption passphrase used by the backend.
"""
def mount_vault(self, key: "Union[str, bytes]") -> None:
"""Mounts the vault with the provided encryption key.
Args:
key (str | bytes): The encryption key used to initialize the backend.
"""
if isinstance(key, str):
key = key.encode()
@abc.abstractmethod
def init_engine(self, key: "Union[bytes, str]") -> None: # pragma: no cover
"""Initializes the encryption engine with the provided key.
Args:
key (bytes | str): The encryption key.
Raises:
NotImplementedError: If the method is not implemented by the subclass.
"""
@abc.abstractmethod
def encrypt(self, value: Any) -> str: # pragma: no cover
"""Encrypts the given value.
Args:
value (Any): The value to encrypt.
Returns:
str: The encrypted value.
Raises:
NotImplementedError: If the method is not implemented by the subclass.
"""
@abc.abstractmethod
def decrypt(self, value: Any) -> str: # pragma: no cover
"""Decrypts the given value.
Args:
value (Any): The value to decrypt.
Returns:
str: The decrypted value.
Raises:
NotImplementedError: If the method is not implemented by the subclass.
"""
class PGCryptoBackend(EncryptionBackend):
"""PostgreSQL pgcrypto-based encryption backend.
This backend uses PostgreSQL's pgcrypto extension for encryption/decryption operations.
Requires the pgcrypto extension to be installed in the database.
Attributes:
passphrase (bytes): The base64-encoded passphrase used for encryption and decryption.
"""
def init_engine(self, key: "Union[bytes, str]") -> None:
"""Initializes the pgcrypto engine with the provided key.
Args:
key (bytes | str): The encryption key.
"""
if isinstance(key, str):
key = key.encode()
self.passphrase = base64.urlsafe_b64encode(key)
def encrypt(self, value: Any) -> str:
"""Encrypts the given value using pgcrypto.
Args:
value (Any): The value to encrypt.
Returns:
str: The encrypted value.
"""
if not isinstance(value, str): # pragma: no cover
value = repr(value)
value = value.encode()
return sql_func.pgp_sym_encrypt(value, self.passphrase) # type: ignore[return-value]
def decrypt(self, value: Any) -> str:
"""Decrypts the given value using pgcrypto.
Args:
value (Any): The value to decrypt.
Returns:
str: The decrypted value.
"""
if not isinstance(value, str): # pragma: no cover
value = str(value)
return sql_func.pgp_sym_decrypt(value, self.passphrase) # type: ignore[return-value]
class FernetBackend(EncryptionBackend):
"""Fernet-based encryption backend.
This backend uses the Python cryptography library's Fernet implementation
for encryption/decryption operations. Provides symmetric encryption with
built-in rotation support.
Attributes:
key (bytes): The base64-encoded key used for encryption and decryption.
fernet (cryptography.fernet.Fernet): The Fernet instance used for encryption/decryption.
"""
def mount_vault(self, key: "Union[str, bytes]") -> None:
"""Mounts the vault with the provided encryption key.
This method hashes the key using SHA256 before initializing the engine.
Args:
key (str | bytes): The encryption key.
"""
if isinstance(key, str):
key = key.encode()
digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) # pyright: ignore[reportPossiblyUnboundVariable]
digest.update(key)
engine_key = digest.finalize()
self.init_engine(engine_key)
def init_engine(self, key: "Union[bytes, str]") -> None:
"""Initializes the Fernet engine with the provided key.
Args:
key (bytes | str): The encryption key.
"""
if isinstance(key, str):
key = key.encode()
self.key = base64.urlsafe_b64encode(key)
self.fernet = Fernet(self.key) # pyright: ignore[reportPossiblyUnboundVariable]
def encrypt(self, value: Any) -> str:
"""Encrypts the given value using Fernet.
Args:
value (Any): The value to encrypt.
Returns:
str: The encrypted value.
"""
if not isinstance(value, str):
value = repr(value)
value = value.encode()
encrypted = self.fernet.encrypt(value)
return encrypted.decode("utf-8")
def decrypt(self, value: Any) -> str:
"""Decrypts the given value using Fernet.
Args:
value (Any): The value to decrypt.
Returns:
str: The decrypted value.
"""
if not isinstance(value, str): # pragma: no cover
value = str(value)
decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode())
if not isinstance(decrypted, str):
decrypted = decrypted.decode("utf-8") # pyright: ignore[reportAttributeAccessIssue]
return decrypted
DEFAULT_ENCRYPTION_KEY = os.urandom(32)
class EncryptedString(TypeDecorator[str]):
"""SQLAlchemy TypeDecorator for storing encrypted string values in a database.
This type provides transparent encryption/decryption of string values using the specified backend.
It extends :class:`sqlalchemy.types.TypeDecorator` and implements String as its underlying type.
Args:
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
length (int | None): The length of the unencrypted string. This is used for documentation and validation purposes only, as encrypted strings will be longer.
**kwargs (Any | None): Additional arguments passed to the underlying String type.
Attributes:
key (str | bytes | Callable[[], str | bytes]): The encryption key.
backend (EncryptionBackend): The encryption backend instance.
length (int | None): The unencrypted string length.
"""
impl = String
cache_ok = True
def __init__(
self,
key: "Union[str, bytes, Callable[[], Union[str, bytes]]]" = DEFAULT_ENCRYPTION_KEY,
backend: "type[EncryptionBackend]" = FernetBackend,
length: "Optional[int]" = None,
**kwargs: Any,
) -> None:
"""Initializes the EncryptedString TypeDecorator.
Args:
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
length (int | None): The length of the unencrypted string. This is used for documentation and validation purposes only.
**kwargs (Any | None): Additional arguments passed to the underlying String type.
"""
super().__init__()
self.key = key
self.backend = backend()
self.length = length
@property
def python_type(self) -> type[str]:
"""Returns the Python type for this type decorator.
Returns:
Type[str]: The Python string type.
"""
return str
def load_dialect_impl(self, dialect: "Dialect") -> Any:
"""Loads the appropriate dialect implementation based on the database dialect.
Note: The actual column length will be larger than the specified length due to encryption overhead.
For most encryption methods, the encrypted string will be approximately 1.35x longer than the original.
Args:
dialect (Dialect): The SQLAlchemy dialect.
Returns:
Any: The dialect-specific type descriptor.
"""
if dialect.name in {"mysql", "mariadb"}:
# For MySQL/MariaDB, always use Text to avoid length limitations
return dialect.type_descriptor(Text())
if dialect.name == "oracle":
# Oracle has a 4000-byte limit for VARCHAR2 (by default)
return dialect.type_descriptor(String(length=4000))
return dialect.type_descriptor(String())
def process_bind_param(self, value: Any, dialect: "Dialect") -> "Union[str, None]":
"""Processes the value before binding it to the SQL statement.
This method encrypts the value using the specified backend and validates length if specified.
Args:
value (Any): The value to process.
dialect (Dialect): The SQLAlchemy dialect.
Raises:
IntegrityError: If the unencrypted value exceeds the maximum length.
Returns:
str | None: The encrypted value or None if the input is None.
"""
if value is None:
return value
# Validate length if specified
if self.length is not None and len(str(value)) > self.length:
msg = f"Unencrypted value exceeds maximum unencrypted length of {self.length}"
raise IntegrityError(msg)
self.mount_vault()
return self.backend.encrypt(value)
def process_result_value(self, value: Any, dialect: "Dialect") -> "Union[str, None]":
"""Processes the value after retrieving it from the database.
This method decrypts the value using the specified backend.
Args:
value (Any): The value to process.
dialect (Dialect): The SQLAlchemy dialect.
Returns:
str | None: The decrypted value or None if the input is None.
"""
if value is None:
return value
self.mount_vault()
return self.backend.decrypt(value)
def mount_vault(self) -> None:
"""Mounts the vault with the encryption key.
If the key is callable, it is called to retrieve the key. Otherwise, the key is used directly.
"""
key = self.key() if callable(self.key) else self.key
self.backend.mount_vault(key)
class EncryptedText(EncryptedString):
"""SQLAlchemy TypeDecorator for storing encrypted text/CLOB values in a database.
This type provides transparent encryption/decryption of text values using the specified backend.
It extends :class:`EncryptedString` and implements Text as its underlying type.
This is suitable for storing larger encrypted text content compared to EncryptedString.
Args:
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
**kwargs (Any | None): Additional arguments passed to the underlying String type.
"""
impl = Text
cache_ok = True
def load_dialect_impl(self, dialect: "Dialect") -> Any:
"""Loads the appropriate dialect implementation for Text type.
Args:
dialect (Dialect): The SQLAlchemy dialect.
Returns:
Any: The dialect-specific Text type descriptor.
"""
return dialect.type_descriptor(Text())
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/ 0000775 0000000 0000000 00000000000 15003544734 0025106 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/__init__.py 0000664 0000000 0000000 00000002257 15003544734 0027225 0 ustar 00root root 0000000 0000000 """File object types for handling file metadata and operations using storage backends.
Provides `FileObject` for representing file metadata and `StoredObject` as the SQLAlchemy
type for database persistence. Includes support for various storage backends (`fsspec`, `obstore`).
The overall design, including concepts like storage backends and the separation of file
representation from the stored type, draws inspiration from the `sqlalchemy-file` library
[https://github.com/jowilf/sqlalchemy-file]. Special thanks to its contributors.
"""
from advanced_alchemy.types.file_object.base import AsyncDataLike, PathLike, StorageBackend, StorageBackendT
from advanced_alchemy.types.file_object.data_type import StoredObject
from advanced_alchemy.types.file_object.file import FileObject, FileObjectList
from advanced_alchemy.types.file_object.registry import StorageRegistry, storages
from advanced_alchemy.types.file_object.session_tracker import FileObjectSessionTracker
__all__ = [
"AsyncDataLike",
"FileObject",
"FileObjectList",
"FileObjectSessionTracker",
"PathLike",
"StorageBackend",
"StorageBackendT",
"StorageRegistry",
"StoredObject",
"storages",
]
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/_typing.py 0000664 0000000 0000000 00000003062 15003544734 0027132 0 ustar 00root root 0000000 0000000 """Internal typing helpers for file_object, handling optional Pydantic integration."""
from typing import Any, Protocol, TypeVar, runtime_checkable
# Define a generic type variable for CoreSchema placeholder if needed
CoreSchemaT = TypeVar("CoreSchemaT")
try:
# Attempt to import real Pydantic components
from pydantic import GetCoreSchemaHandler # pyright: ignore
from pydantic_core import core_schema # pyright: ignore
PYDANTIC_INSTALLED = True
except ImportError:
PYDANTIC_INSTALLED = False # pyright: ignore
@runtime_checkable
class GetCoreSchemaHandler(Protocol): # type: ignore[no-redef]
"""Placeholder for Pydantic's GetCoreSchemaHandler."""
def __call__(self, source_type: Any) -> Any: ...
def __getattr__(self, item: str) -> Any: # Allow arbitrary attribute access
return Any
# Define a placeholder for core_schema module
class CoreSchemaModulePlaceholder:
"""Placeholder for pydantic_core.core_schema module."""
# Define placeholder types/functions used in FileObject.__get_pydantic_core_schema__
CoreSchema = Any # Placeholder for the CoreSchema type itself
def __getattr__(self, name: str) -> Any:
"""Return a dummy function/type for any requested attribute."""
def dummy_schema_func(*args: Any, **kwargs: Any) -> Any: # noqa: ARG001
return Any
return dummy_schema_func
core_schema = CoreSchemaModulePlaceholder() # type: ignore[assignment]
__all__ = ("GetCoreSchemaHandler", "core_schema")
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/_utils.py 0000664 0000000 0000000 00000004635 15003544734 0026767 0 ustar 00root root 0000000 0000000 """Utility functions for file object types."""
from datetime import datetime
from typing import TYPE_CHECKING, Any, Optional
from zlib import adler32
if TYPE_CHECKING:
from advanced_alchemy.types.file_object.base import PathLike
from advanced_alchemy.types.file_object.file import FileObject
def get_mtime_equivalent(info: dict[str, Any]) -> Optional[float]:
"""Return standardized mtime from different implementations.
Args:
info: Dictionary containing file metadata
Returns:
Standardized timestamp or None if not available
"""
# Check these keys in order of preference
mtime_keys = (
"mtime",
"last_modified",
"uploaded_at",
"timestamp",
"Last-Modified",
"modified_at",
"modification_time",
)
mtime = next((info[key] for key in mtime_keys if key in info), None)
if mtime is None or isinstance(mtime, float):
return mtime
if isinstance(mtime, datetime):
return mtime.timestamp()
if isinstance(mtime, str):
try:
return datetime.fromisoformat(mtime.replace("Z", "+00:00")).timestamp()
except ValueError:
pass
return None
def get_or_generate_etag(file_object: "FileObject", info: dict[str, Any], modified_time: Optional[float] = None) -> str:
"""Return standardized etag from different implementations.
Args:
file_object: Path to the file
info: Dictionary containing file metadata
modified_time: Optional modified time for the file
Returns:
Standardized etag or None if not available
"""
# Check these keys in order of preference
etag_keys = (
"e_tag",
"etag",
"etag_key",
)
etag = next((info[key] for key in etag_keys if key in info), None)
if etag is not None:
return str(etag)
if file_object.etag is not None:
return file_object.etag
return create_etag_for_file(file_object.path, modified_time, info.get("size", file_object.size))
def create_etag_for_file(path: "PathLike", modified_time: Optional[float], file_size: int) -> str:
"""Create an etag.
Notes:
- Function is derived from flask.
Returns:
An etag.
"""
check = adler32(str(path).encode("utf-8")) & 0xFFFFFFFF
parts = [str(file_size), str(check)]
if modified_time:
parts.insert(0, str(modified_time))
return f'"{"-".join(parts)}"'
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/backends/ 0000775 0000000 0000000 00000000000 15003544734 0026660 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/backends/__init__.py 0000664 0000000 0000000 00000000000 15003544734 0030757 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/backends/fsspec.py 0000664 0000000 0000000 00000032222 15003544734 0030516 0 ustar 00root root 0000000 0000000 # advanced_alchemy/types/file_object/backends/fsspec.py
# ruff: noqa: PLR0904, SLF001, PLR1702, PLR6301
"""FSSpec-backed storage backend for file objects."""
import datetime
import os
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Sequence
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Optional, Union, cast
from advanced_alchemy.exceptions import MissingDependencyError
from advanced_alchemy.types.file_object._utils import get_mtime_equivalent, get_or_generate_etag
from advanced_alchemy.types.file_object.base import (
PathLike,
StorageBackend,
)
from advanced_alchemy.types.file_object.file import FileObject
from advanced_alchemy.utils.sync_tools import async_
try:
# Correct import for AsyncFileSystem and try importing async file handle
import fsspec # pyright: ignore[reportMissingTypeStubs]
from fsspec.asyn import AsyncFileSystem # pyright: ignore[reportMissingTypeStubs]
except ImportError as e:
msg = "fsspec"
raise MissingDependencyError(msg) from e
if TYPE_CHECKING:
from fsspec import AbstractFileSystem # pyright: ignore[reportMissingTypeStubs]
def _join_path(prefix: str, path: str) -> str:
if not prefix:
return path
prefix = prefix.rstrip("/")
path = path.lstrip("/")
return f"{prefix}/{path}"
class FSSpecBackend(StorageBackend):
"""FSSpec-backed storage backend implementing both sync and async operations."""
driver = "fsspec" # Changed backend identifier to driver
default_expires_in = 3600
prefix: Optional[str]
def __init__(
self,
key: str,
fs: "Union[AbstractFileSystem, AsyncFileSystem, str]",
prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize FSSpecBackend.
Args:
key: The key of the backend instance.
fs: The FSSpec filesystem instance (sync or async) or protocol string.
prefix: Optional path prefix to prepend to all paths.
**kwargs: Additional keyword arguments to pass to fsspec.filesystem.
"""
self.fs = fsspec.filesystem(fs, **kwargs) if isinstance(fs, str) else fs # pyright: ignore
self.is_async = isinstance(self.fs, AsyncFileSystem)
protocol = getattr(self.fs, "protocol", None)
protocol = cast("Optional[str]", protocol[0] if isinstance(protocol, (list, tuple)) else protocol)
self.protocol = protocol or "file"
self.key = key
self.prefix = prefix
self.kwargs = kwargs
def _prepare_path(self, path: PathLike) -> str:
path_str = self._to_path(path)
if self.prefix:
return _join_path(self.prefix, path_str)
return path_str
def get_content(self, path: PathLike, *, options: Optional[dict[str, Any]] = None) -> bytes:
"""Return the bytes stored at the specified location.
Args:
path: Path to retrieve (relative to prefix if set).
options: Optional backend-specific options passed to fsspec's open.
"""
content = self.fs.cat_file(self._prepare_path(path), **(options or {})) # pyright: ignore
if isinstance(content, str):
return content.encode("utf-8")
return cast("bytes", content)
async def get_content_async(self, path: PathLike, *, options: Optional[dict[str, Any]] = None) -> bytes:
"""Return the bytes stored at the specified location asynchronously.
Args:
path: Path to retrieve (relative to prefix if set).
options: Optional backend-specific options passed to fsspec's open.
"""
if not self.is_async:
# Fallback for sync filesystems - Note: get_content is sync, wrapping with async_
# Pass the original relative path to the sync method wrapper
return await async_(self.get_content)(path=path, options=options)
content = await self.fs._cat_file(self._prepare_path(path), **(options or {})) # pyright: ignore
if isinstance(content, str):
return content.encode("utf-8")
return cast("bytes", content)
def save_object(
self,
file_object: FileObject,
data: Union[bytes, IO[bytes], Path, Iterable[bytes]],
*,
use_multipart: Optional[bool] = None,
chunk_size: int = 5 * 1024 * 1024,
max_concurrency: int = 12,
) -> FileObject:
"""Save data to the specified path using info from FileObject.
Args:
file_object: FileObject instance with metadata (path, content_type, etc.)
Path should be relative if prefix is used.
data: The data to save (bytes, byte iterator, file-like object, Path)
use_multipart: Ignored.
chunk_size: Size of chunks when reading from IO/Path.
max_concurrency: Ignored.
Returns:
FileObject object representing the saved file, potentially updated.
"""
full_path = self._prepare_path(file_object.path)
if isinstance(data, Path):
self.fs.put(full_path, data) # pyright: ignore
else:
self.fs.pipe(full_path, data) # pyright: ignore
info = file_object.to_dict()
fs_info = self.fs.info(full_path) # pyright: ignore
if isinstance(fs_info, dict):
info.update(fs_info) # pyright: ignore
file_object.size = cast("int", info.get("size", file_object.size)) # pyright: ignore
file_object.last_modified = (
get_mtime_equivalent(info) or datetime.datetime.now(tz=datetime.timezone.utc).timestamp() # pyright: ignore
)
file_object.etag = get_or_generate_etag(file_object, info, file_object.last_modified) # pyright: ignore
# Merge backend metadata if available and different
backend_meta: dict[str, Any] = info.get("metadata", {}) # pyright: ignore
if backend_meta and backend_meta != file_object.metadata:
file_object.update_metadata(backend_meta) # pyright: ignore
return file_object
async def save_object_async(
self,
file_object: FileObject,
data: Union[bytes, IO[bytes], Path, Iterable[bytes], AsyncIterable[bytes]],
*,
use_multipart: Optional[bool] = None,
chunk_size: int = 5 * 1024 * 1024,
max_concurrency: int = 12,
) -> FileObject:
"""Save data to the specified path asynchronously using info from FileObject.
Args:
file_object: FileObject instance with metadata (path, content_type, etc.)
Path should be relative if prefix is used.
data: The data to save (bytes, async byte iterator, file-like object, Path)
use_multipart: Ignored.
chunk_size: Size of chunks when reading from IO/Path/AsyncIterator.
max_concurrency: Ignored.
Returns:
FileObject object representing the saved file, potentially updated.
"""
full_path = self._prepare_path(file_object.path)
if not self.is_async:
# Fallback for sync filesystems. Handle async data carefully.
# Pass the original relative path to the sync method wrapper
if isinstance(data, (AsyncIterator, AsyncIterable)) and not isinstance(data, (bytes, str)):
# Read async stream into memory for sync backend (potential memory issue)
all_data = b"".join([chunk async for chunk in data])
return await async_(self.save_object)(file_object=file_object, data=all_data, chunk_size=chunk_size)
return await async_(self.save_object)(file_object=file_object, data=data, chunk_size=chunk_size) # type: ignore
if isinstance(data, Path):
await self.fs._put(full_path, data) # pyright: ignore
else:
await self.fs._pipe(full_path, data) # pyright: ignore
info = file_object.to_dict()
fs_info = await self.fs._info(full_path) # pyright: ignore
if isinstance(fs_info, dict):
info.update(fs_info) # pyright: ignore
file_object.size = cast("int", info.get("size", file_object.size)) # pyright: ignore
file_object.last_modified = (
get_mtime_equivalent(info) or datetime.datetime.now(tz=datetime.timezone.utc).timestamp() # pyright: ignore
)
file_object.etag = get_or_generate_etag(file_object, info, file_object.last_modified) # pyright: ignore
# Merge backend metadata if available and different
backend_meta: dict[str, Any] = info.get("metadata", {}) # pyright: ignore
if backend_meta and backend_meta != file_object.metadata:
file_object.update_metadata(backend_meta) # pyright: ignore
return file_object
def delete_object(self, paths: Union[PathLike, Sequence[PathLike]]) -> None:
"""Delete the object(s) at the specified location(s).
Args:
paths: Path or sequence of paths to delete (relative to prefix if set).
"""
if isinstance(paths, (str, Path, os.PathLike)):
path_list = [self._prepare_path(paths)]
else:
path_list = [self._prepare_path(p) for p in paths]
self.fs.rm(path_list, recursive=False) # pyright: ignore
async def delete_object_async(self, paths: Union[PathLike, Sequence[PathLike]]) -> None:
"""Delete the object(s) at the specified location(s) asynchronously.
Args:
paths: Path or sequence of paths to delete (relative to prefix if set).
"""
if not self.is_async:
# Pass the original relative path(s) to the sync method wrapper
return await async_(self.delete_object)(paths=paths)
path_list = (
[self._prepare_path(paths)]
if isinstance(paths, (str, Path, os.PathLike))
else [self._prepare_path(p) for p in paths]
)
await self.fs._rm(path_list, recursive=False) # pyright: ignore
return None
def sign(
self,
paths: Union[PathLike, Sequence[PathLike]],
*,
expires_in: Optional[int] = None,
for_upload: bool = False, # Often not directly supported by generic fsspec sign
) -> Union[str, list[str]]:
"""Create signed URLs for accessing files.
Note: Upload URL generation (`for_upload=True`) is generally not supported
by fsspec's generic `sign` method. This typically requires
backend-specific methods (e.g., S3 presigned POST URLs).
Args:
paths: The path or paths of the file(s) (relative to prefix if set).
expires_in: The expiration time of the URL in seconds (backend-dependent default).
for_upload: If True, attempt to generate an upload URL (likely unsupported).
Returns:
A signed URL string if a single path is given, or a list of strings
if multiple paths are provided.
Raises:
NotImplementedError: If the backend doesn't support signing or if `for_upload=True`.
"""
if for_upload:
msg = "Generating signed URLs for upload is generally not supported by fsspec's generic sign method."
raise NotImplementedError(msg)
expires_in = expires_in or self.default_expires_in
is_single = isinstance(paths, (str, Path, os.PathLike))
path_list = [self._prepare_path(paths)] if is_single else [self._prepare_path(p) for p in paths] # type: ignore
if not hasattr(self.fs, "sign"):
msg = f"Filesystem object {type(self.fs).__name__} does not have a 'sign' method."
raise NotImplementedError(msg)
signed_urls: list[str] = []
try:
# fsspec sign method might take expiration in seconds
# Ensure this is a list comprehension, not a generator expression
signed_urls.extend([self.fs.sign(path_str, expiration=expires_in) for path_str in path_list]) # pyright: ignore
except NotImplementedError as e:
# This might be raised by the sign method itself if not implemented for the protocol
msg = f"Signing URLs not supported by {self.protocol} backend via fsspec."
raise NotImplementedError(msg) from e
return signed_urls[0] if is_single else signed_urls
async def sign_async(
self,
paths: Union[PathLike, Sequence[PathLike]],
*,
expires_in: Optional[int] = None,
for_upload: bool = False,
) -> Union[str, list[str]]:
"""Create signed URLs for accessing files asynchronously.
Note: Upload URL generation (`for_upload=True`) is generally not supported
by fsspec's generic `sign` method. This typically requires
backend-specific methods (e.g., S3 presigned POST URLs).
Args:
paths: The path or paths of the file(s) (relative to prefix if set).
expires_in: The expiration time of the URL in seconds (backend-dependent default).
for_upload: If True, attempt to generate an upload URL (likely unsupported).
Returns:
A signed URL string if a single path is given, or a list of strings
if multiple paths are provided.
"""
return await async_(self.sign)(paths=paths, expires_in=expires_in, for_upload=for_upload)
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/backends/obstore.py 0000664 0000000 0000000 00000025517 15003544734 0030721 0 ustar 00root root 0000000 0000000 # ruff: noqa: PLR0904, PLC2701
"""Obstore-backed storage backend for file objects."""
import datetime
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from advanced_alchemy.exceptions import MissingDependencyError
from advanced_alchemy.types.file_object._utils import get_mtime_equivalent, get_or_generate_etag
from advanced_alchemy.types.file_object.base import (
AsyncDataLike,
DataLike,
PathLike,
StorageBackend,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from advanced_alchemy.types.file_object.file import FileObject
try:
from obstore import sign as obstore_sign
from obstore import sign_async as obstore_sign_async
from obstore.store import ObjectStore, from_url
except ImportError as e:
raise MissingDependencyError(package="obstore") from e
def schema_from_type(obj: Any) -> str: # noqa: PLR0911
"""Extract the schema from an object.
Args:
obj: Object to parse
Returns:
The schema extracted from the object
"""
from obstore.store import AzureStore, GCSStore, HTTPStore, LocalStore, MemoryStore, S3Store
if isinstance(obj, S3Store):
return "s3"
if isinstance(obj, AzureStore):
return "azure"
if isinstance(obj, GCSStore):
return "gcs"
if isinstance(obj, LocalStore):
return "file"
if isinstance(obj, HTTPStore):
return "http"
if isinstance(obj, MemoryStore):
return "memory"
return "file"
class ObstoreBackend(StorageBackend):
"""Obstore-backed storage backend implementing both sync and async operations."""
driver = "obstore"
def __init__(self, key: str, fs: "Union[ObjectStore, str]", **kwargs: "Any") -> None:
"""Initialize ObstoreBackend.
Args:
fs: The ObjectStore instance from the obstore package
key: The key for the storage backend
kwargs: Additional keyword arguments to pass to the ObjectStore constructor
"""
self.fs = from_url(fs, **kwargs) if isinstance(fs, str) else fs # pyright: ignore
self.protocol = schema_from_type(self.fs) # pyright: ignore
self.key = key
self.options = kwargs
def get_content(self, path: "PathLike", *, options: "Optional[dict[str, Any]]" = None) -> bytes:
"""Return the bytes stored at the specified location.
Args:
path: Path to retrieve
options: Optional backend-specific options
"""
options = options or {}
# Filter out unsupported options
supported_options = {
k: v for k, v in options.items() if k in {"use_multipart", "chunk_size", "max_concurrency"}
}
obj = self.fs.get(self._to_path(path), **supported_options)
return obj.bytes().to_bytes() # type: ignore[no-any-return]
async def get_content_async(self, path: "PathLike", *, options: "Optional[dict[str, Any]]" = None) -> bytes:
"""Return the bytes stored at the specified location asynchronously.
Args:
path: Path to retrieve
options: Optional backend-specific options
"""
options = options or {}
# Filter out unsupported options
supported_options = {
k: v for k, v in options.items() if k in {"use_multipart", "chunk_size", "max_concurrency"}
}
obj = await self.fs.get_async(self._to_path(path), **supported_options)
return (await obj.bytes_async()).to_bytes() # type: ignore[no-any-return]
def save_object(
self,
file_object: "FileObject",
data: "DataLike",
*,
use_multipart: "Optional[bool]" = None,
chunk_size: int = 5 * 1024 * 1024,
max_concurrency: int = 12,
) -> "FileObject":
"""Save data to the specified path using info from FileObject.
Args:
file_object: FileObject instance with metadata (path, content_type, etc.).
data: The data to save.
use_multipart: Whether to use multipart upload.
chunk_size: Size of each chunk in bytes.
max_concurrency: Maximum number of concurrent uploads.
Returns:
A FileObject object representing the saved file, potentially updated.
"""
_ = self.fs.put(
file_object.path,
data,
use_multipart=use_multipart,
chunk_size=chunk_size,
max_concurrency=max_concurrency,
)
info = self.fs.head(file_object.path)
file_object.size = cast("int", info.get("size", file_object.size)) # pyright: ignore
file_object.last_modified = (
get_mtime_equivalent(info) or datetime.datetime.now(tz=datetime.timezone.utc).timestamp() # pyright: ignore
)
file_object.etag = get_or_generate_etag(file_object, info, file_object.last_modified) # pyright: ignore
# Merge backend metadata if available and different
backend_meta: dict[str, Any] = info.get("metadata", {}) # pyright: ignore
if backend_meta and backend_meta != file_object.metadata:
file_object.update_metadata(backend_meta) # pyright: ignore
return file_object
async def save_object_async(
self,
file_object: "FileObject",
data: "AsyncDataLike",
*,
use_multipart: "Optional[bool]" = None,
chunk_size: int = 5 * 1024 * 1024,
max_concurrency: int = 12,
) -> "FileObject":
"""Save data to the specified path asynchronously using info from FileObject.
Args:
file_object: FileObject instance with metadata (path, content_type, etc.).
data: The data to save.
use_multipart: Whether to use multipart upload.
chunk_size: Size of each chunk in bytes.
max_concurrency: Maximum number of concurrent uploads.
Returns:
A FileObject object representing the saved file, potentially updated.
"""
_ = await self.fs.put_async(
file_object.path,
data,
use_multipart=use_multipart,
chunk_size=chunk_size,
max_concurrency=max_concurrency,
)
info = await self.fs.head_async(file_object.path)
file_object.size = cast("int", info.get("size", file_object.size)) # pyright: ignore
file_object.last_modified = (
get_mtime_equivalent(info) or datetime.datetime.now(tz=datetime.timezone.utc).timestamp() # pyright: ignore
)
file_object.etag = get_or_generate_etag(file_object, info, file_object.last_modified) # pyright: ignore
# Merge backend metadata if available and different
backend_meta: dict[str, Any] = info.get("metadata", {}) # pyright: ignore
if backend_meta and backend_meta != file_object.metadata:
file_object.update_metadata(backend_meta) # pyright: ignore
return file_object
def delete_object(self, paths: "Union[PathLike, Sequence[PathLike]]") -> None:
"""Delete the specified paths.
Args:
paths: Path or paths to delete
"""
if isinstance(paths, (str, Path, os.PathLike)):
path_list = [self._to_path(paths)]
else:
path_list = [self._to_path(p) for p in paths]
self.fs.delete(path_list)
async def delete_object_async(self, paths: "Union[PathLike, Sequence[PathLike]]") -> None:
"""Delete the specified paths asynchronously.
Args:
paths: Path or paths to delete
"""
if isinstance(paths, (str, Path, os.PathLike)):
path_list = [self._to_path(paths)]
else:
path_list = [self._to_path(p) for p in paths]
await self.fs.delete_async(path_list)
def sign(
self,
paths: "Union[PathLike, Sequence[PathLike]]",
*,
expires_in: "Optional[int]" = None,
for_upload: bool = False,
) -> "Union[str, list[str]]":
"""Create a signed URL for accessing or uploading the file.
Args:
paths: The path or list of paths of the file
expires_in: The expiration time of the URL in seconds
for_upload: If True, generates a URL suitable for uploads (e.g., presigned POST)
Returns:
A URL or list of URLs for accessing the file
"""
http_method = "PUT" if for_upload else "GET"
expires_delta = (
datetime.timedelta(seconds=expires_in) if expires_in is not None else datetime.timedelta(hours=1)
)
if isinstance(paths, (str, Path, os.PathLike)):
single_path = self._to_path(paths)
try:
return obstore_sign(store=self.fs, method=http_method, paths=single_path, expires_in=expires_delta) # type: ignore
except ValueError as e:
msg = f"Error signing path {single_path}: {e}"
raise NotImplementedError(msg) from e
path_list = [self._to_path(p) for p in paths]
try:
return obstore_sign(store=self.fs, method=http_method, paths=path_list, expires_in=expires_delta) # type: ignore
except ValueError as e:
msg = f"Error signing paths {path_list}: {e}"
raise NotImplementedError(msg) from e
async def sign_async(
self,
paths: "Union[PathLike, Sequence[PathLike]]",
*,
expires_in: "Optional[int]" = None,
for_upload: bool = False,
) -> "Union[str, list[str]]":
"""Sign a URL for a given path asynchronously.
Args:
paths: Path to sign
expires_in: Expiration time in seconds
for_upload: Whether the URL is for uploading a file
Returns:
A URL or list of URLs for accessing the file
"""
http_method = "PUT" if for_upload else "GET"
expires_delta = (
datetime.timedelta(seconds=expires_in) if expires_in is not None else datetime.timedelta(hours=1)
)
if isinstance(paths, (str, Path, os.PathLike)):
single_path = self._to_path(paths)
try:
return await obstore_sign_async( # type: ignore
store=self.fs, # pyright: ignore
method=http_method,
paths=single_path,
expires_in=expires_delta,
)
except ValueError as e:
msg = f"Error signing path {single_path}: {e}"
raise NotImplementedError(msg) from e
path_list = [self._to_path(p) for p in paths]
try:
return await obstore_sign_async( # type: ignore
store=self.fs, # pyright: ignore
method=http_method,
paths=path_list,
expires_in=expires_delta,
)
except ValueError as e:
msg = f"Error signing paths {path_list}: {e}"
raise NotImplementedError(msg) from e
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/base.py 0000664 0000000 0000000 00000013044 15003544734 0026374 0 ustar 00root root 0000000 0000000 # ruff: noqa: PLR0904, PLR6301
"""Generic unified storage protocol compatible with multiple backend implementations."""
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Optional, TypeVar, Union
from typing_extensions import TypeAlias
if TYPE_CHECKING:
from advanced_alchemy.types.file_object.file import FileObject
# Type variables
T = TypeVar("T")
StorageBackendT = TypeVar("StorageBackendT", bound="StorageBackend")
PathLike: TypeAlias = Union[str, Path, os.PathLike[Any]]
DataLike: TypeAlias = Union[IO[bytes], Path, bytes, Iterator[bytes], Iterable[bytes]]
AsyncDataLike: TypeAlias = Union[
IO[bytes], Path, bytes, AsyncIterator[bytes], AsyncIterable[bytes], Iterator[bytes], Iterable[bytes]
]
class StorageBackend(ABC):
"""Unified protocol for storage backend implementations supporting both sync and async operations."""
driver: str
"""The name of the storage backend."""
protocol: str
"""The protocol used by the storage backend."""
key: str
"""The key of the backend instance."""
def __init__(self, key: str, fs: Any, **kwargs: Any) -> None:
"""Initialize the storage backend.
Args:
key: The key of the backend instance
fs: The filesystem or storage client
**kwargs: Additional keyword arguments
"""
self.fs = fs
self.key = key
self.options = kwargs
@staticmethod
def _to_path(path: "PathLike") -> str:
"""Convert a path-like object to a string.
Args:
path: The path to convert
Returns:
str: The string representation of the path
"""
return str(path)
@abstractmethod
def get_content(self, path: "PathLike", *, options: Optional[dict[str, Any]] = None) -> bytes:
"""Get the content of a file.
Args:
path: Path to the file
options: Optional backend-specific options
Returns:
bytes: The file content
"""
@abstractmethod
async def get_content_async(self, path: "PathLike", *, options: Optional[dict[str, Any]] = None) -> bytes:
"""Get the content of a file asynchronously.
Args:
path: Path to the file
options: Optional backend-specific options
Returns:
bytes: The file content
"""
@abstractmethod
def save_object(
self,
file_object: "FileObject",
data: "DataLike",
*,
use_multipart: "Optional[bool]" = None,
chunk_size: int = 5 * 1024 * 1024,
max_concurrency: int = 12,
) -> "FileObject":
"""Store a file using information from a FileObject.
Args:
file_object: A FileObject instance containing metadata like path, content_type.
data: The file data to store.
use_multipart: Whether to use multipart upload.
chunk_size: Size of chunks for multipart upload.
max_concurrency: Maximum number of concurrent uploads.
Returns:
FileObject: The stored file object, potentially updated with backend info (size, etag, etc.).
"""
@abstractmethod
async def save_object_async(
self,
file_object: "FileObject",
data: AsyncDataLike,
*,
use_multipart: Optional[bool] = None,
chunk_size: int = 5 * 1024 * 1024,
max_concurrency: int = 12,
) -> "FileObject":
"""Store a file asynchronously using information from a FileObject.
Args:
file_object: A FileObject instance containing metadata like path, content_type.
data: The file data to store.
use_multipart: Whether to use multipart upload.
chunk_size: Size of chunks for multipart upload.
max_concurrency: Maximum number of concurrent uploads.
Returns:
FileObject: The stored file object, potentially updated with backend info (size, etag, etc.).
"""
@abstractmethod
def delete_object(self, paths: Union[PathLike, Sequence[PathLike]]) -> None:
"""Delete one or more files.
Args:
paths: Path or paths to delete
"""
@abstractmethod
async def delete_object_async(self, paths: Union[PathLike, Sequence[PathLike]]) -> None:
"""Delete one or more files asynchronously.
Args:
paths: Path or paths to delete
"""
@abstractmethod
def sign(
self,
paths: Union[PathLike, Sequence[PathLike]],
*,
expires_in: Optional[int] = None,
for_upload: bool = False,
) -> Union[str, list[str]]:
"""Generate a signed URL for one or more files.
Args:
paths: Path or paths to generate URLs for
expires_in: Optional expiration time in seconds
for_upload: Whether the URL is for upload
Returns:
str: The signed URL
"""
@abstractmethod
async def sign_async(
self,
paths: Union[PathLike, Sequence[PathLike]],
*,
expires_in: Optional[int] = None,
for_upload: bool = False,
) -> Union[str, list[str]]:
"""Generate a signed URL for one or more files asynchronously.
Args:
paths: Path or paths to generate URLs for
expires_in: Optional expiration time in seconds
for_upload: Whether the URL is for upload
Returns:
str: The signed URL
"""
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/data_type.py 0000664 0000000 0000000 00000012316 15003544734 0027435 0 ustar 00root root 0000000 0000000 from typing import Any, Optional, Union, cast
from sqlalchemy import TypeDecorator
from advanced_alchemy._serialization import decode_json
from advanced_alchemy.types.file_object.base import StorageBackend
from advanced_alchemy.types.file_object.file import FileObject
from advanced_alchemy.types.file_object.registry import storages
from advanced_alchemy.types.json import JsonB
from advanced_alchemy.types.mutables import MutableList
# Define the type hint for the value this TypeDecorator handles
FileObjectOrList = Union[FileObject, list[FileObject], set[FileObject], MutableList[FileObject]]
OptionalFileObjectOrList = Optional[FileObjectOrList]
class StoredObject(TypeDecorator[OptionalFileObjectOrList]):
"""Custom SQLAlchemy type for storing single or multiple file metadata.
Stores file metadata in JSONB and handles file validation, processing,
and storage operations through a configured storage backend.
"""
impl = JsonB
cache_ok = True
# Default settings
multiple: bool
_raw_backend: Union[str, StorageBackend]
_resolved_backend: "Optional[StorageBackend]" = None
@property
def python_type(self) -> "type[OptionalFileObjectOrList]":
"""Specifies the Python type used, accounting for the `multiple` flag."""
# This provides a hint to SQLAlchemy and type checkers
return MutableList[FileObject] if self.multiple else Optional[FileObject] # type: ignore
@property
def backend(self) -> "StorageBackend":
"""Resolves and returns the storage backend instance."""
# Return cached version if available
if self._resolved_backend is None:
self._resolved_backend = (
storages.get_backend(self._raw_backend) if isinstance(self._raw_backend, str) else self._raw_backend
)
return self._resolved_backend
@property
def storage_key(self) -> str:
"""Returns the storage key from the resolved backend."""
return self.backend.key
def __init__(
self,
backend: Union[str, StorageBackend],
multiple: bool = False,
*args: "Any",
**kwargs: "Any",
) -> None:
"""Initialize StoredObject type.
Args:
backend: Key to retrieve the backend or from the storage registry or storage backend to use.
multiple: If True, stores a list of files; otherwise, a single file.
*args: Additional positional arguments for TypeDecorator.
**kwargs: Additional keyword arguments for TypeDecorator.
"""
super().__init__(*args, **kwargs)
self.multiple = multiple
self._raw_backend = backend
def process_bind_param(
self,
value: "Optional[FileObjectOrList]",
dialect: "Any",
) -> "Optional[Union[dict[str, Any], list[dict[str, Any]]]]":
"""Convert FileObject(s) to JSON representation for the database.
Injects the configured backend into the FileObject before conversion.
Note: This method expects an already processed FileInfo or its dict representation.
Use handle_upload() or handle_upload_async() for processing raw uploads.
Args:
value: The value to process
dialect: The SQLAlchemy dialect
Raises:
TypeError: If the input value is not a FileObject or a list of FileObjects.
Returns:
A dictionary representing the file metadata, or None if the input value is None.
"""
if value is None:
return None
if self.multiple:
if not isinstance(value, (list, MutableList, set)):
return [value.to_dict()] if value else []
return [item.to_dict() for item in value if item]
if isinstance(value, (list, MutableList, set)):
msg = f"Expected a single FileObject for multiple=False, got {type(value)}"
raise TypeError(msg)
return value.to_dict() if value else None
def process_result_value(
self, value: "Optional[Union[bytes, str, dict[str, Any], list[dict[str, Any]]]]", dialect: "Any"
) -> "Optional[FileObjectOrList]":
"""Convert database JSON back to FileObject or MutableList[FileObject].
Args:
value: The value to process
dialect: The SQLAlchemy dialect
Raises:
TypeError: If the input value is not a list of dicts.
Returns:
FileObject or MutableList[FileObject] or None.
"""
if value is None:
return None
if self.multiple:
if isinstance(value, dict):
# If the DB returns a single dict, wrap it in a list
value = [value]
elif isinstance(value, (str, bytes)):
# Decode JSON string or bytes to dict
value = [cast("dict[str, Any]", decode_json(value))]
return MutableList[FileObject]([FileObject(**v) for v in value if v]) # pyright: ignore
if isinstance(value, list):
msg = f"Expected dict from DB for multiple=False, got {type(value)}"
raise TypeError(msg)
if isinstance(value, (bytes, str)):
value = cast("dict[str,Any]", decode_json(value))
return FileObject(**value)
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/file.py 0000664 0000000 0000000 00000043310 15003544734 0026400 0 ustar 00root root 0000000 0000000 """Generic unified storage protocol compatible with multiple backend implementations."""
import mimetypes
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
from sqlalchemy.ext.mutable import MutableList
from typing_extensions import TypeAlias
from advanced_alchemy.exceptions import MissingDependencyError
from advanced_alchemy.types.file_object._typing import PYDANTIC_INSTALLED, GetCoreSchemaHandler, core_schema
from advanced_alchemy.types.file_object.base import AsyncDataLike, DataLike, StorageBackend
from advanced_alchemy.types.file_object.registry import storages
if TYPE_CHECKING:
from advanced_alchemy.types.file_object.base import PathLike
class FileObject:
"""Represents file metadata during processing using a dataclass structure.
This class provides a unified interface for handling file metadata and operations
across different storage backends.
Content or a source path can optionally be provided during initialization via kwargs, store it internally, and add save/save_async methods to persist this pending data using the configured backend.
"""
__slots__ = (
"_checksum",
"_content_type",
"_etag",
"_filename",
"_last_modified",
"_metadata",
"_pending_source_content",
"_pending_source_path",
"_raw_backend",
"_resolved_backend",
"_size",
"_to_filename",
"_version_id",
)
def __init__(
self,
backend: "Union[str, StorageBackend]",
filename: str,
to_filename: Optional[str] = None,
content_type: Optional[str] = None,
size: Optional[int] = None,
last_modified: Optional[float] = None,
checksum: Optional[str] = None,
etag: Optional[str] = None,
version_id: Optional[str] = None,
metadata: Optional[dict[str, Any]] = None,
source_path: "Optional[PathLike]" = None,
content: "Optional[Union[DataLike, AsyncDataLike]]" = None,
) -> None:
"""Perform post-initialization validation and setup.
Handles default path, content type guessing, backend protocol inference,
and processing of 'content' or 'source_path' from extra kwargs.
Raises:
ValueError: If filename is not provided, size is negative, backend/protocol mismatch,
or both 'content' and 'source_path' are provided.
"""
self._size = size
self._last_modified = last_modified
self._checksum = checksum
self._etag = etag
self._version_id = version_id
self._metadata = metadata or {}
self._filename = filename
self._content_type = content_type
self._to_filename = to_filename
self._resolved_backend: Optional[StorageBackend] = backend if isinstance(backend, StorageBackend) else None
self._raw_backend = backend
self._pending_source_path = Path(source_path) if source_path is not None else None
self._pending_source_content = content
if self._pending_source_content is not None and self._pending_source_path is not None:
msg = "Cannot provide both 'source_content' and 'source_path' during initialization."
raise ValueError(msg)
def __repr__(self) -> str:
"""Return a string representation of the FileObject."""
return f"FileObject(filename={self.path}, backend={self.backend.key}, size={self.size}, content_type={self.content_type}, etag={self.etag}, last_modified={self.last_modified}, version_id={self.version_id})"
def __eq__(self, other: object) -> bool:
"""Check equality based on filename and backend key.
Args:
other: The object to compare with.
Returns:
bool: True if the objects are equal, False otherwise.
"""
if not isinstance(other, FileObject):
return False
return self.path == other.path and self.backend.key == other.backend.key
def __hash__(self) -> int:
"""Return a hash based on filename and backend key."""
return hash((self.path, self.backend.key))
@property
def backend(self) -> "StorageBackend":
if self._resolved_backend is None:
self._resolved_backend = (
storages.get_backend(self._raw_backend) if isinstance(self._raw_backend, str) else self._raw_backend
)
return self._resolved_backend
@property
def filename(self) -> str:
return self.path
@property
def content_type(self) -> str:
if self._content_type is None:
guessed_type, _ = mimetypes.guess_type(self._filename)
self._content_type = guessed_type or "application/octet-stream"
return self._content_type
@property
def protocol(self) -> str:
return self.backend.protocol if self.backend else "file"
@property
def path(self) -> str:
return self._to_filename or self._filename
@property
def has_pending_data(self) -> bool:
return bool(self._pending_source_content or self._pending_source_path)
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
@metadata.setter
def metadata(self, value: dict[str, Any]) -> None:
self._metadata = value
@property
def size(self) -> "Optional[int]":
return self._size
@size.setter
def size(self, value: int) -> None:
self._size = value
@property
def last_modified(self) -> "Optional[float]":
return self._last_modified
@last_modified.setter
def last_modified(self, value: float) -> None:
self._last_modified = value
@property
def checksum(self) -> "Optional[str]":
return self._checksum
@checksum.setter
def checksum(self, value: str) -> None:
self._checksum = value
@property
def etag(self) -> "Optional[str]":
return self._etag
@etag.setter
def etag(self, value: str) -> None:
self._etag = value
@property
def version_id(self) -> "Optional[str]":
return self._version_id
@version_id.setter
def version_id(self, value: str) -> None:
self._version_id = value
def update_metadata(self, metadata: "dict[str, Any]") -> None:
"""Update the file metadata.
Args:
metadata: New metadata to merge with existing metadata.
"""
self.metadata.update(metadata)
def to_dict(self) -> "dict[str, Any]":
"""Convert FileObject to a dictionary for storage or serialization.
Note: The 'backend' attribute is intentionally excluded as it's often
not serializable or relevant for storage representations.
The 'extra' dict is included.
Returns:
dict[str, Any]: A dictionary representation of the file information.
"""
# Use dataclasses.asdict and filter out the backend
return {
"filename": self.path,
"content_type": self.content_type,
"size": self.size,
"last_modified": self.last_modified,
"checksum": self.checksum,
"etag": self.etag,
"version_id": self.version_id,
"metadata": self.metadata,
"backend": self.backend.key,
}
def get_content(self, *, options: "Optional[dict[str, Any]]" = None) -> bytes:
"""Get the file content from the storage backend.
Args:
options: Optional backend-specific options.
Returns:
bytes: The file content.
"""
return self.backend.get_content(self.path, options=options)
async def get_content_async(self, *, options: "Optional[dict[str, Any]]" = None) -> bytes:
"""Get the file content from the storage backend asynchronously.
Args:
options: Optional backend-specific options.
Returns:
bytes: The file content.
"""
return await self.backend.get_content_async(self.path, options=options)
def sign(
self,
*,
expires_in: "Optional[int]" = None,
for_upload: bool = False,
) -> str:
"""Generate a signed URL for the file.
Args:
expires_in: Optional expiration time in seconds.
for_upload: Whether the URL is for upload.
Raises:
RuntimeError: If no signed URL is generated.
Returns:
str: The signed URL.
"""
result = self.backend.sign(self.path, expires_in=expires_in, for_upload=for_upload)
if isinstance(result, list):
if not result:
msg = "No signed URL generated"
raise RuntimeError(msg)
return result[0]
return result
async def sign_async(
self,
*,
expires_in: "Optional[int]" = None,
for_upload: bool = False,
) -> str:
"""Generate a signed URL for the file asynchronously.
Args:
expires_in: Optional expiration time in seconds.
for_upload: Whether the URL is for upload.
Returns:
str: The signed URL.
Raises:
RuntimeError: If no signed URL is generated.
"""
result = await self.backend.sign_async(self.path, expires_in=expires_in, for_upload=for_upload)
if isinstance(result, list):
if not result:
msg = "No signed URL generated"
raise RuntimeError(msg)
return result[0]
return result
def delete(self) -> None:
"""Delete the file from storage.
Raises:
RuntimeError: If no backend is configured or path is missing.
"""
if not self.backend:
msg = "No storage backend configured"
raise RuntimeError(msg)
self.backend.delete_object(self.path)
async def delete_async(self) -> None:
"""Delete the file from storage asynchronously."""
await self.backend.delete_object_async(self.path)
def save(
self,
data: Optional[DataLike] = None,
*,
use_multipart: Optional[bool] = None,
chunk_size: int = 5 * 1024 * 1024,
max_concurrency: int = 12,
) -> "FileObject":
"""Save data to the storage backend using this FileObject's metadata.
If `data` is provided, it is used directly.
If `data` is None, checks internal source_content or source_path.
Clears pending attributes after successful save.
Args:
data: Optional data to save (bytes, iterator, file-like, Path). If None,
internal pending data is used.
use_multipart: Passed to the backend's save method.
chunk_size: Passed to the backend's save method.
max_concurrency: Passed to the backend's save method.
Returns:
The updated FileObject instance returned by the backend.
Raises:
TypeError: If trying to save async data synchronously.
"""
if data is None and self._pending_source_content is not None:
data = self._pending_source_content # type: ignore[assignment]
elif data is None and self._pending_source_path is not None:
data = self._pending_source_path
if data is None:
msg = "No data provided and no pending content/path found to save."
raise TypeError(msg)
# The backend's save method is expected to update the FileObject instance in-place
# and return the updated instance.
updated_self = self.backend.save_object(
file_object=self,
data=data,
use_multipart=use_multipart,
chunk_size=chunk_size,
max_concurrency=max_concurrency,
)
# Clear pending attributes after successful save
self._pending_source_content = None
self._pending_source_path = None
return updated_self
async def save_async(
self,
data: Optional[AsyncDataLike] = None,
*,
use_multipart: Optional[bool] = None,
chunk_size: int = 5 * 1024 * 1024,
max_concurrency: int = 12,
) -> "FileObject":
"""Save data to the storage backend asynchronously.
If `data` is provided, it is used directly.
If `data` is None, checks internal source_content or source_path.
Clears pending attributes after successful save.
Uses asyncio.to_thread for reading source_path if backend doesn't handle Path directly.
Args:
data: Optional data to save (bytes, async iterator, file-like, Path, etc.).
If None, internal pending data is used.
use_multipart: Passed to the backend's async save method.
chunk_size: Passed to the backend's async save method.
max_concurrency: Passed to the backend's async save method.
Returns:
The updated FileObject instance returned by the backend.
Raises:
TypeError: If trying to save sync data asynchronously.
"""
if data is None and self._pending_source_content is not None:
data = self._pending_source_content
elif data is None and self._pending_source_path is not None:
data = self._pending_source_path
if data is None:
msg = "No data provided and no pending content/path found to save."
raise TypeError(msg)
# Backend's async save method updates the FileObject instance
updated_self = await self.backend.save_object_async(
file_object=self,
data=data, # Pass the determined data source
use_multipart=use_multipart,
chunk_size=chunk_size,
max_concurrency=max_concurrency,
)
# Clear pending attributes after successful save
self._pending_source_content = None
self._pending_source_path = None
return updated_self
@classmethod
def __get_pydantic_core_schema__(
cls,
source_type: Any,
handler: "GetCoreSchemaHandler", # Use imported GetCoreSchemaHandler
) -> "core_schema.CoreSchema": # Use imported core_schema
"""Get the Pydantic core schema for FileObject.
This method defines how Pydantic should validate and serialize FileObject instances.
It creates a schema that validates dictionaries with the required fields and
converts them to FileObject instances.
Raises:
MissingDependencyError: If Pydantic is not installed when this method is called.
Args:
source_type: The source type (FileObject)
handler: The Pydantic schema handler
Returns:
A Pydantic core schema for FileObject
"""
if not PYDANTIC_INSTALLED:
raise MissingDependencyError(package="pydantic")
def validate_from_dict(data: dict[str, Any]) -> "FileObject":
# We expect a dictionary derived from to_dict()
# We need to resolve the backend string back to an instance if needed
backend_input = data.get("backend")
if backend_input is None:
msg = "backend is required"
raise TypeError(msg)
key = backend_input if isinstance(backend_input, str) else backend_input.key
return cls(
backend=key,
filename=data["filename"],
to_filename=data.get("to_filename"),
content_type=data.get("content_type"),
size=data.get("size"),
last_modified=data.get("last_modified"),
checksum=data.get("checksum"),
etag=data.get("etag"),
version_id=data.get("version_id"),
metadata=data.get("metadata"),
)
typed_dict_schema = core_schema.typed_dict_schema(
{
"filename": core_schema.typed_dict_field(core_schema.str_schema()),
"backend": core_schema.typed_dict_field(core_schema.str_schema()),
"to_filename": core_schema.typed_dict_field(core_schema.str_schema(), required=False),
"content_type": core_schema.typed_dict_field(core_schema.str_schema(), required=False),
"size": core_schema.typed_dict_field(core_schema.int_schema(), required=False),
"last_modified": core_schema.typed_dict_field(core_schema.float_schema(), required=False),
"checksum": core_schema.typed_dict_field(core_schema.str_schema(), required=False),
"etag": core_schema.typed_dict_field(core_schema.str_schema(), required=False),
"version_id": core_schema.typed_dict_field(core_schema.str_schema(), required=False),
"metadata": core_schema.typed_dict_field(
core_schema.nullable_schema(
core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema())
),
required=False,
),
}
)
validation_schema = core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
core_schema.chain_schema(
[
typed_dict_schema,
core_schema.no_info_plain_validator_function(validate_from_dict),
]
),
]
)
return core_schema.json_or_python_schema(
json_schema=validation_schema,
python_schema=validation_schema,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.to_dict(), # pyright: ignore
info_arg=False,
return_schema=typed_dict_schema,
), # pyright: ignore
)
FileObjectList: TypeAlias = MutableList[FileObject]
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/registry.py 0000664 0000000 0000000 00000010504 15003544734 0027330 0 ustar 00root root 0000000 0000000 from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.utils.module_loader import import_string
from advanced_alchemy.utils.singleton import SingletonMeta
if TYPE_CHECKING:
from advanced_alchemy.types.file_object.base import StorageBackend
DEFAULT_BACKEND = (
"advanced_alchemy.types.file_object.backends.obstore.ObstoreBackend"
if find_spec("obstore")
else "advanced_alchemy.types.file_object.backends.fsspec.FSSpecBackend"
)
class StorageRegistry(metaclass=SingletonMeta):
"""A provider for creating and managing threaded portals."""
def __init__(
self,
json_serializer: "Callable[[Any], str]" = encode_json,
json_deserializer: Callable[[Union[str, bytes]], Any] = decode_json,
default_backend: "Union[str, type[StorageBackend]]" = DEFAULT_BACKEND,
) -> None:
"""Initialize the PortalProvider."""
self._registry: dict[str, StorageBackend] = {}
self.json_serializer = json_serializer
self.json_deserializer = json_deserializer
self.default_backend: str = (
DEFAULT_BACKEND if isinstance(default_backend, str) else default_backend.__qualname__
)
def set_default_backend(self, default_backend: "Union[str, type[StorageBackend]]") -> None:
"""Set the default storage backend.
Args:
default_backend: The default storage backend
"""
self.default_backend = default_backend if isinstance(default_backend, str) else default_backend.__qualname__
def is_registered(self, key: str) -> bool:
"""Check if a storage backend is registered in the registry.
Args:
key: The key of the storage backend
Returns:
bool: True if the storage backend is registered, False otherwise.
"""
return key in self._registry
def get_backend(self, key: str) -> "StorageBackend":
"""Retrieve a configured storage backend from the registry.
Returns:
StorageBackend: The storage backend associaStorageBackendiven key.
Raises:
ImproperConfigurationError: If no storage backend is registered with the given key.
"""
try:
return self._registry[key]
except KeyError as e:
msg = f"No storage backend registered with key {key}"
raise ImproperConfigurationError(msg) from e
@overload
def register_backend(self, value: "str") -> None: ...
@overload
def register_backend(self, value: "str", key: None = None) -> None: ...
@overload
def register_backend(self, value: "str", key: str) -> None: ...
@overload
def register_backend(self, value: "StorageBackend", key: None = None) -> None: ...
@overload
def register_backend(self, value: "StorageBackend", key: str) -> None: ...
def register_backend(self, value: "Union[StorageBackend, str]", key: "Optional[str]" = None) -> None:
"""Register a new storage backend in the registry.
Args:
value: The storage backend to register.
key: The key to register the storage backend with.
Raises:
ImproperConfigurationError: If a string value is provided without a key.
"""
if isinstance(value, str):
if key is None:
msg = "key is required when registering a string value"
raise ImproperConfigurationError(msg)
self._registry[key] = import_string(self.default_backend)(fs=value, key=key)
else:
if key is not None:
msg = "key is not allowed when registering a StorageBackend"
raise ImproperConfigurationError(msg)
self._registry[value.key] = value
def unregister_backend(self, key: str) -> None:
"""Unregister a storage backend from the registry."""
if key in self._registry:
del self._registry[key]
def clear_backends(self) -> None:
"""Clear the registry."""
self._registry.clear()
def registered_backends(self) -> list[str]:
"""Return a list of all registered keys."""
return list(self._registry.keys())
storages = StorageRegistry()
python-advanced-alchemy-1.4.1/advanced_alchemy/types/file_object/session_tracker.py 0000664 0000000 0000000 00000012474 15003544734 0030666 0 ustar 00root root 0000000 0000000 # ruff: noqa: UP037
"""Application ORM configuration."""
import asyncio
import logging
from typing import TYPE_CHECKING, Any, Union
if TYPE_CHECKING:
from collections.abc import Awaitable
from pathlib import Path
from advanced_alchemy.types.file_object import FileObject
logger = logging.getLogger("advanced_alchemy")
class FileObjectSessionTracker:
"""Tracks FileObject changes within a single session transaction."""
def __init__(self) -> None:
"""Initialize the tracker."""
# Stores objects that have pending data to be saved on commit.
# Maps FileObject -> data source (bytes or Path)
self.pending_saves: "dict[FileObject, Union[bytes, Path]]" = {}
# Stores objects that should be deleted from storage on commit.
self.pending_deletes: "set[FileObject]" = set()
# Stores objects that were successfully saved within this transaction,
# needed for rollback cleanup.
self._saved_in_transaction: "set[FileObject]" = set()
def add_pending_save(self, obj: "FileObject", data: "Union[bytes, Path]") -> None:
"""Mark a FileObject for saving."""
self.pending_saves[obj] = data
# If this object was previously marked for deletion, unmark it.
self.pending_deletes.discard(obj)
def add_pending_delete(self, obj: "FileObject") -> None:
"""Mark a FileObject for deletion."""
# If this object was pending save, unmark it.
self.pending_saves.pop(obj, None)
# Only add to pending deletes if it actually exists in storage (has a path)
if obj.path:
self.pending_deletes.add(obj)
def commit(self) -> None:
"""Process pending saves and deletes after a successful commit."""
for obj, data in self.pending_saves.items():
try:
obj.save(data)
except Exception as e: # noqa: BLE001
logger.warning("Error saving file for object %s: %s", obj, e.__cause__)
for obj in self.pending_deletes:
try:
obj.delete()
except FileNotFoundError:
# Ignore if the file is already gone (shouldn't happen often here)
pass
except Exception as e: # noqa: BLE001
logger.warning("Error deleting file for object %s: %s", obj, e.__cause__)
self.clear()
async def commit_async(self) -> None:
"""Process pending saves and deletes after a successful commit."""
save_tasks: list[Awaitable[Any]] = []
for obj, data in self.pending_saves.items():
save_tasks.append(obj.save_async(data))
self._saved_in_transaction.add(obj)
delete_tasks: list[Awaitable[Any]] = [obj.delete_async() for obj in self.pending_deletes]
# Run save and delete tasks concurrently
save_results = await asyncio.gather(*save_tasks, return_exceptions=True)
delete_results = await asyncio.gather(*delete_tasks, return_exceptions=True)
# Process save results (log errors)
for result, (obj, _data) in zip(save_results, self.pending_saves.items()):
if isinstance(result, Exception):
logger.warning("Error saving file for object %s: %s", obj, result.__cause__)
# Process delete results (log errors, ignore FileNotFoundError)
for result, obj_to_delete in zip(delete_results, self.pending_deletes):
if isinstance(result, FileNotFoundError):
continue
if isinstance(result, Exception):
logger.warning("Error deleting file %s: %s", obj_to_delete.path, result.__cause__)
self.clear()
def rollback(self) -> None:
"""Clean up files saved during a transaction that is being rolled back."""
for obj in self._saved_in_transaction:
if obj.path:
try:
obj.delete()
except FileNotFoundError:
# Ignore if the file is already gone (shouldn't happen often here)
pass
except Exception as e: # noqa: BLE001
logger.warning("Error deleting file during rollback %s: %s", obj.path, e.__cause__)
self.clear()
async def rollback_async(self) -> None:
"""Clean up files saved during a transaction that is being rolled back."""
rollback_delete_tasks: list[Awaitable[Any]] = []
objects_to_delete_on_rollback: list[FileObject] = []
# Only delete files that were actually saved *during this transaction*
for obj in self._saved_in_transaction:
if obj.path:
rollback_delete_tasks.append(obj.delete_async())
objects_to_delete_on_rollback.append(obj)
for task, obj_to_delete in zip(rollback_delete_tasks, objects_to_delete_on_rollback):
try:
await task
except FileNotFoundError:
# Ignore if the file is already gone (shouldn't happen often here)
pass
except Exception as e: # noqa: BLE001
logger.warning("Error deleting file during rollback %s: %s", obj_to_delete.path, e.__cause__)
self.clear()
def clear(self) -> None:
"""Clear the tracker's state."""
self.pending_saves.clear()
self.pending_deletes.clear()
self._saved_in_transaction.clear()
python-advanced-alchemy-1.4.1/advanced_alchemy/types/guid.py 0000664 0000000 0000000 00000006453 15003544734 0024153 0 ustar 00root root 0000000 0000000 # ruff: noqa: FA100
from base64 import b64decode
from importlib.util import find_spec
from typing import Any, Optional, Union, cast
from uuid import UUID
from sqlalchemy.dialects.mssql import UNIQUEIDENTIFIER as MSSQL_UNIQUEIDENTIFIER
from sqlalchemy.dialects.oracle import RAW as ORA_RAW
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
from sqlalchemy.engine import Dialect
from sqlalchemy.types import BINARY, CHAR, TypeDecorator
from typing_extensions import Buffer
__all__ = ("GUID",)
UUID_UTILS_INSTALLED = find_spec("uuid_utils")
NANOID_INSTALLED = find_spec("fastnanoid")
class GUID(TypeDecorator[UUID]):
"""Platform-independent GUID type.
Uses PostgreSQL's UUID type (Postgres, DuckDB, Cockroach),
MSSQL's UNIQUEIDENTIFIER type, Oracle's RAW(16) type,
otherwise uses BINARY(16) or CHAR(32),
storing as stringified hex values.
Will accept stringified UUIDs as a hexstring or an actual UUID
"""
impl = BINARY(16)
cache_ok = True
@property
def python_type(self) -> type[UUID]:
return UUID
def __init__(self, *args: Any, binary: bool = True, **kwargs: Any) -> None:
self.binary = binary
def load_dialect_impl(self, dialect: Dialect) -> Any:
if dialect.name in {"postgresql", "duckdb", "cockroachdb"}:
return dialect.type_descriptor(PG_UUID())
if dialect.name == "oracle":
return dialect.type_descriptor(ORA_RAW(16))
if dialect.name == "mssql":
return dialect.type_descriptor(MSSQL_UNIQUEIDENTIFIER())
if self.binary:
return dialect.type_descriptor(BINARY(16))
return dialect.type_descriptor(CHAR(32))
def process_bind_param(
self,
value: Optional[Union[bytes, str, UUID]],
dialect: Dialect,
) -> Optional[Union[bytes, str]]:
if value is None:
return value
if dialect.name in {"postgresql", "duckdb", "cockroachdb", "mssql"}:
return str(value)
value = self.to_uuid(value)
if value is None:
return value
if dialect.name in {"oracle", "spanner+spanner"}:
return value.bytes
return value.bytes if self.binary else value.hex
def process_result_value(
self,
value: Optional[Union[bytes, str, UUID]],
dialect: Dialect,
) -> Optional[UUID]:
if value is None:
return value
if value.__class__.__name__ == "UUID":
return cast("UUID", value)
if dialect.name == "spanner+spanner":
return UUID(bytes=b64decode(cast("str | Buffer", value)))
if self.binary:
return UUID(bytes=cast("bytes", value))
return UUID(hex=cast("str", value))
@staticmethod
def to_uuid(value: Any) -> Optional[UUID]:
if value.__class__.__name__ == "UUID" or value is None:
return cast("Optional[UUID]", value)
try:
value = UUID(hex=value)
except (TypeError, ValueError):
value = UUID(bytes=value)
return cast("Optional[UUID]", value)
def compare_values(self, x: Any, y: Any) -> bool:
"""Compare two values for equality."""
if x.__class__.__name__ == "UUID" and y.__class__.__name__ == "UUID":
return cast("bool", x.bytes == y.bytes)
return cast("bool", x == y)
python-advanced-alchemy-1.4.1/advanced_alchemy/types/identity.py 0000664 0000000 0000000 00000000310 15003544734 0025036 0 ustar 00root root 0000000 0000000 from sqlalchemy.types import BigInteger, Integer
BigIntIdentity = BigInteger().with_variant(Integer, "sqlite")
"""A ``BigInteger`` variant that reverts to an ``Integer`` for unsupported variants."""
python-advanced-alchemy-1.4.1/advanced_alchemy/types/json.py 0000664 0000000 0000000 00000006160 15003544734 0024167 0 ustar 00root root 0000000 0000000 # ruff: noqa: FA100
from typing import Any, Optional, Union, cast
from sqlalchemy import text, util
from sqlalchemy.dialects.oracle import BLOB as ORA_BLOB
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
from sqlalchemy.engine import Dialect
from sqlalchemy.types import JSON as _JSON
from sqlalchemy.types import SchemaType, TypeDecorator, TypeEngine
from advanced_alchemy._serialization import decode_json, encode_json
__all__ = ("ORA_JSONB",)
class ORA_JSONB(TypeDecorator[dict[str, Any]], SchemaType): # noqa: N801
"""Oracle Binary JSON type.
JsonB = _JSON().with_variant(PG_JSONB, "postgresql").with_variant(ORA_JSONB, "oracle")
"""
impl = ORA_BLOB
cache_ok = True
@property
def python_type(self) -> type[dict[str, Any]]:
return dict
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize JSON type"""
self.name = kwargs.pop("name", None)
self.oracle_strict = kwargs.pop("oracle_strict", True)
def coerce_compared_value(self, op: Any, value: Any) -> Any:
return self.impl.coerce_compared_value(op=op, value=value) # type: ignore[no-untyped-call, call-arg]
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(ORA_BLOB())
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[Any]:
return value if value is None else encode_json(value)
def process_result_value(self, value: Union[bytes, None], dialect: Dialect) -> Optional[Any]:
if dialect.oracledb_ver < (2,): # type: ignore[attr-defined]
return value if value is None else decode_json(value)
return value
def _should_create_constraint(self, compiler: Any, **kw: Any) -> bool:
return cast("bool", compiler.dialect.name == "oracle")
def _variant_mapping_for_set_table(self, column: Any) -> Optional[dict[str, Any]]:
if column.type._variant_mapping: # noqa: SLF001
variant_mapping = dict(column.type._variant_mapping) # noqa: SLF001
variant_mapping["_default"] = column.type
else:
variant_mapping = None
return variant_mapping
@util.preload_module("sqlalchemy.sql.schema")
def _set_table(self, column: Any, table: Any) -> None:
schema = util.preloaded.sql_schema
variant_mapping = self._variant_mapping_for_set_table(column)
constraint_options = "(strict)" if self.oracle_strict else ""
sqltext = text(f"{column.name} is json {constraint_options}")
e = schema.CheckConstraint(
sqltext,
name=f"{column.name}_is_json",
_create_rule=util.portable_instancemethod( # type: ignore[no-untyped-call]
self._should_create_constraint,
{"variant_mapping": variant_mapping},
),
_type_bound=True,
)
table.append_constraint(e)
JsonB = (
_JSON().with_variant(PG_JSONB, "postgresql").with_variant(ORA_JSONB, "oracle").with_variant(PG_JSONB, "cockroachdb")
)
"""A JSON type that uses native ``JSONB`` where possible and ``Binary`` or ``Blob`` as
an alternative.
"""
python-advanced-alchemy-1.4.1/advanced_alchemy/types/mutables.py 0000664 0000000 0000000 00000010655 15003544734 0025036 0 ustar 00root root 0000000 0000000 from typing import Any, TypeVar, cast, no_type_check
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.ext.mutable import MutableList as SQLMutableList
from typing_extensions import Self
T = TypeVar("T", bound="Any")
class MutableList(SQLMutableList[T]): # pragma: no cover
"""A list type that implements :class:`Mutable`.
The :class:`MutableList` object implements a list that will
emit change events to the underlying mapping when the contents of
the list are altered, including when values are added or removed.
This is a replication of default Mutablelist provide by SQLAlchemy.
The difference here is the properties _removed which keep every element
removed from the list in order to be able to delete them after commit
and keep them when session rolled back.
"""
def __init__(self, *args: "Any", **kwargs: "Any") -> None:
super().__init__(*args, **kwargs)
self._pending_removed: set[T] = set()
self._pending_append: list[T] = []
@classmethod
def coerce(cls, key: "Any", value: "Any") -> "Any": # pragma: no cover
if not isinstance(value, MutableList):
if isinstance(value, list):
return MutableList[T](value)
# this call will raise ValueError
return Mutable.coerce(key, value)
return cast("MutableList[T]", value)
@no_type_check
def __reduce_ex__(self, proto: int) -> "tuple[type[MutableList[T]], tuple[list[T]]]": # pragma: no cover
return self.__class__, (list(self),)
# needed for backwards compatibility with
# older pickles
def __getstate__(self) -> "tuple[list[T], set[T]]": # pragma: no cover
return list(self), self._pending_removed
def __setstate__(self, state: "Any") -> None: # pragma: no cover
self[:] = state[0]
self._pending_removed = state[1]
def __setitem__(self, index: "Any", value: "Any") -> None:
"""Detect list set events and emit change events."""
old_value = self[index] if isinstance(index, slice) else [self[index]]
list.__setitem__(self, index, value) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
self.changed()
self._pending_removed.update(old_value) # pyright: ignore[reportArgumentType]
def __delitem__(self, index: "Any") -> None:
"""Detect list del events and emit change events."""
old_value = self[index] if isinstance(index, slice) else [self[index]]
list.__delitem__(self, index) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
self.changed()
self._pending_removed.update(old_value) # pyright: ignore[reportArgumentType]
def pop(self, *arg: "Any") -> "T":
result = list.pop(self, *arg) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
self.changed()
self._pending_removed.add(result) # pyright: ignore[reportArgumentType,reportUnknownArgumentType]
return result # pyright: ignore[reportUnknownVariableType]
def append(self, x: "Any") -> None:
list.append(self, x) # pyright: ignore[reportUnknownMemberType]
self._pending_append.append(x)
self.changed()
def extend(self, x: "Any") -> None:
list.extend(self, x) # pyright: ignore[reportUnknownMemberType]
self._pending_append.extend(x)
self.changed()
@no_type_check
def __iadd__(self, x: "Any") -> "Self":
self.extend(x)
return self
def insert(self, i: "Any", x: "Any") -> None:
list.insert(self, i, x) # pyright: ignore[reportUnknownMemberType]
self._pending_append.append(x)
self.changed()
def remove(self, i: "T") -> None:
list.remove(self, i) # pyright: ignore[reportUnknownMemberType]
self._pending_removed.add(i)
self.changed()
def clear(self) -> None:
self._pending_removed.update(self)
list.clear(self) # type: ignore[arg-type] # pyright: ignore[reportUnknownMemberType]
self.changed()
def sort(self, **kw: "Any") -> None:
list.sort(self, **kw) # pyright: ignore[reportUnknownMemberType]
self.changed()
def reverse(self) -> None:
list.reverse(self) # type: ignore[arg-type] # pyright: ignore[reportUnknownMemberType]
self.changed()
def _finalize_pending(self) -> None:
"""Finalize pending changes by clearing the pending append list."""
self._pending_append.clear()
python-advanced-alchemy-1.4.1/advanced_alchemy/types/password_hash/ 0000775 0000000 0000000 00000000000 15003544734 0025506 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/types/password_hash/__init__.py 0000664 0000000 0000000 00000000000 15003544734 0027605 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/types/password_hash/argon2.py 0000664 0000000 0000000 00000004354 15003544734 0027256 0 ustar 00root root 0000000 0000000 """Argon2 Hashing Backend using argon2-cffi."""
from typing import TYPE_CHECKING, Any, Union
from advanced_alchemy.types.password_hash.base import HashingBackend
if TYPE_CHECKING:
from sqlalchemy import BinaryExpression, ColumnElement
from argon2 import PasswordHasher as Argon2PasswordHasher # pyright: ignore
from argon2.exceptions import InvalidHash, VerifyMismatchError # pyright: ignore
class Argon2Hasher(HashingBackend):
"""Hashing backend using Argon2 via the argon2-cffi library."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize Argon2Backend.
Args:
**kwargs: Optional keyword arguments to pass to the argon2.PasswordHasher constructor.
See argon2-cffi documentation for available parameters (e.g., time_cost,
memory_cost, parallelism, hash_len, salt_len, type).
"""
self.hasher = Argon2PasswordHasher(**kwargs) # pyright: ignore
def hash(self, value: "Union[str, bytes]") -> str:
"""Hash the password using Argon2.
Args:
value: The plain text password (will be encoded to UTF-8 if string).
Returns:
The Argon2 hash string.
"""
return self.hasher.hash(self._ensure_bytes(value))
def verify(self, plain: "Union[str, bytes]", hashed: str) -> bool:
"""Verify a plain text password against an Argon2 hash.
Args:
plain: The plain text password (will be encoded to UTF-8 if string).
hashed: The Argon2 hash string to verify against.
Returns:
True if the password matches the hash, False otherwise.
"""
try:
self.hasher.verify(hashed, self._ensure_bytes(plain))
except (VerifyMismatchError, InvalidHash):
return False
except Exception: # noqa: BLE001
return False
return True
def compare_expression(self, column: "ColumnElement[str]", plain: "Union[str, bytes]") -> "BinaryExpression[bool]":
"""Direct SQL comparison is not supported for Argon2.
Raises:
NotImplementedError: Always raised.
"""
msg = "Argon2Hasher does not support direct SQL comparison."
raise NotImplementedError(msg)
python-advanced-alchemy-1.4.1/advanced_alchemy/types/password_hash/base.py 0000664 0000000 0000000 00000013366 15003544734 0027003 0 ustar 00root root 0000000 0000000 """Base classes for password hashing backends."""
import abc
from typing import Any, Union, cast
from sqlalchemy import BinaryExpression, ColumnElement, FunctionElement, String, TypeDecorator
class HashingBackend(abc.ABC):
"""Abstract base class for password hashing backends.
This class defines the interface that all password hashing backends must implement.
Concrete implementations should provide the actual hashing and verification logic.
"""
@staticmethod
def _ensure_bytes(value: Union[str, bytes]) -> bytes:
if isinstance(value, str):
return value.encode("utf-8")
return value
@abc.abstractmethod
def hash(self, value: "Union[str, bytes]") -> "Union[str, Any]":
"""Hash the given value.
Args:
value: The plain text value to hash.
Returns:
Either a string (the hash) or a SQLAlchemy SQL expression for DB-side hashing.
"""
@abc.abstractmethod
def verify(self, plain: "Union[str, bytes]", hashed: str) -> bool:
"""Verify a plain text value against a hash.
Args:
plain: The plain text value to verify.
hashed: The hash to verify against.
Returns:
True if the plain text matches the hash, False otherwise.
"""
@abc.abstractmethod
def compare_expression(self, column: "ColumnElement[str]", plain: "Union[str, bytes]") -> "BinaryExpression[bool]":
"""Generate a SQLAlchemy expression for comparing a column with a plain text value.
Args:
column: The SQLAlchemy column to compare.
plain: The plain text value to compare against.
Returns:
A SQLAlchemy binary expression for the comparison.
"""
class HashedPassword:
"""Wrapper class for a hashed password.
This class holds the hash string and provides a method to verify a plain text password against it.
"""
def __hash__(self) -> int:
return hash(self.hash_string)
def __init__(self, hash_string: str, backend: "HashingBackend") -> None:
"""Initialize a HashedPassword object.
Args:
hash_string: The hash string.
backend: The hashing backend to use for verification.
"""
self.hash_string = hash_string
self.backend = backend
def verify(self, plain_password: "Union[str, bytes]") -> bool:
"""Verify a plain text password against this hash.
Args:
plain_password: The plain text password to verify.
Returns:
True if the password matches the hash, False otherwise.
"""
return self.backend.verify(plain_password, self.hash_string)
class PasswordHash(TypeDecorator[str]):
"""SQLAlchemy TypeDecorator for storing hashed passwords in a database.
This type provides transparent hashing of password values using the specified backend.
It extends :class:`sqlalchemy.types.TypeDecorator` and implements String as its underlying type.
"""
impl = String
cache_ok = True
def __init__(self, backend: "HashingBackend", length: int = 128) -> None:
"""Initialize the PasswordHash TypeDecorator.
Args:
backend: The hashing backend class to use
length: The maximum length of the hash string. Defaults to 128.
"""
self.length = length
super().__init__(length=length)
self.backend = backend
@property
def python_type(self) -> "type[str]":
"""Returns the Python type for this type decorator.
Returns:
The Python string type.
"""
return str
def process_bind_param(self, value: Any, dialect: Any) -> "Union[str, FunctionElement[str], None]":
"""Process the value before binding it to the SQL statement.
This method hashes the value using the specified backend.
If the backend returns a SQLAlchemy FunctionElement (for DB-side hashing),
it is returned directly. Otherwise, the hashed string is returned.
Args:
value: The value to process.
dialect: The SQLAlchemy dialect.
Returns:
The hashed string, a SQLAlchemy FunctionElement, or None.
"""
if value is None:
return value
hashed_value = self.backend.hash(value)
# Check if the backend returned a SQL function for DB-side hashing
if isinstance(hashed_value, FunctionElement):
return cast("FunctionElement[str]", hashed_value)
# Otherwise, assume it's a string or HashedPassword object (convert to string)
return str(hashed_value)
def process_result_value(self, value: Any, dialect: Any) -> "Union[HashedPassword, None]": # type: ignore[override]
"""Process the value after retrieving it from the database.
This method wraps the hash string in a HashedPassword object.
Args:
value: The value to process.
dialect: The SQLAlchemy dialect.
Returns:
A HashedPassword object or None if the input is None.
"""
if value is None:
return value
# Ensure the retrieved value is a string before passing to HashedPassword
return HashedPassword(str(value), self.backend)
def compare_value(
self, column: "ColumnElement[str]", plain_password: "Union[str, bytes]"
) -> "BinaryExpression[bool]":
"""Generate a SQLAlchemy expression for comparing a column with a plain text password.
Args:
column: The SQLAlchemy column to compare.
plain_password: The plain text password to compare against.
Returns:
A SQLAlchemy binary expression for the comparison.
"""
return self.backend.compare_expression(column, plain_password)
python-advanced-alchemy-1.4.1/advanced_alchemy/types/password_hash/passlib.py 0000664 0000000 0000000 00000003751 15003544734 0027523 0 ustar 00root root 0000000 0000000 """Passlib Hashing Backend."""
from typing import TYPE_CHECKING, Any, Union
from passlib.context import CryptContext # pyright: ignore
from advanced_alchemy.types.password_hash.base import HashingBackend
if TYPE_CHECKING:
from sqlalchemy import BinaryExpression, ColumnElement
class PasslibHasher(HashingBackend):
"""Hashing backend using Passlib.
Relies on the `passlib` package being installed.
Install with `pip install passlib` or `uv pip install passlib`.
"""
def __init__(self, context: CryptContext) -> None:
"""Initialize PasslibBackend.
Args:
context: The Passlib CryptContext to use for hashing and verification.
"""
self.context = context
def hash(self, value: "Union[str, bytes]") -> str:
"""Hash the given value using the Passlib context.
Args:
value: The plain text value to hash. Will be converted to string.
Returns:
The hashed string.
"""
return self.context.hash(self._ensure_bytes(value))
def verify(self, plain: "Union[str, bytes]", hashed: str) -> bool:
"""Verify a plain text value against a hash using the Passlib context.
Args:
plain: The plain text value to verify. Will be converted to string.
hashed: The hash to verify against.
Returns:
True if the plain text matches the hash, False otherwise.
"""
try:
return self.context.verify(self._ensure_bytes(plain), hashed)
except Exception: # noqa: BLE001
# Passlib can raise various errors for invalid hashes
return False
def compare_expression(self, column: "ColumnElement[str]", plain: Any) -> "BinaryExpression[bool]":
"""Direct SQL comparison is not supported for Passlib.
Raises:
NotImplementedError: Always raised.
"""
msg = "PasslibHasher does not support direct SQL comparison."
raise NotImplementedError(msg)
python-advanced-alchemy-1.4.1/advanced_alchemy/types/password_hash/pwdlib.py 0000664 0000000 0000000 00000003410 15003544734 0027337 0 ustar 00root root 0000000 0000000 """Pwdlib Hashing Backend."""
from typing import TYPE_CHECKING, Any, Union
from advanced_alchemy.types.password_hash.base import HashingBackend
if TYPE_CHECKING:
from sqlalchemy import BinaryExpression, ColumnElement
from pwdlib.hashers.base import HasherProtocol
class PwdlibHasher(HashingBackend):
"""Hashing backend using Pwdlib."""
def __init__(self, hasher: HasherProtocol) -> None:
"""Initialize PwdlibBackend.
Args:
hasher: The Pwdlib hasher to use for hashing and verification.
"""
self.hasher = hasher
def hash(self, value: "Union[str, bytes]") -> str:
"""Hash the given value using the Pwdlib context.
Args:
value: The plain text value to hash. Will be converted to string.
Returns:
The hashed string.
"""
return self.hasher.hash(self._ensure_bytes(value))
def verify(self, plain: "Union[str, bytes]", hashed: str) -> bool:
"""Verify a plain text value against a hash using the Pwdlib context.
Args:
plain: The plain text value to verify. Will be converted to string.
hashed: The hash to verify against.
Returns:
True if the plain text matches the hash, False otherwise.
"""
try:
return self.hasher.verify(self._ensure_bytes(plain), hashed)
except Exception: # noqa: BLE001
return False
def compare_expression(self, column: "ColumnElement[str]", plain: Any) -> "BinaryExpression[bool]":
"""Direct SQL comparison is not supported for Pwdlib.
Raises:
NotImplementedError: Always raised.
"""
msg = "PwdlibHasher does not support direct SQL comparison."
raise NotImplementedError(msg)
python-advanced-alchemy-1.4.1/advanced_alchemy/utils/ 0000775 0000000 0000000 00000000000 15003544734 0022635 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/utils/__init__.py 0000664 0000000 0000000 00000000000 15003544734 0024734 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/advanced_alchemy/utils/dataclass.py 0000664 0000000 0000000 00000012124 15003544734 0025146 0 ustar 00root root 0000000 0000000 from dataclasses import Field, fields, is_dataclass
from inspect import isclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, final, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Set as AbstractSet
from typing_extensions import TypeAlias, TypeGuard
__all__ = (
"DataclassProtocol",
"Empty",
"EmptyType",
"extract_dataclass_fields",
"extract_dataclass_items",
"is_dataclass_class",
"is_dataclass_instance",
"simple_asdict",
)
@final
class Empty:
"""A sentinel class used as placeholder."""
EmptyType: "TypeAlias" = type[Empty]
"""Type alias for the :class:`~advanced_alchemy.utils.dataclass.Empty` sentinel class."""
@runtime_checkable
class DataclassProtocol(Protocol):
"""Protocol for instance checking dataclasses"""
__dataclass_fields__: "ClassVar[dict[str, Any]]"
def extract_dataclass_fields(
dt: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
include: "Optional[AbstractSet[str]]" = None,
exclude: "Optional[AbstractSet[str]]" = None,
) -> "tuple[Field[Any], ...]":
"""Extract dataclass fields.
Args:
dt: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
include: An iterable of fields to include.
exclude: An iterable of fields to exclude.
Returns:
A tuple of dataclass fields.
"""
include = include or set()
exclude = exclude or set()
if common := (include & exclude):
msg = f"Fields {common} are both included and excluded."
raise ValueError(msg)
dataclass_fields: Iterable[Field[Any]] = fields(dt)
if exclude_none:
dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not None)
if exclude_empty:
dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not Empty)
if include:
dataclass_fields = (field for field in dataclass_fields if field.name in include)
if exclude:
dataclass_fields = (field for field in dataclass_fields if field.name not in exclude)
return tuple(dataclass_fields)
def extract_dataclass_items(
dt: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
include: "Optional[AbstractSet[str]]" = None,
exclude: "Optional[AbstractSet[str]]" = None,
) -> tuple[tuple[str, Any], ...]:
"""Extract dataclass name, value pairs.
Unlike the 'asdict' method exports by the stdlib, this function does not pickle values.
Args:
dt: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
include: An iterable of fields to include.
exclude: An iterable of fields to exclude.
Returns:
A tuple of key/value pairs.
"""
dataclass_fields = extract_dataclass_fields(dt, exclude_none, exclude_empty, include, exclude)
return tuple((field.name, getattr(dt, field.name)) for field in dataclass_fields)
def simple_asdict(
obj: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
convert_nested: bool = True,
exclude: "Optional[AbstractSet[str]]" = None,
) -> "dict[str, Any]":
"""Convert a dataclass to a dictionary.
This method has important differences to the standard library version:
- it does not deepcopy values
- it does not recurse into collections
Args:
obj: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
convert_nested: Whether to recursively convert nested dataclasses.
exclude: An iterable of fields to exclude.
Returns:
A dictionary of key/value pairs.
"""
ret: dict[str, Any] = {}
for field in extract_dataclass_fields(obj, exclude_none, exclude_empty, exclude=exclude):
value = getattr(obj, field.name)
if is_dataclass_instance(value) and convert_nested:
ret[field.name] = simple_asdict(value, exclude_none, exclude_empty)
else:
ret[field.name] = getattr(obj, field.name)
return ret
def is_dataclass_instance(obj: Any) -> "TypeGuard[DataclassProtocol]":
"""Check if an object is a dataclass instance.
Args:
obj: An object to check.
Returns:
True if the object is a dataclass instance.
"""
return hasattr(type(obj), "__dataclass_fields__") # pyright: ignore[reportUnknownArgumentType]
def is_dataclass_class(annotation: Any) -> "TypeGuard[type[DataclassProtocol]]":
"""Wrap :func:`is_dataclass ` in a :data:`typing.TypeGuard`.
Args:
annotation: tested to determine if instance or type of :class:`dataclasses.dataclass`.
Returns:
``True`` if instance or type of ``dataclass``.
"""
try:
return isclass(annotation) and is_dataclass(annotation)
except TypeError: # pragma: no cover
return False
python-advanced-alchemy-1.4.1/advanced_alchemy/utils/deprecation.py 0000664 0000000 0000000 00000007507 15003544734 0025515 0 ustar 00root root 0000000 0000000 import inspect
from functools import wraps
from typing import Callable, Literal, Optional
from warnings import warn
from typing_extensions import ParamSpec, TypeVar
__all__ = ("deprecated", "warn_deprecation")
T = TypeVar("T")
P = ParamSpec("P")
DeprecatedKind = Literal["function", "method", "classmethod", "attribute", "property", "class", "parameter", "import"]
def warn_deprecation(
version: str,
deprecated_name: str,
kind: DeprecatedKind,
*,
removal_in: Optional[str] = None,
alternative: Optional[str] = None,
info: Optional[str] = None,
pending: bool = False,
) -> None:
"""Warn about a call to a (soon to be) deprecated function.
Args:
version: Advanced Alchemy version where the deprecation will occur
deprecated_name: Name of the deprecated function
removal_in: Advanced Alchemy version where the deprecated function will be removed
alternative: Name of a function that should be used instead
info: Additional information
pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning`
kind: Type of the deprecated thing
"""
parts = []
if kind == "import":
access_type = "Import of"
elif kind in {"function", "method"}:
access_type = "Call to"
else:
access_type = "Use of"
if pending:
parts.append(f"{access_type} {kind} awaiting deprecation {deprecated_name!r}") # pyright: ignore[reportUnknownMemberType]
else:
parts.append(f"{access_type} deprecated {kind} {deprecated_name!r}") # pyright: ignore[reportUnknownMemberType]
parts.extend( # pyright: ignore[reportUnknownMemberType]
(
f"Deprecated in advanced-alchemy {version}",
f"This {kind} will be removed in {removal_in or 'the next major version'}",
),
)
if alternative:
parts.append(f"Use {alternative!r} instead") # pyright: ignore[reportUnknownMemberType]
if info:
parts.append(info) # pyright: ignore[reportUnknownMemberType]
text = ". ".join(parts) # pyright: ignore[reportUnknownArgumentType]
warning_class = PendingDeprecationWarning if pending else DeprecationWarning
warn(text, warning_class, stacklevel=2)
def deprecated(
version: str,
*,
removal_in: Optional[str] = None,
alternative: Optional[str] = None,
info: Optional[str] = None,
pending: bool = False,
kind: Optional[Literal["function", "method", "classmethod", "property"]] = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Create a decorator wrapping a function, method or property with a warning call about a (pending) deprecation.
Args:
version: Advanced Alchemy version where the deprecation will occur
removal_in: Advanced Alchemy version where the deprecated function will be removed
alternative: Name of a function that should be used instead
info: Additional information
pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning`
kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure
out if it's a function or method
Returns:
A decorator wrapping the function call with a warning
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
warn_deprecation(
version=version,
deprecated_name=func.__name__,
info=info,
alternative=alternative,
pending=pending,
removal_in=removal_in,
kind=kind or ("method" if inspect.ismethod(func) else "function"),
)
return func(*args, **kwargs)
return wrapped
return decorator
python-advanced-alchemy-1.4.1/advanced_alchemy/utils/fixtures.py 0000664 0000000 0000000 00000004231 15003544734 0025060 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any, Union
from advanced_alchemy._serialization import decode_json
from advanced_alchemy.exceptions import MissingDependencyError
if TYPE_CHECKING:
from pathlib import Path
from anyio import Path as AsyncPath
__all__ = ("open_fixture", "open_fixture_async")
def open_fixture(fixtures_path: "Union[Path, AsyncPath]", fixture_name: str) -> Any:
"""Loads JSON file with the specified fixture name
Args:
fixtures_path: :class:`pathlib.Path` | :class:`anyio.Path` The path to look for fixtures
fixture_name (str): The fixture name to load.
Raises:
:class:`FileNotFoundError`: Fixtures not found.
Returns:
Any: The parsed JSON data
"""
from pathlib import Path
fixture = Path(fixtures_path / f"{fixture_name}.json")
if fixture.exists():
with fixture.open(mode="r", encoding="utf-8") as f:
f_data = f.read()
return decode_json(f_data)
msg = f"Could not find the {fixture_name} fixture"
raise FileNotFoundError(msg)
async def open_fixture_async(fixtures_path: "Union[Path, AsyncPath]", fixture_name: str) -> Any:
"""Loads JSON file with the specified fixture name
Args:
fixtures_path: :class:`pathlib.Path` | :class:`anyio.Path` The path to look for fixtures
fixture_name (str): The fixture name to load.
Raises:
:class:`~advanced_alchemy.exceptions.MissingDependencyError`: The `anyio` library is required to use this function.
:class:`FileNotFoundError`: Fixtures not found.
Returns:
Any: The parsed JSON data
"""
try:
from anyio import Path as AsyncPath
except ImportError as exc:
msg = "The `anyio` library is required to use this function. Please install it with `pip install anyio`."
raise MissingDependencyError(msg) from exc
fixture = AsyncPath(fixtures_path / f"{fixture_name}.json")
if await fixture.exists():
async with await fixture.open(mode="r", encoding="utf-8") as f:
f_data = await f.read()
return decode_json(f_data)
msg = f"Could not find the {fixture_name} fixture"
raise FileNotFoundError(msg)
python-advanced-alchemy-1.4.1/advanced_alchemy/utils/module_loader.py 0000664 0000000 0000000 00000005213 15003544734 0026023 0 ustar 00root root 0000000 0000000 """General utility functions."""
import sys
from importlib import import_module
from importlib.util import find_spec
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from types import ModuleType
__all__ = (
"import_string",
"module_to_os_path",
)
def module_to_os_path(dotted_path: str = "app") -> Path:
"""Find Module to OS Path.
Return a path to the base directory of the project or the module
specified by `dotted_path`.
Args:
dotted_path: The path to the module. Defaults to "app".
Raises:
TypeError: The module could not be found.
Returns:
Path: The path to the module.
"""
try:
if (src := find_spec(dotted_path)) is None: # pragma: no cover
msg = f"Couldn't find the path for {dotted_path}"
raise TypeError(msg)
except ModuleNotFoundError as e:
msg = f"Couldn't find the path for {dotted_path}"
raise TypeError(msg) from e
path = Path(str(src.origin))
return path.parent if path.is_file() else path
def import_string(dotted_path: str) -> Any:
"""Dotted Path Import.
Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import failed.
Args:
dotted_path: The path of the module to import.
Raises:
ImportError: Could not import the module.
Returns:
object: The imported object.
"""
def _is_loaded(module: "Optional[ModuleType]") -> bool:
spec = getattr(module, "__spec__", None)
initializing = getattr(spec, "_initializing", False)
return bool(module and spec and not initializing)
def _cached_import(module_path: str, class_name: str) -> Any:
"""Import and cache a class from a module.
Args:
module_path: dotted path to module.
class_name: Class or function name.
Returns:
object: The imported class or function
"""
# Check whether module is loaded and fully initialized.
module = sys.modules.get(module_path)
if not _is_loaded(module):
module = import_module(module_path)
return getattr(module, class_name)
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError as e:
msg = "%s doesn't look like a module path"
raise ImportError(msg, dotted_path) from e
try:
return _cached_import(module_path, class_name)
except AttributeError as e:
msg = "Module '%s' does not define a '%s' attribute/class"
raise ImportError(msg, module_path, class_name) from e
python-advanced-alchemy-1.4.1/advanced_alchemy/utils/portals.py 0000664 0000000 0000000 00000016373 15003544734 0024705 0 ustar 00root root 0000000 0000000 """This module provides a portal provider and portal for calling async functions from synchronous code."""
import asyncio
import functools
import queue
import threading
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, cast
from warnings import warn
from advanced_alchemy.exceptions import ImproperConfigurationError
if TYPE_CHECKING:
from collections.abc import Coroutine
__all__ = ("Portal", "PortalProvider", "PortalProviderSingleton")
_R = TypeVar("_R")
class PortalProviderSingleton(type):
"""A singleton metaclass for PortalProvider that creates unique instances per event loop."""
_instances: "ClassVar[dict[tuple[type, Optional[asyncio.AbstractEventLoop]], PortalProvider]]" = {}
def __call__(cls, *args: Any, **kwargs: Any) -> "PortalProvider":
# Use a tuple of the class and loop as the key
key = (cls, kwargs.get("loop"))
if key not in cls._instances:
cls._instances[key] = super().__call__(*args, **kwargs)
return cls._instances[key]
class PortalProvider(metaclass=PortalProviderSingleton):
"""A provider for creating and managing threaded portals."""
def __init__(self, /, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
"""Initialize the PortalProvider."""
self._request_queue: queue.Queue[
tuple[
Callable[..., Coroutine[Any, Any, Any]],
tuple[Any, ...],
dict[str, Any],
queue.Queue[tuple[Optional[Any], Optional[Exception]]],
]
] = queue.Queue()
self._result_queue: queue.Queue[tuple[Optional[Any], Optional[Exception]]] = queue.Queue()
self._loop: Optional[asyncio.AbstractEventLoop] = loop
self._thread: Optional[threading.Thread] = None
self._ready_event: threading.Event = threading.Event()
@property
def portal(self) -> "Portal":
"""The portal instance."""
return Portal(self)
@property
def is_running(self) -> bool:
"""Whether the portal provider is running."""
return self._thread is not None and self._thread.is_alive()
@property
def is_ready(self) -> bool:
"""Whether the portal provider is ready."""
return self._ready_event.is_set()
@property
def loop(self) -> "asyncio.AbstractEventLoop": # pragma: no cover
"""The event loop.
Raises:
ImproperConfigurationError: If the portal provider is not started.
"""
if self._loop is None:
msg = "The PortalProvider is not started. Did you forget to call .start()?"
raise ImproperConfigurationError(msg)
return self._loop
def start(self) -> None:
"""Starts the background thread and event loop."""
if self._thread is not None: # pragma: no cover
warn("PortalProvider already started", stacklevel=2)
return
self._thread = threading.Thread(target=self._run_event_loop, daemon=True)
self._thread.start()
self._ready_event.wait() # Wait for the loop to be ready
def stop(self) -> None:
"""Stops the background thread and event loop."""
if self._loop is None or self._thread is None:
return
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
self._loop.close()
self._loop = None
self._thread = None
self._ready_event.clear()
def _run_event_loop(self) -> None: # pragma: no cover
"""The main function of the background thread."""
if self._loop is None:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._ready_event.set() # Signal that the loop is ready
self._loop.run_forever()
@staticmethod
async def _async_caller(
func: "Callable[..., Coroutine[Any, Any, _R]]",
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> _R:
"""Wrapper to run the async function and send the result to the result queue.
Args:
func: The async function to call.
args: Positional arguments to the function.
kwargs: Keyword arguments to the function.
Returns:
The result of the async function.
"""
result: _R = await func(*args, **kwargs)
return result
def call(self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, **kwargs: Any) -> _R:
"""Calls an async function from a synchronous context.
Args:
func: The async function to call.
*args: Positional arguments to the function.
**kwargs: Keyword arguments to the function.
Raises:
ImproperConfigurationError: If the portal provider is not started.
Returns:
The result of the async function.
"""
if self._loop is None:
msg = "The PortalProvider is not started. Did you forget to call .start()?"
raise ImproperConfigurationError(msg)
# Create a new result queue
local_result_queue: queue.Queue[tuple[Optional[_R], Optional[Exception]]] = queue.Queue()
# Send the request to the background thread
self._request_queue.put((func, args, kwargs, local_result_queue))
# Trigger the execution in the event loop
_handle = self._loop.call_soon_threadsafe(self._process_request)
# Wait for the result from the background thread
result, exception = local_result_queue.get()
if exception:
raise exception
return cast("_R", result)
def _process_request(self) -> None: # pragma: no cover
"""Processes a request from the request queue in the event loop."""
assert self._loop is not None # noqa: S101
if not self._request_queue.empty():
func, args, kwargs, local_result_queue = self._request_queue.get()
future = asyncio.run_coroutine_threadsafe(self._async_caller(func, args, kwargs), self._loop)
# Attach a callback to handle the result/exception
future.add_done_callback(
functools.partial(self._handle_future_result, local_result_queue=local_result_queue), # pyright: ignore[reportArgumentType]
)
@staticmethod
def _handle_future_result(
future: "asyncio.Future[Any]",
local_result_queue: "queue.Queue[tuple[Optional[Any], Optional[Exception]]]",
) -> None: # pragma: no cover
"""Handles the result or exception from the completed future."""
try:
result = future.result()
local_result_queue.put((result, None))
except Exception as e: # noqa: BLE001
local_result_queue.put((None, e))
class Portal:
def __init__(self, provider: "PortalProvider") -> None:
self._provider = provider
def call(self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, **kwargs: Any) -> _R:
"""Calls an async function using the associated PortalProvider.
Args:
func: The async function to call.
*args: Positional arguments to the function.
**kwargs: Keyword arguments to the function.
Returns:
The result of the async function.
"""
return self._provider.call(func, *args, **kwargs)
python-advanced-alchemy-1.4.1/advanced_alchemy/utils/singleton.py 0000664 0000000 0000000 00000002456 15003544734 0025220 0 ustar 00root root 0000000 0000000 from typing import Any, TypeVar
_T = TypeVar("_T")
class SingletonMeta(type):
"""Metaclass for singleton pattern."""
# We store instances keyed by the class type
_instances: dict[type, object] = {}
def __call__(cls: type[_T], *args: Any, **kwargs: Any) -> _T:
"""Call method for the singleton metaclass.
Args:
cls: The class being instantiated.
*args: Positional arguments for the class constructor.
**kwargs: Keyword arguments for the class constructor.
Returns:
The singleton instance of the class.
"""
# Use SingletonMeta._instances to access the class attribute
if cls not in SingletonMeta._instances: # pyright: ignore[reportUnnecessaryContains]
# Create the instance using super().__call__ which calls the class's __new__ and __init__
instance = super().__call__(*args, **kwargs) # type: ignore
SingletonMeta._instances[cls] = instance
# Return the cached instance. We cast here because the dictionary stores `object`,
# but we know it's of type _T for the given cls key.
# Mypy might need an ignore here depending on configuration, but pyright should handle it.
return SingletonMeta._instances[cls] # type: ignore[return-value]
python-advanced-alchemy-1.4.1/advanced_alchemy/utils/sync_tools.py 0000664 0000000 0000000 00000025665 15003544734 0025421 0 ustar 00root root 0000000 0000000 import asyncio
import functools
import inspect
import sys
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import (
TYPE_CHECKING,
Any,
Generic,
Optional,
TypeVar,
Union,
cast,
)
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Coroutine
from types import TracebackType
try:
import uvloop # pyright: ignore[reportMissingImports]
except ImportError:
uvloop = None # type: ignore[assignment]
ReturnT = TypeVar("ReturnT")
ParamSpecT = ParamSpec("ParamSpecT")
T = TypeVar("T")
class PendingType:
def __repr__(self) -> str:
return "AsyncPending"
Pending = PendingType()
class PendingValueError(Exception):
"""Exception raised when a value is accessed before it is ready."""
class SoonValue(Generic[T]):
"""Holds a value that will be available soon after an async operation."""
def __init__(self) -> None:
self._stored_value: Union[T, PendingType] = Pending
@property
def value(self) -> "T":
if isinstance(self._stored_value, PendingType):
msg = "The return value of this task is still pending."
raise PendingValueError(msg)
return self._stored_value
@property
def ready(self) -> bool:
return not isinstance(self._stored_value, PendingType)
class TaskGroup:
"""Manages a group of asyncio tasks, allowing them to be run concurrently."""
def __init__(self) -> None:
self._tasks: set[asyncio.Task[Any]] = set()
self._exceptions: list[BaseException] = []
self._closed = False
async def __aenter__(self) -> "TaskGroup":
if self._closed:
msg = "Cannot enter a task group that has already been closed."
raise RuntimeError(msg)
return self
async def __aexit__(
self,
exc_type: "Optional[type[BaseException]]", # noqa: PYI036
exc_val: "Optional[BaseException]", # noqa: PYI036
exc_tb: "Optional[TracebackType]", # noqa: PYI036
) -> None:
self._closed = True
if exc_val:
self._exceptions.append(exc_val)
if self._tasks:
await asyncio.wait(self._tasks)
if self._exceptions:
# Re-raise the first exception encountered.
raise self._exceptions[0]
def create_task(self, coro: "Coroutine[Any, Any, Any]") -> "asyncio.Task[Any]":
"""Create and add a coroutine as a task to the task group.
Args:
coro (Coroutine): The coroutine to be added as a task.
Returns:
asyncio.Task: The created asyncio task.
Raises:
RuntimeError: If the task group has already been closed.
"""
if self._closed:
msg = "Cannot create a task in a task group that has already been closed."
raise RuntimeError(msg)
task = asyncio.create_task(coro)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
task.add_done_callback(self._check_result)
return task
def _check_result(self, task: "asyncio.Task[Any]") -> None:
"""Check and store exceptions from a completed task.
Args:
task (asyncio.Task): The task to check for exceptions.
"""
try:
task.result() # This will raise the exception if one occurred.
except Exception as e: # noqa: BLE001
self._exceptions.append(e)
def start_soon_(
self,
async_function: "Callable[ParamSpecT, Awaitable[T]]",
name: object = None,
) -> "Callable[ParamSpecT, SoonValue[T]]":
"""Create a function to start a new task in this task group.
Args:
async_function (Callable): An async function to call soon.
name (object, optional): Name of the task for introspection and debugging.
Returns:
Callable: A function that starts the task and returns a SoonValue object.
"""
@functools.wraps(async_function)
def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "SoonValue[T]":
partial_f = functools.partial(async_function, *args, **kwargs)
soon_value: SoonValue[T] = SoonValue()
@functools.wraps(partial_f)
async def value_wrapper(*_args: "Any") -> None:
value = await partial_f()
soon_value._stored_value = value # pyright: ignore[reportPrivateUsage] # noqa: SLF001
self.create_task(value_wrapper) # type: ignore[arg-type]
return soon_value
return wrapper
def create_task_group() -> "TaskGroup":
"""Create a TaskGroup for managing multiple concurrent async tasks.
Returns:
TaskGroup: A new TaskGroup instance.
"""
return TaskGroup()
class CapacityLimiter:
"""Limits the number of concurrent operations using a semaphore."""
def __init__(self, total_tokens: int) -> None:
self._semaphore = asyncio.Semaphore(total_tokens)
async def acquire(self) -> None:
await self._semaphore.acquire()
def release(self) -> None:
self._semaphore.release()
@property
def total_tokens(self) -> int:
return self._semaphore._value # noqa: SLF001
@total_tokens.setter
def total_tokens(self, value: int) -> None:
self._semaphore = asyncio.Semaphore(value)
async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(
self,
exc_type: "Optional[type[BaseException]]", # noqa: PYI036
exc_val: "Optional[BaseException]", # noqa: PYI036
exc_tb: "Optional[TracebackType]", # noqa: PYI036
) -> None:
self.release()
_default_limiter = CapacityLimiter(40)
def run_(async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]") -> "Callable[ParamSpecT, ReturnT]":
"""Convert an async function to a blocking function using asyncio.run().
Args:
async_function (Callable): The async function to convert.
Returns:
Callable: A blocking function that runs the async function.
"""
@functools.wraps(async_function)
def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
partial_f = functools.partial(async_function, *args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None:
# Running in an existing event loop
return asyncio.run(partial_f())
# Create a new event loop and run the function
if uvloop and sys.platform != "win32":
uvloop.install() # pyright: ignore[reportUnknownMemberType]
return asyncio.run(partial_f())
return wrapper
def await_(
async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]",
raise_sync_error: bool = True,
) -> "Callable[ParamSpecT, ReturnT]":
"""Convert an async function to a blocking one, running in the main async loop.
Args:
async_function (Callable): The async function to convert.
raise_sync_error (bool, optional): If False, runs in a new event loop if no loop is present.
Returns:
Callable: A blocking function that runs the async function.
"""
@functools.wraps(async_function)
def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
partial_f = functools.partial(async_function, *args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is None and raise_sync_error is False:
return asyncio.run(partial_f())
# Running in an existing event loop
return asyncio.run(partial_f())
return wrapper
def async_(
function: "Callable[ParamSpecT, ReturnT]",
*,
limiter: "Optional[CapacityLimiter]" = None,
) -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
"""Convert a blocking function to an async one using asyncio.to_thread().
Args:
function (Callable): The blocking function to convert.
cancellable (bool, optional): Allow cancellation of the operation.
limiter (CapacityLimiter, optional): Limit the total number of threads.
Returns:
Callable: An async function that runs the original function in a thread.
"""
async def wrapper(
*args: "ParamSpecT.args",
**kwargs: "ParamSpecT.kwargs",
) -> "ReturnT":
partial_f = functools.partial(function, *args, **kwargs)
used_limiter = limiter or _default_limiter
async with used_limiter:
return await asyncio.to_thread(partial_f)
return wrapper
def maybe_async_(
function: "Callable[ParamSpecT, Union[Awaitable[ReturnT], ReturnT]]",
) -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
"""Convert a function to an async one if it is not already.
Args:
function (Callable): The function to convert.
Returns:
Callable: An async function that runs the original function.
"""
if inspect.iscoroutinefunction(function):
return function
async def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
result = function(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return await async_(lambda: result)()
return wrapper
def wrap_sync(fn: "Callable[ParamSpecT, ReturnT]") -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
"""Convert a sync function to an async one.
Args:
fn (Callable): The function to convert.
Returns:
Callable: An async function that runs the original function.
"""
if inspect.iscoroutinefunction(fn):
return fn
async def wrapped(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> ReturnT:
return await async_(functools.partial(fn, *args, **kwargs))()
return wrapped
class _ContextManagerWrapper(Generic[T]):
def __init__(self, cm: AbstractContextManager[T]) -> None:
self._cm = cm
async def __aenter__(self) -> T:
return self._cm.__enter__()
async def __aexit__(
self,
exc_type: "Optional[type[BaseException]]", # noqa: PYI036
exc_val: "Optional[BaseException]", # noqa: PYI036
exc_tb: "Optional[TracebackType]", # noqa: PYI036
) -> "Optional[bool]":
return self._cm.__exit__(exc_type, exc_val, exc_tb)
def maybe_async_context(
obj: "Union[AbstractContextManager[T], AbstractAsyncContextManager[T]]",
) -> "AbstractAsyncContextManager[T]":
"""Convert a context manager to an async one if it is not already.
Args:
obj (AbstractContextManager[T] or AbstractAsyncContextManager[T]): The context manager to convert.
Returns:
AbstractAsyncContextManager[T]: An async context manager that runs the original context manager.
"""
if isinstance(obj, AbstractContextManager):
return cast("AbstractAsyncContextManager[T]", _ContextManagerWrapper(obj))
return obj
python-advanced-alchemy-1.4.1/advanced_alchemy/utils/text.py 0000664 0000000 0000000 00000003755 15003544734 0024205 0 ustar 00root root 0000000 0000000 """General utility functions."""
import re
import unicodedata
from functools import lru_cache
from typing import Optional
__all__ = (
"check_email",
"slugify",
)
def check_email(email: str) -> str:
"""Validate an email.
Very simple email validation.
Args:
email (str): The email to validate.
Raises:
ValueError: If the email is invalid.
Returns:
str: The validated email.
"""
if "@" not in email:
msg = "Invalid email!"
raise ValueError(msg)
return email.lower()
def slugify(value: str, allow_unicode: bool = False, separator: Optional[str] = None) -> str:
"""Slugify.
Convert to ASCII if ``allow_unicode`` is ``False``. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Args:
value (str): the string to slugify
allow_unicode (bool, optional): allow unicode characters in slug. Defaults to False.
separator (str, optional): by default a `-` is used to delimit word boundaries.
Set this to configure something different.
Returns:
str: a slugified string of the value parameter
"""
if allow_unicode:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[^\w\s-]", "", value.lower())
if separator is not None:
return re.sub(r"[-\s]+", "-", value).strip("-_").replace("-", separator)
return re.sub(r"[-\s]+", "-", value).strip("-_")
@lru_cache(maxsize=100)
def camelize(string: str) -> str:
"""Convert a string to camel case.
Args:
string (str): The string to convert.
Returns:
str: The converted string.
"""
return "".join(word if index == 0 else word.capitalize() for index, word in enumerate(string.split("_")))
python-advanced-alchemy-1.4.1/codecov.yml 0000664 0000000 0000000 00000000253 15003544734 0020373 0 ustar 00root root 0000000 0000000 coverage:
status:
project:
default:
target: auto
threshold: 2%
patch:
default:
target: auto
comment:
require_changes: true
python-advanced-alchemy-1.4.1/docs/ 0000775 0000000 0000000 00000000000 15003544734 0017156 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.4.1/docs/Makefile 0000664 0000000 0000000 00000001172 15003544734 0020617 0 ustar 00root root 0000000 0000000 # Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
python-advanced-alchemy-1.4.1/docs/PYPI_README.md 0000664 0000000 0000000 00000002265 15003544734 0021303 0 ustar 00root root 0000000 0000000