# Advanced Alchemy
Check out the [project documentation][project-docs] ๐ for more information.
## About
A carefully crafted, thoroughly tested, optimized companion library for SQLAlchemy,
offering:
- Sync and async repositories, featuring common CRUD and highly optimized bulk operations
- Integration with major web frameworks including Litestar, Starlette, FastAPI, Sanic
- Custom-built alembic configuration and CLI with optional framework integration
- Utility base classes with audit columns, primary keys and utility functions
- Optimized JSON types including a custom JSON type for Oracle
- Integrated support for UUID6 and UUID7 using [`uuid-utils`](https://github.com/aminalaee/uuid-utils) (install with the `uuid` extra)
- Integrated support for Nano ID using [`fastnanoid`](https://github.com/oliverlambson/fastnanoid) (install with the `nanoid` extra)
- Pre-configured base classes with audit columns UUID or Big Integer primary keys and
a [sentinel column](https://docs.sqlalchemy.org/en/20/core/connections.html#configuring-sentinel-columns).
- Synchronous and asynchronous repositories featuring:
- Common CRUD operations for SQLAlchemy models
- Bulk inserts, updates, upserts, and deletes with dialect-specific enhancements
- Integrated counts, pagination, sorting, filtering with `LIKE`, `IN`, and dates before and/or after.
- Tested support for multiple database backends including:
- SQLite via [aiosqlite](https://aiosqlite.omnilib.dev/en/stable/) or [sqlite](https://docs.python.org/3/library/sqlite3.html)
- Postgres via [asyncpg](https://magicstack.github.io/asyncpg/current/) or [psycopg3 (async or sync)](https://www.psycopg.org/psycopg3/)
- MySQL via [asyncmy](https://github.com/long2ice/asyncmy)
- Oracle via [oracledb (async or sync)](https://oracle.github.io/python-oracledb/) (tested on 18c and 23c)
- Google Spanner via [spanner-sqlalchemy](https://github.com/googleapis/python-spanner-sqlalchemy/)
- DuckDB via [duckdb_engine](https://github.com/Mause/duckdb_engine)
- Microsoft SQL Server via [pyodbc](https://github.com/mkleehammer/pyodbc) or [aioodbc](https://github.com/aio-libs/aioodbc)
- CockroachDB via [sqlalchemy-cockroachdb (async or sync)](https://github.com/cockroachdb/sqlalchemy-cockroachdb)
- ...and much more
## Usage
### Installation
```shell
pip install advanced-alchemy
```
> [!IMPORTANT]\
> Check out [the installation guide][install-guide] in our official documentation!
### Repositories
Advanced Alchemy includes a set of asynchronous and synchronous repository classes for easy CRUD
operations on your SQLAlchemy models.
Click to expand the example
```python
from advanced_alchemy import base, repository, config
from sqlalchemy import create_engine
from sqlalchemy.orm import Mapped, sessionmaker
class User(base.UUIDBase):
# you can optionally override the generated table name by manually setting it.
__tablename__ = "user_account" # type: ignore[assignment]
email: Mapped[str]
name: Mapped[str]
class UserRepository(repository.SQLAlchemySyncRepository[User]):
"""User repository."""
model_type = User
db = config.SQLAlchemySyncConfig(connection_string="duckdb:///:memory:", session_config=config.SyncSessionConfig(expire_on_commit=False))
# Initializes the database.
with db.get_engine().begin() as conn:
User.metadata.create_all(conn)
with db.get_session() as db_session:
repo = UserRepository(session=db_session)
# 1) Create multiple users with `add_many`
bulk_users = [
{"email": 'cody@litestar.dev', 'name': 'Cody'},
{"email": 'janek@litestar.dev', 'name': 'Janek'},
{"email": 'peter@litestar.dev', 'name': 'Peter'},
{"email": 'jacob@litestar.dev', 'name': 'Jacob'}
]
objs = repo.add_many([User(**raw_user) for raw_user in bulk_users])
db_session.commit()
print(f"Created {len(objs)} new objects.")
# 2) Select paginated data and total row count. Pass additional filters as kwargs
created_objs, total_objs = repo.list_and_count(LimitOffset(limit=10, offset=0), name="Cody")
print(f"Selected {len(created_objs)} records out of a total of {total_objs}.")
# 3) Let's remove the batch of records selected.
deleted_objs = repo.delete_many([new_obj.id for new_obj in created_objs])
print(f"Removed {len(deleted_objs)} records out of a total of {total_objs}.")
# 4) Let's count the remaining rows
remaining_count = repo.count()
print(f"Found {remaining_count} remaining records after delete.")
```
For a full standalone example, see the sample [here][standalone-example]
### Services
Advanced Alchemy includes an additional service class to make working with a repository easier.
This class is designed to accept data as a dictionary or SQLAlchemy model,
and it will handle the type conversions for you.
Here's the same example from above but using a service to create the data:
```python
from advanced_alchemy import base, repository, filters, service, config
from sqlalchemy import create_engine
from sqlalchemy.orm import Mapped, sessionmaker
class User(base.UUIDBase):
# you can optionally override the generated table name by manually setting it.
__tablename__ = "user_account" # type: ignore[assignment]
email: Mapped[str]
name: Mapped[str]
class UserService(service.SQLAlchemySyncRepositoryService[User]):
"""User repository."""
class Repo(repository.SQLAlchemySyncRepository[User]):
"""User repository."""
model_type = User
repository_type = Repo
db = config.SQLAlchemySyncConfig(connection_string="duckdb:///:memory:", session_config=config.SyncSessionConfig(expire_on_commit=False))
# Initializes the database.
with db.get_engine().begin() as conn:
User.metadata.create_all(conn)
with db.get_session() as db_session:
service = UserService(session=db_session)
# 1) Create multiple users with `add_many`
objs = service.create_many([
{"email": 'cody@litestar.dev', 'name': 'Cody'},
{"email": 'janek@litestar.dev', 'name': 'Janek'},
{"email": 'peter@litestar.dev', 'name': 'Peter'},
{"email": 'jacob@litestar.dev', 'name': 'Jacob'}
])
print(objs)
print(f"Created {len(objs)} new objects.")
# 2) Select paginated data and total row count. Pass additional filters as kwargs
created_objs, total_objs = service.list_and_count(LimitOffset(limit=10, offset=0), name="Cody")
print(f"Selected {len(created_objs)} records out of a total of {total_objs}.")
# 3) Let's remove the batch of records selected.
deleted_objs = service.delete_many([new_obj.id for new_obj in created_objs])
print(f"Removed {len(deleted_objs)} records out of a total of {total_objs}.")
# 4) Let's count the remaining rows
remaining_count = service.count()
print(f"Found {remaining_count} remaining records after delete.")
```
### Web Frameworks
Advanced Alchemy works with nearly all Python web frameworks.
Several helpers for popular libraries are included, and additional PRs to support others are welcomed.
#### Litestar
Advanced Alchemy is the official SQLAlchemy integration for Litestar.
In addition to installing with `pip install advanced-alchemy`,
it can also be installed as a Litestar extra with `pip install litestar[sqlalchemy]`.
Litestar Example
```python
from litestar import Litestar
from litestar.plugins.sqlalchemy import SQLAlchemyPlugin, SQLAlchemyAsyncConfig
# alternately...
# from advanced_alchemy.extensions.litestar import SQLAlchemyAsyncConfig, SQLAlchemyPlugin
alchemy = SQLAlchemyPlugin(
config=SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite"),
)
app = Litestar(plugins=[alchemy])
```
For a full Litestar example, check [here][litestar-example]
#### Flask
Flask Example
```python
from flask import Flask
from advanced_alchemy.extensions.flask import AdvancedAlchemy, SQLAlchemySyncConfig
app = Flask(__name__)
alchemy = AdvancedAlchemy(
config=SQLAlchemySyncConfig(connection_string="duckdb:///:memory:"), app=app,
)
```
For a full Flask example, see [here][flask-example]
#### FastAPI
FastAPI Example
```python
from advanced_alchemy.extensions.fastapi import AdvancedAlchemy, SQLAlchemyAsyncConfig
from fastapi import FastAPI
app = FastAPI()
alchemy = AdvancedAlchemy(
config=SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite"), app=app,
)
```
For a full FastAPI example with optional CLI integration, see [here][fastapi-example]
#### Starlette
Pre-built Example Apps
```python
from advanced_alchemy.extensions.starlette import AdvancedAlchemy, SQLAlchemyAsyncConfig
from starlette.applications import Starlette
app = Starlette()
alchemy = AdvancedAlchemy(
config=SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite"), app=app,
)
```
#### Sanic
Pre-built Example Apps
```python
from sanic import Sanic
from sanic_ext import Extend
from advanced_alchemy.extensions.sanic import AdvancedAlchemy, SQLAlchemyAsyncConfig
app = Sanic("AlchemySanicApp")
alchemy = AdvancedAlchemy(
sqlalchemy_config=SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite"),
)
Extend.register(alchemy)
```
## Contributing
All [Litestar Organization][litestar-org] projects will always be a community-centered, available for contributions of any size.
Before contributing, please review the [contribution guide][contributing].
If you have any questions, reach out to us on [Discord][discord], our org-wide [GitHub discussions][litestar-discussions] page,
or the [project-specific GitHub discussions page][project-discussions].
[litestar-org]: https://github.com/litestar-org
[contributing]: https://docs.advanced-alchemy.litestar.dev/latest/contribution-guide.html
[discord]: https://discord.gg/litestar
[litestar-discussions]: https://github.com/orgs/litestar-org/discussions
[project-discussions]: https://github.com/litestar-org/advanced-alchemy/discussions
[project-docs]: https://docs.advanced-alchemy.litestar.dev
[install-guide]: https://docs.advanced-alchemy.litestar.dev/latest/#installation
[fastapi-example]: https://github.com/litestar-org/advanced-alchemy/blob/main/examples/fastapi_service.py
[flask-example]: https://github.com/litestar-org/advanced-alchemy/blob/main/examples/flask/flask_services.py
[litestar-example]: https://github.com/litestar-org/advanced-alchemy/blob/main/examples/litestar.py
[standalone-example]: https://github.com/litestar-org/advanced-alchemy/blob/main/examples/standalone.py
python-advanced-alchemy-1.0.1/advanced_alchemy/ 0000775 0000000 0000000 00000000000 14766637146 0021510 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/__init__.py 0000664 0000000 0000000 00000000570 14766637146 0023623 0 ustar 00root root 0000000 0000000 from advanced_alchemy import (
alembic,
base,
cli,
config,
exceptions,
extensions,
filters,
mixins,
operations,
service,
types,
utils,
)
__all__ = (
"alembic",
"base",
"cli",
"config",
"exceptions",
"extensions",
"filters",
"mixins",
"operations",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/__main__.py 0000664 0000000 0000000 00000000366 14766637146 0023607 0 ustar 00root root 0000000 0000000 from advanced_alchemy.cli import add_migration_commands as build_cli_interface
def run_cli() -> None: # pragma: no cover
"""Advanced Alchemy CLI"""
build_cli_interface()()
if __name__ == "__main__": # pragma: no cover
run_cli()
python-advanced-alchemy-1.0.1/advanced_alchemy/__metadata__.py 0000664 0000000 0000000 00000001067 14766637146 0024442 0 ustar 00root root 0000000 0000000 """Metadata for the Project."""
from importlib.metadata import PackageNotFoundError, metadata, version # pragma: no cover
__all__ = ("__project__", "__version__") # pragma: no cover
try: # pragma: no cover
__version__ = version("advanced_alchemy")
"""Version of the project."""
__project__ = metadata("advanced_alchemy")["Name"]
"""Name of the project."""
except PackageNotFoundError: # pragma: no cover
__version__ = "0.0.1"
__project__ = "Advanced Alchemy"
finally: # pragma: no cover
del version, PackageNotFoundError, metadata
python-advanced-alchemy-1.0.1/advanced_alchemy/_listeners.py 0000664 0000000 0000000 00000001314 14766637146 0024230 0 ustar 00root root 0000000 0000000 """Application ORM configuration."""
import datetime
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from sqlalchemy.orm import Session
def touch_updated_timestamp(session: "Session", *_: Any) -> None:
"""Set timestamp on update.
Called from SQLAlchemy's
:meth:`before_flush ` event to bump the ``updated``
timestamp on modified instances.
Args:
session: The sync :class:`Session ` instance that underlies the async
session.
"""
for instance in session.dirty:
if hasattr(instance, "updated_at"):
instance.updated_at = datetime.datetime.now(datetime.timezone.utc)
python-advanced-alchemy-1.0.1/advanced_alchemy/_serialization.py 0000664 0000000 0000000 00000005014 14766637146 0025076 0 ustar 00root root 0000000 0000000 import datetime
import enum
from typing import Any
from typing_extensions import runtime_checkable
try:
from pydantic import BaseModel # type: ignore # noqa: PGH003
PYDANTIC_INSTALLED = True
except ImportError:
from typing import ClassVar, Protocol
@runtime_checkable
class BaseModel(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""
model_fields: ClassVar[dict[str, Any]]
def model_dump_json(self, *args: Any, **kwargs: Any) -> str:
"""Placeholder"""
return ""
PYDANTIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
def _type_to_string(value: Any) -> str: # pragma: no cover
if isinstance(value, datetime.datetime):
return convert_datetime_to_gmt_iso(value)
if isinstance(value, datetime.date):
return convert_date_to_iso(value)
if isinstance(value, enum.Enum):
return str(value.value)
if PYDANTIC_INSTALLED and isinstance(value, BaseModel):
return value.model_dump_json()
try:
val = str(value)
except Exception as exc:
raise TypeError from exc
return val
try:
from msgspec.json import Decoder, Encoder
encoder, decoder = Encoder(enc_hook=_type_to_string), Decoder()
decode_json = decoder.decode
def encode_json(data: Any) -> str: # pragma: no cover
return encoder.encode(data).decode("utf-8")
except ImportError:
try:
from orjson import OPT_NAIVE_UTC, OPT_SERIALIZE_NUMPY, OPT_SERIALIZE_UUID
from orjson import dumps as _encode_json
from orjson import loads as decode_json # type: ignore[no-redef,assignment]
def encode_json(data: Any) -> str: # pragma: no cover
return _encode_json(
data, default=_type_to_string, option=OPT_SERIALIZE_NUMPY | OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
).decode("utf-8") # type: ignore[no-any-return]
except ImportError:
from json import dumps as encode_json # type: ignore[assignment] # noqa: F401
from json import loads as decode_json # type: ignore[assignment] # noqa: F401
def convert_datetime_to_gmt_iso(dt: datetime.datetime) -> str: # pragma: no cover
"""Handle datetime serialization for nested timestamps."""
if not dt.tzinfo:
dt = dt.replace(tzinfo=datetime.timezone.utc)
return dt.isoformat().replace("+00:00", "Z")
def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover
"""Handle datetime serialization for nested timestamps."""
return dt.isoformat()
python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/ 0000775 0000000 0000000 00000000000 14766637146 0023104 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/__init__.py 0000664 0000000 0000000 00000000000 14766637146 0025203 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/commands.py 0000664 0000000 0000000 00000032323 14766637146 0025262 0 ustar 00root root 0000000 0000000 import sys
from typing import TYPE_CHECKING, Any, Optional, TextIO, Union
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
from alembic import command as migration_command
from alembic.config import Config as _AlembicCommandConfig
from alembic.ddl.impl import DefaultImpl
if TYPE_CHECKING:
import os
from argparse import Namespace
from collections.abc import Mapping
from pathlib import Path
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from advanced_alchemy.config.sync import SQLAlchemySyncConfig
from alembic.runtime.environment import ProcessRevisionDirectiveFn
from alembic.script.base import Script
class AlembicSpannerImpl(DefaultImpl):
"""Alembic implementation for Spanner."""
__dialect__ = "spanner+spanner"
class AlembicDuckDBImpl(DefaultImpl):
"""Alembic implementation for DuckDB."""
__dialect__ = "duckdb"
class AlembicCommandConfig(_AlembicCommandConfig):
def __init__(
self,
engine: "Union[Engine, AsyncEngine]",
version_table_name: str,
bind_key: "Optional[str]" = None,
file_: "Union[str, os.PathLike[str], None]" = None,
ini_section: str = "alembic",
output_buffer: "Optional[TextIO]" = None,
stdout: "TextIO" = sys.stdout,
cmd_opts: "Optional[Namespace]" = None,
config_args: "Optional[Mapping[str, Any]]" = None,
attributes: "Optional[dict[str, Any]]" = None,
template_directory: "Optional[Path]" = None,
version_table_schema: "Optional[str]" = None,
render_as_batch: bool = True,
compare_type: bool = False,
user_module_prefix: "Optional[str]" = "sa.",
) -> None:
"""Initialize the AlembicCommandConfig.
Args:
engine (sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine): The SQLAlchemy engine instance.
version_table_name (str): The name of the version table.
bind_key (str | None): The bind key for the metadata.
file_ (str | os.PathLike[str] | None): The file path for the alembic configuration.
ini_section (str): The ini section name.
output_buffer (typing.TextIO | None): The output buffer for alembic commands.
stdout (typing.TextIO): The standard output stream.
cmd_opts (argparse.Namespace | None): Command line options.
config_args (typing.Mapping[str, typing.Any] | None): Additional configuration arguments.
attributes (dict[str, typing.Any] | None): Additional attributes for the configuration.
template_directory (pathlib.Path | None): The directory for alembic templates.
version_table_schema (str | None): The schema for the version table.
render_as_batch (bool): Whether to render migrations as batch.
compare_type (bool): Whether to compare types during migrations.
user_module_prefix (str | None): The prefix for user modules.
"""
self.template_directory = template_directory
self.bind_key = bind_key
self.version_table_name = version_table_name
self.version_table_pk = engine.dialect.name != "spanner+spanner"
self.version_table_schema = version_table_schema
self.render_as_batch = render_as_batch
self.user_module_prefix = user_module_prefix
self.compare_type = compare_type
self.engine = engine
self.db_url = engine.url.render_as_string(hide_password=False)
if config_args is None:
config_args = {}
super().__init__(file_, ini_section, output_buffer, stdout, cmd_opts, config_args, attributes)
def get_template_directory(self) -> str:
"""Return the directory where Alembic setup templates are found.
This method is used by the alembic ``init`` and ``list_templates``
commands.
"""
if self.template_directory is not None:
return str(self.template_directory)
return super().get_template_directory()
class AlembicCommands:
def __init__(self, sqlalchemy_config: "Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]") -> None:
"""Initialize the AlembicCommands.
Args:
sqlalchemy_config (SQLAlchemyAsyncConfig | SQLAlchemySyncConfig): The SQLAlchemy configuration.
"""
self.sqlalchemy_config = sqlalchemy_config
self.config = self._get_alembic_command_config()
def upgrade(
self,
revision: str = "head",
sql: bool = False,
tag: "Optional[str]" = None,
) -> None:
"""Upgrade the database to a specified revision.
Args:
revision (str): The target revision to upgrade to.
sql (bool): If True, generate SQL script instead of applying changes.
tag (str | None): An optional tag to apply to the migration.
"""
return migration_command.upgrade(config=self.config, revision=revision, tag=tag, sql=sql)
def downgrade(
self,
revision: str = "head",
sql: bool = False,
tag: "Optional[str]" = None,
) -> None:
"""Downgrade the database to a specified revision.
Args:
revision (str): The target revision to downgrade to.
sql (bool): If True, generate SQL script instead of applying changes.
tag (str | None): An optional tag to apply to the migration.
"""
return migration_command.downgrade(config=self.config, revision=revision, tag=tag, sql=sql)
def check(self) -> None:
"""Check for pending upgrade operations.
This method checks if there are any pending upgrade operations
that need to be applied to the database.
"""
return migration_command.check(config=self.config)
def current(self, verbose: bool = False) -> None:
"""Display the current revision of the database.
Args:
verbose (bool): If True, display detailed information.
"""
return migration_command.current(self.config, verbose=verbose)
def edit(self, revision: str) -> None:
"""Edit the revision script using the system editor.
Args:
revision (str): The revision identifier to edit.
"""
return migration_command.edit(config=self.config, rev=revision)
def ensure_version(self, sql: bool = False) -> None:
"""Ensure the alembic version table exists.
Args:
sql (bool): If True, generate SQL script instead of applying changes.
"""
return migration_command.ensure_version(config=self.config, sql=sql)
def heads(self, verbose: bool = False, resolve_dependencies: bool = False) -> None:
"""Show current available heads in the script directory.
Args:
verbose (bool): If True, display detailed information.
resolve_dependencies (bool): If True, resolve dependencies between heads.
"""
return migration_command.heads(config=self.config, verbose=verbose, resolve_dependencies=resolve_dependencies)
def history(
self,
rev_range: "Optional[str]" = None,
verbose: bool = False,
indicate_current: bool = False,
) -> None:
"""List changeset scripts in chronological order.
Args:
rev_range (str | None): The revision range to display.
verbose (bool): If True, display detailed information.
indicate_current (bool): If True, indicate the current revision.
"""
return migration_command.history(
config=self.config,
rev_range=rev_range,
verbose=verbose,
indicate_current=indicate_current,
)
def merge(
self,
revisions: str,
message: "Optional[str]" = None,
branch_label: "Optional[str]" = None,
rev_id: "Optional[str]" = None,
) -> "Union[Script, None]":
"""Merge two revisions together.
Args:
revisions (str): The revisions to merge.
message (str | None): The commit message for the merge.
branch_label (str | None): The branch label for the merge.
rev_id (str | None): The revision ID for the merge.
Returns:
Script | None: The resulting script from the merge.
"""
return migration_command.merge(
config=self.config,
revisions=revisions,
message=message,
branch_label=branch_label,
rev_id=rev_id,
)
def revision(
self,
message: "Optional[str]" = None,
autogenerate: bool = False,
sql: bool = False,
head: str = "head",
splice: bool = False,
branch_label: "Optional[str]" = None,
version_path: "Optional[str]" = None,
rev_id: "Optional[str]" = None,
depends_on: "Optional[str]" = None,
process_revision_directives: "Optional[ProcessRevisionDirectiveFn]" = None,
) -> "Union[Script, list[Optional[Script]], None]":
"""Create a new revision file.
Args:
message (str | None): The commit message for the revision.
autogenerate (bool): If True, autogenerate the revision script.
sql (bool): If True, generate SQL script instead of applying changes.
head (str): The head revision to base the new revision on.
splice (bool): If True, create a splice revision.
branch_label (str | None): The branch label for the revision.
version_path (str | None): The path for the version file.
rev_id (str | None): The revision ID for the new revision.
depends_on (str | None): The revisions this revision depends on.
process_revision_directives (ProcessRevisionDirectiveFn | None): A function to process revision directives.
Returns:
Script | List[Script | None] | None: The resulting script(s) from the revision.
"""
return migration_command.revision(
config=self.config,
message=message,
autogenerate=autogenerate,
sql=sql,
head=head,
splice=splice,
branch_label=branch_label,
version_path=version_path,
rev_id=rev_id,
depends_on=depends_on,
process_revision_directives=process_revision_directives,
)
def show(
self,
rev: Any,
) -> None:
"""Show the revision(s) denoted by the given symbol.
Args:
rev (Any): The revision symbol to display.
"""
return migration_command.show(config=self.config, rev=rev)
def init(
self,
directory: str,
package: bool = False,
multidb: bool = False,
) -> None:
"""Initialize a new scripts directory.
Args:
directory (str): The directory to initialize.
package (bool): If True, create a package.
multidb (bool): If True, initialize for multiple databases.
"""
template = "sync"
if isinstance(self.sqlalchemy_config, SQLAlchemyAsyncConfig):
template = "asyncio"
if multidb:
template = f"{template}-multidb"
msg = "Multi database Alembic configurations are not currently supported."
raise NotImplementedError(msg)
return migration_command.init(
config=self.config,
directory=directory,
template=template,
package=package,
)
def list_templates(self) -> None:
"""List available templates.
This method lists all available templates for alembic initialization.
"""
return migration_command.list_templates(config=self.config)
def stamp(
self,
revision: str,
sql: bool = False,
tag: "Optional[str]" = None,
purge: bool = False,
) -> None:
"""Stamp the revision table with the given revision.
Args:
revision (str): The revision to stamp.
sql (bool): If True, generate SQL script instead of applying changes.
tag (str | None): An optional tag to apply to the migration.
purge (bool): If True, purge the revision history.
"""
return migration_command.stamp(config=self.config, revision=revision, sql=sql, tag=tag, purge=purge)
def _get_alembic_command_config(self) -> "AlembicCommandConfig":
"""Get the Alembic command configuration.
Returns:
AlembicCommandConfig: The configuration for Alembic commands.
"""
kwargs: dict[str, Any] = {}
if self.sqlalchemy_config.alembic_config.script_config:
kwargs["file_"] = self.sqlalchemy_config.alembic_config.script_config
if self.sqlalchemy_config.alembic_config.template_path:
kwargs["template_directory"] = self.sqlalchemy_config.alembic_config.template_path
kwargs.update(
{
"engine": self.sqlalchemy_config.get_engine(),
"version_table_name": self.sqlalchemy_config.alembic_config.version_table_name,
},
)
self.config = AlembicCommandConfig(**kwargs)
self.config.set_main_option("script_location", self.sqlalchemy_config.alembic_config.script_location)
return self.config
python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/ 0000775 0000000 0000000 00000000000 14766637146 0025102 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/__init__.py 0000664 0000000 0000000 00000000000 14766637146 0027201 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/asyncio/ 0000775 0000000 0000000 00000000000 14766637146 0026547 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/asyncio/__init__.py 0000664 0000000 0000000 00000000000 14766637146 0030646 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/asyncio/alembic.ini.mako 0000664 0000000 0000000 00000005001 14766637146 0031566 0 ustar 00root root 0000000 0000000 # Advanced Alchemy Alembic Asyncio Config
[alembic]
prepend_sys_path = src:.
# path to migration scripts
script_location = migrations
# template used to generate migration files
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(slug)s_%%(rev)s
# This is not required to be set when running through `advanced_alchemy`
# sqlalchemy.url = driver://user:pass@localhost/dbname
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone = UTC
# max length of characters to apply to the
# "slug" field
truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
output_encoding = utf-8
# [post_write_hooks]
# This section defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner,
# against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/asyncio/env.py 0000664 0000000 0000000 00000010752 14766637146 0027716 0 ustar 00root root 0000000 0000000 import asyncio
from typing import TYPE_CHECKING, cast
from sqlalchemy import Column, pool
from sqlalchemy.ext.asyncio import AsyncEngine, async_engine_from_config
from advanced_alchemy.base import metadata_registry
from alembic import context
from alembic.autogenerate import rewriter
from alembic.operations import ops
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from advanced_alchemy.alembic.commands import AlembicCommandConfig
from alembic.runtime.environment import EnvironmentContext
__all__ = ("do_run_migrations", "run_migrations_offline", "run_migrations_online")
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config: "AlembicCommandConfig" = context.config # type: ignore # noqa: PGH003
writer = rewriter.Rewriter()
@writer.rewrites(ops.CreateTableOp)
def order_columns(
context: "EnvironmentContext", # noqa: ARG001
revision: tuple[str, ...], # noqa: ARG001
op: ops.CreateTableOp,
) -> ops.CreateTableOp:
"""Orders ID first and the audit columns at the end."""
special_names = {"id": -100, "sa_orm_sentinel": 3001, "created_at": 3002, "updated_at": 3003}
cols_by_key = [ # pyright: ignore[reportUnknownVariableType]
(
special_names.get(col.key, index) if isinstance(col, Column) else 2000,
col.copy(), # type: ignore[attr-defined]
)
for index, col in enumerate(op.columns)
]
columns = [col for _, col in sorted(cols_by_key, key=lambda entry: entry[0])] # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType,reportUnknownLambdaType]
return ops.CreateTableOp(
op.table_name,
columns, # pyright: ignore[reportUnknownArgumentType]
schema=op.schema,
# TODO: Remove when https://github.com/sqlalchemy/alembic/issues/1193 is fixed # noqa: FIX002
_namespace_metadata=op._namespace_metadata, # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
**op.kw,
)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
context.configure(
url=config.db_url,
target_metadata=metadata_registry.get(config.bind_key),
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: "Connection") -> None:
"""Run migrations."""
context.configure(
connection=connection,
target_metadata=metadata_registry.get(config.bind_key),
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine and associate a
connection with the context.
"""
configuration = config.get_section(config.config_ini_section) or {}
configuration["sqlalchemy.url"] = config.db_url
connectable = cast(
"AsyncEngine",
config.engine
or async_engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
),
)
if connectable is None: # pyright: ignore[reportUnnecessaryComparison]
msg = "Could not get engine from config. Please ensure your `alembic.ini` according to the official Alembic documentation."
raise RuntimeError(
msg,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())
python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/asyncio/script.py.mako 0000664 0000000 0000000 00000003364 14766637146 0031361 0 ustar 00root root 0000000 0000000 # type: ignore
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
import warnings
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC
from sqlalchemy import Text # noqa: F401
${imports if imports else ""}
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"]
sa.GUID = GUID
sa.DateTimeUTC = DateTimeUTC
sa.ORA_JSONB = ORA_JSONB
sa.EncryptedString = EncryptedString
sa.EncryptedText = EncryptedText
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
schema_upgrades()
data_upgrades()
def downgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
data_downgrades()
schema_downgrades()
def schema_upgrades() -> None:
"""schema upgrade migrations go here."""
${upgrades if upgrades else "pass"}
def schema_downgrades() -> None:
"""schema downgrade migrations go here."""
${downgrades if downgrades else "pass"}
def data_upgrades() -> None:
"""Add any optional data upgrade migrations here!"""
def data_downgrades() -> None:
"""Add any optional data downgrade migrations here!"""
python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/sync/ 0000775 0000000 0000000 00000000000 14766637146 0026056 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/sync/__init__.py 0000664 0000000 0000000 00000000000 14766637146 0030155 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/sync/alembic.ini.mako 0000664 0000000 0000000 00000005002 14766637146 0031076 0 ustar 00root root 0000000 0000000 # Advanced Alchemy Alembic Sync Config
[alembic]
prepend_sys_path = src:.
# path to migration scripts
script_location = migrations
# template used to generate migration files
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(slug)s_%%(rev)s
# This is not required to be set when running through the `advanced_alchemy`
# sqlalchemy.url = driver://user:pass@localhost/dbname
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone = UTC
# max length of characters to apply to the
# "slug" field
truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
output_encoding = utf-8
# [post_write_hooks]
# This section defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner,
# against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/sync/env.py 0000664 0000000 0000000 00000010615 14766637146 0027223 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, cast
from sqlalchemy import Column, Engine, engine_from_config, pool
from advanced_alchemy.base import metadata_registry
from alembic import context
from alembic.autogenerate import rewriter
from alembic.operations import ops
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from advanced_alchemy.alembic.commands import AlembicCommandConfig
from alembic.runtime.environment import EnvironmentContext
__all__ = ["do_run_migrations", "run_migrations_offline", "run_migrations_online"]
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config: "AlembicCommandConfig" = context.config # type: ignore # noqa: PGH003
writer = rewriter.Rewriter()
@writer.rewrites(ops.CreateTableOp)
def order_columns(
context: "EnvironmentContext", # noqa: ARG001
revision: tuple[str, ...], # noqa: ARG001
op: ops.CreateTableOp,
) -> ops.CreateTableOp:
"""Orders ID first and the audit columns at the end."""
special_names = {"id": -100, "sa_orm_sentinel": 3001, "created_at": 3002, "updated_at": 3003}
cols_by_key = [ # pyright: ignore[reportUnknownVariableType]
(
special_names.get(col.key, index) if isinstance(col, Column) else 2000,
col.copy(), # type: ignore[attr-defined]
)
for index, col in enumerate(op.columns)
]
columns = [col for _, col in sorted(cols_by_key, key=lambda entry: entry[0])] # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType,reportUnknownLambdaType]
return ops.CreateTableOp(
op.table_name,
columns, # pyright: ignore[reportUnknownArgumentType]
schema=op.schema,
# TODO: Remove when https://github.com/sqlalchemy/alembic/issues/1193 is fixed # noqa: FIX002
_namespace_metadata=op._namespace_metadata, # noqa: SLF001 # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
**op.kw,
)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
context.configure(
url=config.db_url,
target_metadata=metadata_registry.get(config.bind_key),
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: "Connection") -> None:
"""Run migrations."""
context.configure(
connection=connection,
target_metadata=metadata_registry.get(config.bind_key),
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine and associate a
connection with the context.
"""
configuration = config.get_section(config.config_ini_section) or {}
configuration["sqlalchemy.url"] = config.db_url
connectable = cast(
"Engine",
config.engine
or engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
),
)
if connectable is None: # pyright: ignore[reportUnnecessaryComparison]
msg = "Could not get engine from config. Please ensure your `alembic.ini` according to the official Alembic documentation."
raise RuntimeError(
msg,
)
with connectable.connect() as connection:
do_run_migrations(connection=connection)
connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/templates/sync/script.py.mako 0000664 0000000 0000000 00000003477 14766637146 0030675 0 ustar 00root root 0000000 0000000 # type: ignore
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
import warnings
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC
from sqlalchemy import Text # noqa: F401
${imports if imports else ""}
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"]
sa.GUID = GUID
sa.DateTimeUTC = DateTimeUTC
sa.ORA_JSONB = ORA_JSONB
sa.EncryptedString = EncryptedString
sa.EncryptedText = EncryptedText
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: str | None = ${repr(down_revision)}
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
depends_on: str | Sequence[str] | None = ${repr(depends_on)}
def upgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
schema_upgrades()
data_upgrades()
def downgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
data_downgrades()
schema_downgrades()
def schema_upgrades() -> None:
"""schema upgrade migrations go here."""
${upgrades if upgrades else "pass"}
def schema_downgrades() -> None:
"""schema downgrade migrations go here."""
${downgrades if downgrades else "pass"}
def data_upgrades() -> None:
"""Add any optional data upgrade migrations here!"""
def data_downgrades() -> None:
"""Add any optional data downgrade migrations here!"""
python-advanced-alchemy-1.0.1/advanced_alchemy/alembic/utils.py 0000664 0000000 0000000 00000010174 14766637146 0024621 0 ustar 00root root 0000000 0000000 from contextlib import AbstractAsyncContextManager, AbstractContextManager
from pathlib import Path
from typing import TYPE_CHECKING, Union
from litestar.cli._utils import console
from sqlalchemy import Engine, MetaData, Table
from typing_extensions import TypeIs
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import DeclarativeBase, Session
__all__ = ("drop_all", "dump_tables")
async def drop_all(engine: "Union[AsyncEngine, Engine]", version_table_name: str, metadata: MetaData) -> None:
def _is_sync(engine: "Union[Engine, AsyncEngine]") -> "TypeIs[Engine]":
return isinstance(engine, Engine)
def _drop_tables_sync(engine: Engine) -> None:
console.rule("[bold red]Connecting to database backend.")
with engine.begin() as db:
console.rule("[bold red]Dropping the db", align="left")
metadata.drop_all(db)
console.rule("[bold red]Dropping the version table", align="left")
Table(version_table_name, metadata).drop(db, checkfirst=True)
console.rule("[bold yellow]Successfully dropped all objects", align="left")
async def _drop_tables_async(engine: "AsyncEngine") -> None:
console.rule("[bold red]Connecting to database backend.", align="left")
async with engine.begin() as db:
console.rule("[bold red]Dropping the db", align="left")
await db.run_sync(metadata.drop_all)
console.rule("[bold red]Dropping the version table", align="left")
await db.run_sync(Table(version_table_name, metadata).drop, checkfirst=True)
console.rule("[bold yellow]Successfully dropped all objects", align="left")
if _is_sync(engine):
return _drop_tables_sync(engine)
return await _drop_tables_async(engine)
async def dump_tables(
dump_dir: Path,
session: "Union[AbstractContextManager[Session], AbstractAsyncContextManager[AsyncSession]]",
models: "list[type[DeclarativeBase]]",
) -> None:
from types import new_class
from advanced_alchemy._serialization import encode_json
def _is_sync(
session: "Union[AbstractAsyncContextManager[AsyncSession], AbstractContextManager[Session]]",
) -> "TypeIs[AbstractContextManager[Session]]":
return isinstance(session, AbstractContextManager)
def _dump_table_sync(session: "AbstractContextManager[Session]") -> None:
from advanced_alchemy.repository import SQLAlchemySyncRepository
with session as _session:
for model in models:
json_path = dump_dir / f"{model.__tablename__}.json"
console.rule(
f"[yellow bold]Dumping table '{json_path.stem}' to '{json_path}'",
style="yellow",
align="left",
)
repo = new_class(
"repo",
(SQLAlchemySyncRepository,),
exec_body=lambda ns, model=model: ns.setdefault("model_type", model), # type: ignore[misc]
)
json_path.write_text(encode_json([row.to_dict() for row in repo(session=_session).list()]))
async def _dump_table_async(session: "AbstractAsyncContextManager[AsyncSession]") -> None:
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
async with session as _session:
for model in models:
json_path = dump_dir / f"{model.__tablename__}.json"
console.rule(
f"[yellow bold]Dumping table '{json_path.stem}' to '{json_path}'",
style="yellow",
align="left",
)
repo = new_class(
"repo",
(SQLAlchemyAsyncRepository,),
exec_body=lambda ns, model=model: ns.setdefault("model_type", model), # type: ignore[misc]
)
json_path.write_text(encode_json([row.to_dict() for row in await repo(session=_session).list()]))
dump_dir.mkdir(exist_ok=True)
if _is_sync(session):
return _dump_table_sync(session)
return await _dump_table_async(session)
python-advanced-alchemy-1.0.1/advanced_alchemy/base.py 0000664 0000000 0000000 00000036276 14766637146 0023012 0 ustar 00root root 0000000 0000000 # ruff: noqa: TC004
"""Common base classes for SQLAlchemy declarative models."""
import contextlib
import datetime
import re
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, cast, runtime_checkable
from uuid import UUID
from sqlalchemy import Date, MetaData, String
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import (
DeclarativeBase,
Mapper,
declared_attr,
)
from sqlalchemy.orm import (
registry as SQLAlchemyRegistry, # noqa: N812
)
from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType # pyright: ignore[reportPrivateUsage]
from sqlalchemy.types import TypeEngine
from typing_extensions import Self, TypeVar
from advanced_alchemy.mixins import (
AuditColumns,
BigIntPrimaryKey,
NanoIDPrimaryKey,
UUIDPrimaryKey,
UUIDv6PrimaryKey,
UUIDv7PrimaryKey,
)
from advanced_alchemy.types import GUID, DateTimeUTC, JsonB
from advanced_alchemy.utils.dataclass import DataclassProtocol
if TYPE_CHECKING:
from sqlalchemy.sql import FromClause
from sqlalchemy.sql.schema import (
_NamingSchemaParameter as NamingSchemaParameter, # pyright: ignore[reportPrivateUsage]
)
__all__ = (
"AdvancedDeclarativeBase",
"BasicAttributes",
"BigIntAuditBase",
"BigIntBase",
"BigIntBaseT",
"CommonTableAttributes",
"ModelProtocol",
"NanoIDAuditBase",
"NanoIDBase",
"NanoIDBaseT",
"SQLQuery",
"TableArgsType",
"UUIDAuditBase",
"UUIDBase",
"UUIDBaseT",
"UUIDv6AuditBase",
"UUIDv6Base",
"UUIDv6BaseT",
"UUIDv7AuditBase",
"UUIDv7Base",
"UUIDv7BaseT",
"convention",
"create_registry",
"merge_table_arguments",
"metadata_registry",
"orm_registry",
"table_name_regexp",
)
UUIDBaseT = TypeVar("UUIDBaseT", bound="UUIDBase")
"""Type variable for :class:`UUIDBase`."""
BigIntBaseT = TypeVar("BigIntBaseT", bound="BigIntBase")
"""Type variable for :class:`BigIntBase`."""
UUIDv6BaseT = TypeVar("UUIDv6BaseT", bound="UUIDv6Base")
"""Type variable for :class:`UUIDv6Base`."""
UUIDv7BaseT = TypeVar("UUIDv7BaseT", bound="UUIDv7Base")
"""Type variable for :class:`UUIDv7Base`."""
NanoIDBaseT = TypeVar("NanoIDBaseT", bound="NanoIDBase")
"""Type variable for :class:`NanoIDBase`."""
convention: "NamingSchemaParameter" = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
"""Templates for automated constraint name generation."""
table_name_regexp = re.compile("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
"""Regular expression for table name"""
def merge_table_arguments(cls: type[DeclarativeBase], table_args: Optional[TableArgsType] = None) -> TableArgsType:
"""Merge Table Arguments.
This function helps merge table arguments when using mixins that include their own table args,
making it easier to append additional information such as comments or constraints to the model.
Args:
cls (type[:class:`sqlalchemy.orm.DeclarativeBase`]): The model that will get the table args.
table_args (:class:`TableArgsType`, optional): Additional information to add to table_args.
Returns:
:class:`TableArgsType`: Merged table arguments.
"""
args: list[Any] = []
kwargs: dict[str, Any] = {}
mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in cls.__bases__) # pyright: ignore[reportUnknownParameter,reportUnknownArgumentType,reportArgumentType]
for arg_to_merge in (*mixin_table_args, table_args):
if arg_to_merge:
if isinstance(arg_to_merge, tuple):
last_positional_arg = arg_to_merge[-1] # pyright: ignore[reportUnknownVariableType]
args.extend(arg_to_merge[:-1]) # pyright: ignore[reportUnknownArgumentType]
if isinstance(last_positional_arg, dict):
kwargs.update(last_positional_arg) # pyright: ignore[reportUnknownArgumentType]
else:
args.append(last_positional_arg)
else:
kwargs.update(arg_to_merge)
if args:
if kwargs:
return (*args, kwargs)
return tuple(args)
return kwargs
@runtime_checkable
class ModelProtocol(Protocol):
"""The base SQLAlchemy model protocol.
Attributes:
__table__ (:class:`sqlalchemy.sql.FromClause`): The table associated with the model.
__mapper__ (:class:`sqlalchemy.orm.Mapper`): The mapper for the model.
__name__ (str): The name of the model.
"""
if TYPE_CHECKING:
__table__: FromClause
__mapper__: Mapper[Any]
__name__: str
def to_dict(self, exclude: Optional[set[str]] = None) -> dict[str, Any]:
"""Convert model to dictionary.
Returns:
Dict[str, Any]: A dict representation of the model
"""
...
class BasicAttributes:
"""Basic attributes for SQLAlchemy tables and queries.
Provides a method to convert the model to a dictionary representation.
Methods:
to_dict: Converts the model to a dictionary, excluding specified fields. :no-index:
"""
if TYPE_CHECKING:
__name__: str
__table__: FromClause
__mapper__: Mapper[Any]
def to_dict(self, exclude: Optional[set[str]] = None) -> dict[str, Any]:
"""Convert model to dictionary.
Returns:
Dict[str, Any]: A dict representation of the model
"""
exclude = {"sa_orm_sentinel", "_sentinel"}.union(self._sa_instance_state.unloaded).union(exclude or []) # type: ignore[attr-defined]
return {
field: getattr(self, field)
for field in self.__mapper__.columns.keys() # noqa: SIM118
if field not in exclude
}
class CommonTableAttributes(BasicAttributes):
"""Common attributes for SQLAlchemy tables.
Inherits from :class:`BasicAttributes` and provides a mechanism to infer table names from class names.
Attributes:
__tablename__ (str): The inferred table name.
"""
if TYPE_CHECKING:
__tablename__: str
else:
@declared_attr.directive
def __tablename__(cls) -> str:
"""Infer table name from class name."""
return table_name_regexp.sub(r"_\1", cls.__name__).lower()
def create_registry(
custom_annotation_map: Optional[dict[Any, Union[type[TypeEngine[Any]], TypeEngine[Any]]]] = None,
) -> SQLAlchemyRegistry:
"""Create a new SQLAlchemy registry.
Args:
custom_annotation_map (dict, optional): Custom type annotations to use for the registry.
Returns:
:class:`sqlalchemy.orm.registry`: A new SQLAlchemy registry with the specified type annotations.
"""
import uuid as core_uuid
meta = MetaData(naming_convention=convention)
type_annotation_map: dict[Any, Union[type[TypeEngine[Any]], TypeEngine[Any]]] = {
UUID: GUID,
core_uuid.UUID: GUID,
datetime.datetime: DateTimeUTC,
datetime.date: Date,
dict: JsonB,
dict[str, Any]: JsonB,
dict[str, str]: JsonB,
DataclassProtocol: JsonB,
}
with contextlib.suppress(ImportError):
from pydantic import AnyHttpUrl, AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, Json
type_annotation_map.update(
{
EmailStr: String,
AnyUrl: String,
AnyHttpUrl: String,
Json: JsonB,
IPvAnyAddress: String,
IPvAnyInterface: String,
IPvAnyNetwork: String,
}
)
with contextlib.suppress(ImportError):
from msgspec import Struct
type_annotation_map[Struct] = JsonB
if custom_annotation_map is not None:
type_annotation_map.update(custom_annotation_map)
return SQLAlchemyRegistry(metadata=meta, type_annotation_map=type_annotation_map)
orm_registry = create_registry()
class MetadataRegistry:
"""A registry for metadata.
Provides methods to get and set metadata for different bind keys.
Methods:
get: Retrieves the metadata for a given bind key.
set: Sets the metadata for a given bind key.
"""
_instance: Optional["MetadataRegistry"] = None
_registry: dict[Union[str, None], MetaData] = {None: orm_registry.metadata}
def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cast("Self", cls._instance)
def get(self, bind_key: Optional[str] = None) -> MetaData:
"""Get the metadata for the given bind key."""
return self._registry.setdefault(bind_key, MetaData(naming_convention=convention))
def set(self, bind_key: Optional[str], metadata: MetaData) -> None:
"""Set the metadata for the given bind key."""
self._registry[bind_key] = metadata
def __iter__(self) -> Iterator[Union[str, None]]:
return iter(self._registry)
def __getitem__(self, bind_key: Union[str, None]) -> MetaData:
return self._registry[bind_key]
def __setitem__(self, bind_key: Union[str, None], metadata: MetaData) -> None:
self._registry[bind_key] = metadata
def __contains__(self, bind_key: Union[str, None]) -> bool:
return bind_key in self._registry
metadata_registry = MetadataRegistry()
class AdvancedDeclarativeBase(DeclarativeBase):
"""A subclass of declarative base that allows for overriding of the registry.
Inherits from :class:`sqlalchemy.orm.DeclarativeBase`.
Attributes:
registry (:class:`sqlalchemy.orm.registry`): The registry for the declarative base.
__metadata_registry__ (:class:`~advanced_alchemy.base.MetadataRegistry`): The metadata registry.
__bind_key__ (Optional[:class:`str`]): The bind key for the metadata.
"""
registry = orm_registry
__abstract__ = True
__metadata_registry__: MetadataRegistry = MetadataRegistry()
__bind_key__: Optional[str] = None
def __init_subclass__(cls, **kwargs: Any) -> None:
bind_key = getattr(cls, "__bind_key__", None)
if bind_key is not None:
cls.metadata = cls.__metadata_registry__.get(bind_key)
elif None not in cls.__metadata_registry__ and getattr(cls, "metadata", None) is not None:
cls.__metadata_registry__[None] = cls.metadata
super().__init_subclass__(**kwargs)
class UUIDBase(UUIDPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v4 primary keys.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDPrimaryKey`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDAuditBase(CommonTableAttributes, UUIDPrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v4 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDv6Base(UUIDv6PrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v6 primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.UUIDv6PrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDv6AuditBase(CommonTableAttributes, UUIDv6PrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v6 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDv6PrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDv7Base(UUIDv7PrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v7 primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.UUIDv7PrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class UUIDv7AuditBase(CommonTableAttributes, UUIDv7PrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v7 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDv7PrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class NanoIDBase(NanoIDPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with Nano ID primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.NanoIDPrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class NanoIDAuditBase(CommonTableAttributes, NanoIDPrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with Nano ID primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.NanoIDPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class BigIntBase(BigIntPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with BigInt primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.BigIntPrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class BigIntAuditBase(CommonTableAttributes, BigIntPrimaryKey, AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with BigInt primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.BigIntPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class DefaultBase(CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models. No primary key is added.
.. seealso::
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class SQLQuery(BasicAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy custom mapped objects.
.. seealso::
:class:`BasicAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
__allow_unmapped__ = True
python-advanced-alchemy-1.0.1/advanced_alchemy/cli.py 0000664 0000000 0000000 00000037621 14766637146 0022642 0 ustar 00root root 0000000 0000000 from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union, cast
if TYPE_CHECKING:
from click import Group
from advanced_alchemy.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from alembic.migration import MigrationContext
from alembic.operations.ops import MigrationScript, UpgradeOps
__all__ = ("add_migration_commands", "get_alchemy_group")
def get_alchemy_group() -> "Group":
"""Get the Advanced Alchemy CLI group."""
from advanced_alchemy.exceptions import MissingDependencyError
try:
import rich_click as click
except ImportError:
try:
import click # type: ignore[no-redef]
except ImportError as e:
raise MissingDependencyError(package="click", install_package="cli") from e
@click.group(name="alchemy")
@click.option(
"--config",
help="Dotted path to SQLAlchemy config(s) (e.g. 'myapp.config.alchemy_configs')",
required=True,
type=str,
)
@click.pass_context
def alchemy_group(ctx: "click.Context", config: str) -> None:
"""Advanced Alchemy CLI commands."""
from rich import get_console
from advanced_alchemy.utils import module_loader
console = get_console()
ctx.ensure_object(dict)
try:
config_instance = module_loader.import_string(config)
if isinstance(config_instance, Sequence):
ctx.obj["configs"] = config_instance
else:
ctx.obj["configs"] = [config_instance]
except ImportError as e:
console.print(f"[red]Error loading config: {e}[/]")
ctx.exit(1)
return alchemy_group
def add_migration_commands(database_group: Optional["Group"] = None) -> "Group": # noqa: C901, PLR0915
"""Add migration commands to the database group."""
from advanced_alchemy.exceptions import MissingDependencyError
try:
import rich_click as click
except ImportError:
try:
import click # type: ignore[no-redef]
except ImportError as e:
raise MissingDependencyError(package="click", install_package="cli") from e
from rich import get_console
console = get_console()
if database_group is None:
database_group = get_alchemy_group()
bind_key_option = click.option(
"--bind-key",
help="Specify which SQLAlchemy config to use by bind key",
type=str,
default=None,
)
verbose_option = click.option(
"--verbose",
help="Enable verbose output.",
type=bool,
default=False,
is_flag=True,
)
no_prompt_option = click.option(
"--no-prompt",
help="Do not prompt for confirmation before executing the command.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)
def get_config_by_bind_key(
ctx: "click.Context", bind_key: Optional[str]
) -> "Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]":
"""Get the SQLAlchemy config for the specified bind key."""
configs = ctx.obj["configs"]
if bind_key is None:
return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", configs[0])
for config in configs:
if config.bind_key == bind_key:
return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", config)
console.print(f"[red]No config found for bind key: {bind_key}[/]")
ctx.exit(1) # noqa: RET503
@database_group.command(
name="show-current-revision",
help="Shows the current revision for the database.",
)
@bind_key_option
@verbose_option
def show_database_revision(bind_key: Optional[str], verbose: bool) -> None: # pyright: ignore[reportUnusedFunction]
"""Show current database revision."""
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
console.rule("[yellow]Listing current revision[/]", align="left")
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.current(verbose=verbose)
@database_group.command(
name="downgrade",
help="Downgrade database to a specific revision.",
)
@bind_key_option
@click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True)
@click.option(
"--tag",
help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.",
type=str,
default=None,
)
@no_prompt_option
@click.argument(
"revision",
type=str,
default="-1",
)
def downgrade_database( # pyright: ignore[reportUnusedFunction]
bind_key: Optional[str], revision: str, sql: bool, tag: Optional[str], no_prompt: bool
) -> None:
"""Downgrade the database to the latest revision."""
from rich.prompt import Confirm
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
console.rule("[yellow]Starting database downgrade process[/]", align="left")
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?")
)
if input_confirmed:
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.downgrade(revision=revision, sql=sql, tag=tag)
@database_group.command(
name="upgrade",
help="Upgrade database to a specific revision.",
)
@bind_key_option
@click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True)
@click.option(
"--tag",
help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.",
type=str,
default=None,
)
@no_prompt_option
@click.argument(
"revision",
type=str,
default="head",
)
def upgrade_database( # pyright: ignore[reportUnusedFunction]
bind_key: Optional[str], revision: str, sql: bool, tag: Optional[str], no_prompt: bool
) -> None:
"""Upgrade the database to the latest revision."""
from rich.prompt import Confirm
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
console.rule("[yellow]Starting database upgrade process[/]", align="left")
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"[bold]Are you sure you want migrate the database to the `{revision}` revision?[/]")
)
if input_confirmed:
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.upgrade(revision=revision, sql=sql, tag=tag)
@database_group.command(
name="init",
help="Initialize migrations for the project.",
)
@bind_key_option
@click.argument(
"directory",
default=None,
required=False,
)
@click.option("--multidb", is_flag=True, default=False, help="Support multiple databases")
@click.option("--package", is_flag=True, default=True, help="Create `__init__.py` for created folder")
@no_prompt_option
def init_alembic( # pyright: ignore[reportUnusedFunction]
bind_key: Optional[str], directory: Optional[str], multidb: bool, package: bool, no_prompt: bool
) -> None:
"""Initialize the database migrations."""
from rich.prompt import Confirm
from advanced_alchemy.alembic.commands import AlembicCommands
ctx = click.get_current_context()
console.rule("[yellow]Initializing database migrations.", align="left")
input_confirmed = (
True if no_prompt else Confirm.ask("[bold]Are you sure you want initialize migrations for the project?[/]")
)
if input_confirmed:
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
for config in configs:
directory = config.alembic_config.script_location if directory is None else directory
alembic_commands = AlembicCommands(sqlalchemy_config=config)
alembic_commands.init(directory=cast("str", directory), multidb=multidb, package=package)
@database_group.command(
name="make-migrations",
help="Create a new migration revision.",
)
@bind_key_option
@click.option("-m", "--message", default=None, help="Revision message")
@click.option(
"--autogenerate/--no-autogenerate", default=True, help="Automatically populate revision with detected changes"
)
@click.option("--sql", is_flag=True, default=False, help="Export to `.sql` instead of writing to the database.")
@click.option("--head", default="head", help="Specify head revision to use as base for new revision.")
@click.option(
"--splice", is_flag=True, default=False, help='Allow a non-head revision as the "head" to splice onto'
)
@click.option("--branch-label", default=None, help="Specify a branch label to apply to the new revision")
@click.option("--version-path", default=None, help="Specify specific path from config for version file")
@click.option("--rev-id", default=None, help="Specify a ID to use for revision.")
@no_prompt_option
def create_revision( # pyright: ignore[reportUnusedFunction]
bind_key: Optional[str],
message: Optional[str],
autogenerate: bool,
sql: bool,
head: str,
splice: bool,
branch_label: Optional[str],
version_path: Optional[str],
rev_id: Optional[str],
no_prompt: bool,
) -> None:
"""Create a new database revision."""
from rich.prompt import Prompt
from advanced_alchemy.alembic.commands import AlembicCommands
def process_revision_directives(
context: "MigrationContext", # noqa: ARG001
revision: tuple[str], # noqa: ARG001
directives: list["MigrationScript"],
) -> None:
"""Handle revision directives."""
if autogenerate and cast("UpgradeOps", directives[0].upgrade_ops).is_empty():
console.rule(
"[magenta]The generation of a migration file is being skipped because it would result in an empty file.",
style="magenta",
align="left",
)
console.rule(
"[magenta]More information can be found here. https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect",
style="magenta",
align="left",
)
console.rule(
"[magenta]If you intend to create an empty migration file, use the --no-autogenerate option.",
style="magenta",
align="left",
)
directives.clear()
ctx = click.get_current_context()
console.rule("[yellow]Starting database upgrade process[/]", align="left")
if message is None:
message = "autogenerated" if no_prompt else Prompt.ask("Please enter a message describing this revision")
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.revision(
message=message,
autogenerate=autogenerate,
sql=sql,
head=head,
splice=splice,
branch_label=branch_label,
version_path=version_path,
rev_id=rev_id,
process_revision_directives=process_revision_directives, # type: ignore[arg-type]
)
@database_group.command(name="drop-all", help="Drop all tables from the database.")
@bind_key_option
@no_prompt_option
def drop_all(bind_key: Optional[str], no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
"""Drop all tables from the database."""
from anyio import run
from rich.prompt import Confirm
from advanced_alchemy.alembic.utils import drop_all
from advanced_alchemy.base import metadata_registry
ctx = click.get_current_context()
console.rule("[yellow]Dropping all tables from the database[/]", align="left")
input_confirmed = no_prompt or Confirm.ask(
"[bold red]Are you sure you want to drop all tables from the database?"
)
async def _drop_all(
configs: "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]",
) -> None:
for config in configs:
engine = config.get_engine()
await drop_all(engine, config.alembic_config.version_table_name, metadata_registry.get(config.bind_key))
if input_confirmed:
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
run(_drop_all, configs)
@database_group.command(name="dump-data", help="Dump specified tables from the database to JSON files.")
@bind_key_option
@click.option(
"--table",
"table_names",
help="Name of the table to dump. Multiple tables can be specified. Use '*' to dump all tables.",
type=str,
required=True,
multiple=True,
)
@click.option(
"--dir",
"dump_dir",
help="Directory to save the JSON files. Defaults to WORKDIR/fixtures",
type=click.Path(path_type=Path),
default=Path.cwd() / "fixtures",
required=False,
)
def dump_table_data(bind_key: Optional[str], table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction]
"""Dump table data to JSON files."""
from anyio import run
from rich.prompt import Confirm
from advanced_alchemy.alembic.utils import dump_tables
from advanced_alchemy.base import metadata_registry, orm_registry
ctx = click.get_current_context()
all_tables = "*" in table_names
if all_tables and not Confirm.ask(
"[yellow bold]You have specified '*'. Are you sure you want to dump all tables from the database?",
):
return console.rule("[red bold]No data was dumped.", style="red", align="left")
async def _dump_tables() -> None:
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
for config in configs:
target_tables = set(metadata_registry.get(config.bind_key).tables)
if not all_tables:
for table_name in set(table_names) - target_tables:
console.rule(
f"[red bold]Skipping table '{table_name}' because it is not available in the default registry",
style="red",
align="left",
)
target_tables.intersection_update(table_names)
else:
console.rule("[yellow bold]Dumping all tables", style="yellow", align="left")
models = [
mapper.class_ for mapper in orm_registry.mappers if mapper.class_.__table__.name in target_tables
]
await dump_tables(dump_dir, config.get_session(), models)
console.rule("[green bold]Data dump complete", align="left")
return run(_dump_tables)
return database_group
python-advanced-alchemy-1.0.1/advanced_alchemy/config/ 0000775 0000000 0000000 00000000000 14766637146 0022755 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/config/__init__.py 0000664 0000000 0000000 00000001757 14766637146 0025100 0 ustar 00root root 0000000 0000000 from advanced_alchemy.config.asyncio import AlembicAsyncConfig, AsyncSessionConfig, SQLAlchemyAsyncConfig
from advanced_alchemy.config.common import (
ConnectionT,
EngineT,
GenericAlembicConfig,
GenericSessionConfig,
GenericSQLAlchemyConfig,
SessionMakerT,
SessionT,
)
from advanced_alchemy.config.engine import EngineConfig
from advanced_alchemy.config.sync import AlembicSyncConfig, SQLAlchemySyncConfig, SyncSessionConfig
from advanced_alchemy.config.types import CommitStrategy, TypeDecodersSequence, TypeEncodersMap
__all__ = (
"AlembicAsyncConfig",
"AlembicSyncConfig",
"AsyncSessionConfig",
"CommitStrategy",
"ConnectionT",
"EngineConfig",
"EngineT",
"GenericAlembicConfig",
"GenericSQLAlchemyConfig",
"GenericSessionConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SessionMakerT",
"SessionT",
"SyncSessionConfig",
"TypeDecodersSequence",
"TypeEncodersMap",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/config/asyncio.py 0000664 0000000 0000000 00000006354 14766637146 0025004 0 ustar 00root root 0000000 0000000 from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Union
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from advanced_alchemy.config.common import (
GenericAlembicConfig,
GenericSessionConfig,
GenericSQLAlchemyConfig,
)
from advanced_alchemy.utils.dataclass import Empty
if TYPE_CHECKING:
from typing import Callable
from sqlalchemy.orm import Session
from advanced_alchemy.utils.dataclass import EmptyType
__all__ = (
"AlembicAsyncConfig",
"AsyncSessionConfig",
"SQLAlchemyAsyncConfig",
)
@dataclass
class AsyncSessionConfig(GenericSessionConfig[AsyncConnection, AsyncEngine, AsyncSession]):
"""SQLAlchemy async session config."""
sync_session_class: "Union[type[Session], None, EmptyType]" = Empty
"""A :class:`Session ` subclass or other callable which will be used to construct the
:class:`Session ` which will be proxied. This parameter may be used to provide custom
:class:`Session ` subclasses. Defaults to the
:attr:`AsyncSession.sync_session_class ` class-level
attribute."""
@dataclass
class AlembicAsyncConfig(GenericAlembicConfig):
"""Configuration for an Async Alembic's Config class.
.. seealso::
https://alembic.sqlalchemy.org/en/latest/api/config.html
"""
@dataclass
class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, async_sessionmaker[AsyncSession]]):
"""Async SQLAlchemy Configuration.
Note:
The alembic configuration options are documented in the Alembic documentation.
"""
create_engine_callable: "Callable[[str], AsyncEngine]" = create_async_engine
"""Callable that creates an :class:`AsyncEngine ` instance or instance of its
subclass.
"""
session_config: AsyncSessionConfig = field(default_factory=AsyncSessionConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration options for the :class:`async_sessionmaker`."""
session_maker_class: "type[async_sessionmaker[AsyncSession]]" = async_sessionmaker # pyright: ignore[reportIncompatibleVariableOverride]
"""Sessionmaker class to use."""
alembic_config: "AlembicAsyncConfig" = field(default_factory=AlembicAsyncConfig)
"""Configuration for the SQLAlchemy Alembic migrations.
The configuration options are documented in the Alembic documentation.
"""
def __hash__(self) -> int:
return super().__hash__()
def __eq__(self, other: object) -> bool:
return super().__eq__(other)
@asynccontextmanager
async def get_session(
self,
) -> AsyncGenerator[AsyncSession, None]:
"""Get a session from the session maker.
Returns:
AsyncGenerator[AsyncSession, None]: An async context manager that yields an AsyncSession.
"""
session_maker = self.create_session_maker()
async with session_maker() as session:
yield session
python-advanced-alchemy-1.0.1/advanced_alchemy/config/common.py 0000664 0000000 0000000 00000032501 14766637146 0024620 0 ustar 00root root 0000000 0000000 from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, Union, cast
from typing_extensions import TypeVar
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config.engine import EngineConfig
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.utils.dataclass import Empty, simple_asdict
if TYPE_CHECKING:
from sqlalchemy import Connection, Engine, MetaData
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Mapper, Query, Session, sessionmaker
from sqlalchemy.orm.session import JoinTransactionMode
from sqlalchemy.sql import TableClause
from advanced_alchemy.utils.dataclass import EmptyType
__all__ = (
"ALEMBIC_TEMPLATE_PATH",
"ConnectionT",
"EngineT",
"GenericAlembicConfig",
"GenericSQLAlchemyConfig",
"GenericSessionConfig",
"SessionMakerT",
"SessionT",
)
ALEMBIC_TEMPLATE_PATH = f"{Path(__file__).parent.parent}/alembic/templates"
"""Path to the Alembic templates."""
ConnectionT = TypeVar("ConnectionT", bound="Union[Connection, AsyncConnection]")
"""Type variable for SQLAlchemy connection types.
.. seealso::
:class:`sqlalchemy.Connection`
:class:`sqlalchemy.ext.asyncio.AsyncConnection`
"""
EngineT = TypeVar("EngineT", bound="Union[Engine, AsyncEngine]")
"""Type variable for a SQLAlchemy engine.
.. seealso::
:class:`sqlalchemy.Engine`
:class:`sqlalchemy.ext.asyncio.AsyncEngine`
"""
SessionT = TypeVar("SessionT", bound="Union[Session, AsyncSession]")
"""Type variable for a SQLAlchemy session.
.. seealso::
:class:`sqlalchemy.Session`
:class:`sqlalchemy.ext.asyncio.AsyncSession`
"""
SessionMakerT = TypeVar("SessionMakerT", bound="Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]")
"""Type variable for a SQLAlchemy sessionmaker.
.. seealso::
:class:`sqlalchemy.orm.sessionmaker`
:class:`sqlalchemy.ext.asyncio.async_sessionmaker`
"""
@dataclass
class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]):
"""SQLAlchemy async session config.
Types:
ConnectionT: :class:`sqlalchemy.Connection` | :class:`sqlalchemy.ext.asyncio.AsyncConnection`
EngineT: :class:`sqlalchemy.Engine` | :class:`sqlalchemy.ext.asyncio.AsyncEngine`
SessionT: :class:`sqlalchemy.Session` | :class:`sqlalchemy.ext.asyncio.AsyncSession`
"""
autobegin: "Union[bool, EmptyType]" = Empty
"""Automatically start transactions when database access is requested by an operation.
Bool or :class:`Empty `
"""
autoflush: "Union[bool, EmptyType]" = Empty
"""When ``True``, all query operations will issue a flush call to this :class:`Session `
before proceeding"""
bind: "Union[EngineT, ConnectionT, None, EmptyType]" = Empty
"""The :class:`Engine ` or :class:`Connection ` that new
:class:`Session ` objects will be bound to."""
binds: "Union[dict[Union[type[Any], Mapper[Any], TableClause, str], Union[EngineT, ConnectionT]], None, EmptyType]" = Empty
"""A dictionary which may specify any number of :class:`Engine ` or :class:`Connection
` objects as the source of connectivity for SQL operations on a per-entity basis. The
keys of the dictionary consist of any series of mapped classes, arbitrary Python classes that are bases for mapped
classes, :class:`Table ` objects and :class:`Mapper ` objects. The
values of the dictionary are then instances of :class:`Engine ` or less commonly
:class:`Connection ` objects."""
class_: "Union[type[SessionT], EmptyType]" = Empty
"""Class to use in order to create new :class:`Session ` objects."""
expire_on_commit: "Union[bool, EmptyType]" = Empty
"""If ``True``, all instances will be expired after each commit."""
info: "Union[dict[str, Any], None, EmptyType]" = Empty
"""Optional dictionary of information that will be available via the
:attr:`Session.info `"""
join_transaction_mode: "Union[JoinTransactionMode, EmptyType]" = Empty
"""Describes the transactional behavior to take when a given bind is a Connection that has already begun a
transaction outside the scope of this Session; in other words the
:attr:`Connection.in_transaction() ` method returns True."""
query_cls: "Union[type[Query], None, EmptyType]" = Empty # pyright: ignore[reportMissingTypeArgument]
"""Class which should be used to create new Query objects, as returned by the
:attr:`Session.query() ` method."""
twophase: "Union[bool, EmptyType]" = Empty
"""When ``True``, all transactions will be started as a โtwo phaseโ transaction, i.e. using the โtwo phaseโ
semantics of the database in use along with an XID. During a :attr:`commit() `, after
:attr:`flush() ` has been issued for all attached databases, the
:attr:`TwoPhaseTransaction.prepare() ` method on each database`s
:class:`TwoPhaseTransaction ` will be called. This allows each database to
roll back the entire transaction, before each transaction is committed."""
@dataclass
class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]):
"""Common SQLAlchemy Configuration.
Types:
EngineT: :class:`sqlalchemy.Engine` or :class:`sqlalchemy.ext.asyncio.AsyncEngine`
SessionT: :class:`sqlalchemy.Session` or :class:`sqlalchemy.ext.asyncio.AsyncSession`
SessionMakerT: :class:`sqlalchemy.orm.sessionmaker` or :class:`sqlalchemy.ext.asyncio.async_sessionmaker`
"""
create_engine_callable: "Callable[[str], EngineT]"
"""Callable that creates an :class:`AsyncEngine ` instance or instance of its
subclass.
"""
session_config: "GenericSessionConfig[Any, Any, Any]"
"""Configuration options for either the :class:`async_sessionmaker `
or :class:`sessionmaker `.
"""
session_maker_class: "type[Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]]"
"""Sessionmaker class to use.
.. seealso::
:class:`sqlalchemy.orm.sessionmaker`
:class:`sqlalchemy.ext.asyncio.async_sessionmaker`
"""
connection_string: "Union[str, None]" = field(default=None)
"""Database connection string in one of the formats supported by SQLAlchemy.
Notes:
- For async connections, the connection string must include the correct async prefix.
e.g. ``'postgresql+asyncpg://...'`` instead of ``'postgresql://'``, and for sync connections its the opposite.
"""
engine_config: "EngineConfig" = field(default_factory=EngineConfig)
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
session_maker: "Union[Callable[[], SessionT], None]" = None
"""Callable that returns a session.
If provided, the plugin will use this rather than instantiate a sessionmaker.
"""
engine_instance: "Union[EngineT, None]" = None
"""Optional engine to use.
If set, the plugin will use the provided instance rather than instantiate an engine.
"""
create_all: bool = False
"""If true, all models are automatically created on engine creation."""
metadata: "Union[MetaData, None]" = None
"""Optional metadata to use.
If set, the plugin will use the provided instance rather than the default metadata."""
enable_touch_updated_timestamp_listener: bool = True
"""Enable Created/Updated Timestamp event listener.
This is a listener that will update ``created_at`` and ``updated_at`` columns on record modification.
Disable if you plan to bring your own update mechanism for these columns"""
bind_key: "Union[str, None]" = None
"""Bind key to register a metadata to a specific engine configuration."""
_SESSION_SCOPE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of session scope keys in the class."""
_ENGINE_APP_STATE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of engine app state keys in the class."""
_SESSIONMAKER_APP_STATE_KEY_REGISTRY: "ClassVar[set[str]]" = field(init=False, default=cast("set[str]", set()))
"""Internal counter for ensuring unique identification of sessionmaker state keys in the class."""
def __post_init__(self) -> None:
if self.connection_string is not None and self.engine_instance is not None:
msg = "Only one of 'connection_string' or 'engine_instance' can be provided."
raise ImproperConfigurationError(msg)
if self.metadata is None:
self.metadata = metadata_registry.get(self.bind_key)
else:
metadata_registry.set(self.bind_key, self.metadata)
if self.enable_touch_updated_timestamp_listener:
from sqlalchemy import event
from sqlalchemy.orm import Session
from advanced_alchemy._listeners import touch_updated_timestamp
event.listen(Session, "before_flush", touch_updated_timestamp)
def __hash__(self) -> int: # pragma: no cover
return hash(
(
self.__class__.__qualname__,
self.connection_string,
self.engine_config.__class__.__qualname__,
self.bind_key,
)
)
def __eq__(self, other: object) -> bool:
return self.__hash__() == other.__hash__()
@property
def engine_config_dict(self) -> dict[str, Any]:
"""Return the engine configuration as a dict.
Returns:
A string keyed dict of config kwargs for the SQLAlchemy :func:`sqlalchemy.get_engine`
function.
"""
return simple_asdict(self.engine_config, exclude_empty=True)
@property
def session_config_dict(self) -> dict[str, Any]:
"""Return the session configuration as a dict.
Returns:
A string keyed dict of config kwargs for the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker`
class.
"""
return simple_asdict(self.session_config, exclude_empty=True)
def get_engine(self) -> EngineT:
"""Return an engine. If none exists yet, create one.
Returns:
:class:`sqlalchemy.Engine` or :class:`sqlalchemy.ext.asyncio.AsyncEngine` instance used by the plugin.
"""
if self.engine_instance:
return self.engine_instance
if self.connection_string is None:
msg = "One of 'connection_string' or 'engine_instance' must be provided."
raise ImproperConfigurationError(msg)
engine_config = self.engine_config_dict
try:
return self.create_engine_callable(self.connection_string, **engine_config)
except TypeError:
# likely due to a dialect that doesn't support json type
del engine_config["json_deserializer"]
del engine_config["json_serializer"]
return self.create_engine_callable(self.connection_string, **engine_config)
def create_session_maker(self) -> "Callable[[], SessionT]": # pragma: no cover
"""Get a session maker. If none exists yet, create one.
Returns:
:class:`sqlalchemy.orm.sessionmaker` or :class:`sqlalchemy.ext.asyncio.async_sessionmaker` factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if session_kws.get("bind") is None:
session_kws["bind"] = self.get_engine()
return cast("Callable[[], SessionT]", self.session_maker_class(**session_kws))
@dataclass
class GenericAlembicConfig:
"""Configuration for Alembic's :class:`Config `.
For details see: https://alembic.sqlalchemy.org/en/latest/api/config.html
"""
script_config: str = "alembic.ini"
"""A path to the Alembic configuration file such as ``alembic.ini``. If left unset, the default configuration
will be used.
"""
version_table_name: str = "alembic_versions"
"""Configure the name of the table used to hold the applied alembic revisions.
Defaults to ``alembic_versions``.
"""
version_table_schema: "Optional[str]" = None
"""Configure the schema to use for the alembic revisions revisions.
If unset, it defaults to connection's default schema."""
script_location: str = "migrations"
"""A path to save generated migrations.
"""
user_module_prefix: "Optional[str]" = "sa."
"""User module prefix."""
render_as_batch: bool = True
"""Render as batch."""
compare_type: bool = False
"""Compare type."""
template_path: str = ALEMBIC_TEMPLATE_PATH
"""Template path."""
python-advanced-alchemy-1.0.1/advanced_alchemy/config/engine.py 0000664 0000000 0000000 00000026131 14766637146 0024577 0 ustar 00root root 0000000 0000000 from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Literal, Union
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.utils.dataclass import Empty
if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any
from sqlalchemy.engine.interfaces import IsolationLevel
from sqlalchemy.pool import Pool
from typing_extensions import TypeAlias
from advanced_alchemy.utils.dataclass import EmptyType
_EchoFlagType: "TypeAlias" = 'Union[None, bool, Literal["debug"]]'
_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat", "numeric_dollar"]
__all__ = ("EngineConfig",)
@dataclass
class EngineConfig:
"""Configuration for SQLAlchemy's Engine.
This class provides configuration options for SQLAlchemy engine creation.
See: https://docs.sqlalchemy.org/en/20/core/engines.html
"""
connect_args: "Union[dict[Any, Any], EmptyType]" = Empty
"""A dictionary of arguments which will be passed directly to the DBAPI's ``connect()`` method as keyword arguments.
"""
echo: "Union[_EchoFlagType, EmptyType]" = Empty
"""If ``True``, the Engine will log all statements as well as a ``repr()`` of their parameter lists to the default
log handler, which defaults to ``sys.stdout`` for output. If set to the string "debug", result rows will be printed
to the standard output as well. The echo attribute of Engine can be modified at any time to turn logging on and off;
direct control of logging is also available using the standard Python logging module.
"""
echo_pool: "Union[_EchoFlagType, EmptyType]" = Empty
"""If ``True``, the connection pool will log informational output such as when connections are invalidated as well
as when connections are recycled to the default log handler, which defaults to sys.stdout for output. If set to the
string "debug", the logging will include pool checkouts and checkins. Direct control of logging is also available
using the standard Python logging module."""
enable_from_linting: "Union[bool, EmptyType]" = Empty
"""Defaults to True. Will emit a warning if a given SELECT statement is found to have un-linked FROM elements which
would cause a cartesian product."""
execution_options: "Union[Mapping[str, Any], EmptyType]" = Empty
"""Dictionary execution options which will be applied to all connections. See
:attr:`Connection.execution_options() ` for details."""
hide_parameters: "Union[bool, EmptyType]" = Empty
"""Boolean, when set to ``True``, SQL statement parameters will not be displayed in INFO logging nor will they be
formatted into the string representation of :class:`StatementError ` objects."""
insertmanyvalues_page_size: "Union[int, EmptyType]" = Empty
"""Number of rows to format into an INSERT statement when the statement uses โinsertmanyvaluesโ mode, which is a
paged form of bulk insert that is used for many backends when using executemany execution typically in conjunction
with RETURNING. Defaults to 1000, but may also be subject to dialect-specific limiting factors which may override
this value on a per-statement basis."""
isolation_level: "Union[IsolationLevel, EmptyType]" = Empty
"""Optional string name of an isolation level which will be set on all new connections unconditionally. Isolation
levels are typically some subset of the string names "SERIALIZABLE", "REPEATABLE READ", "READ COMMITTED",
"READ UNCOMMITTED" and "AUTOCOMMIT" based on backend."""
json_deserializer: "Callable[[str], Any]" = decode_json
"""For dialects that support the :class:`JSON ` datatype, this is a Python callable that will
convert a JSON string to a Python object. By default, this is set to Litestar's
:attr:`decode_json() <.serialization.decode_json>` function."""
json_serializer: "Callable[[Any], str]" = encode_json
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, Litestar's :attr:`encode_json() <.serialization.encode_json>` is used."""
label_length: "Union[int, None, EmptyType]" = Empty
"""Optional integer value which limits the size of dynamically generated column labels to that many characters. If
less than 6, labels are generated as โ_(counter)โ. If ``None``, the value of ``dialect.max_identifier_length``,
which may be affected via the
:attr:`get_engine.max_identifier_length parameter `, is
used instead. The value of
:attr:`get_engine.label_length ` may not be larger than that of
:attr:`get_engine.max_identifier_length `."""
logging_name: "Union[str, EmptyType]" = Empty
"""String identifier which will be used within the โnameโ field of logging records generated within the
โsqlalchemy.engineโ logger. Defaults to a hexstring of the object`s id."""
max_identifier_length: "Union[int, None, EmptyType]" = Empty
"""Override the max_identifier_length determined by the dialect. if ``None`` or ``0``, has no effect. This is the
database`s configured maximum number of characters that may be used in a SQL identifier such as a table name, column
name, or label name. All dialects determine this value automatically, however in the case of a new database version
for which this value has changed but SQLAlchemy`s dialect has not been adjusted, the value may be passed here."""
max_overflow: "Union[int, EmptyType]" = Empty
"""The number of connections to allow in connection pool โoverflowโ, that is connections that can be opened above
and beyond the pool_size setting, which defaults to five. This is only used with
:class:`QueuePool `."""
module: "Union[Any, None, EmptyType]" = Empty
"""Reference to a Python module object (the module itself, not its string name). Specifies an alternate DBAPI module
to be used by the engine`s dialect. Each sub-dialect references a specific DBAPI which will be imported before first
connect. This parameter causes the import to be bypassed, and the given module to be used instead. Can be used for
testing of DBAPIs as well as to inject โmockโ DBAPI implementations into the
:class:`Engine `."""
paramstyle: "Union[_ParamStyle, None, EmptyType]" = Empty
"""The paramstyle to use when rendering bound parameters. This style defaults to the one recommended by the DBAPI
itself, which is retrieved from the ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept more than
one paramstyle, and in particular it may be desirable to change a โnamedโ paramstyle into a โpositionalโ one, or
vice versa. When this attribute is passed, it should be one of the values "qmark", "numeric", "named", "format" or
"pyformat", and should correspond to a parameter style known to be supported by the DBAPI in use."""
pool: "Union[Pool, None, EmptyType]" = Empty
"""An already-constructed instance of :class:`Pool `, such as a
:class:`QueuePool ` instance. If non-None, this pool will be used directly as the
underlying connection pool for the engine, bypassing whatever connection parameters are present in the URL argument.
For information on constructing connection pools manually, see
`Connection Pooling `_."""
poolclass: "Union[type[Pool], None, EmptyType]" = Empty
"""A :class:`Pool ` subclass, which will be used to create a connection pool instance using
the connection parameters given in the URL. Note this differs from pool in that you don`t actually instantiate the
pool in this case, you just indicate what type of pool to be used."""
pool_logging_name: "Union[str, EmptyType]" = Empty
"""String identifier which will be used within the โnameโ field of logging records generated within the
โsqlalchemy.poolโ logger. Defaults to a hexstring of the object`s id."""
pool_pre_ping: "Union[bool, EmptyType]" = Empty
"""If True will enable the connection pool โpre-pingโ feature that tests connections for liveness upon each
checkout."""
pool_size: "Union[int, EmptyType]" = Empty
"""The number of connections to keep open inside the connection pool. This used with
:class:`QueuePool ` as well as
:class:`SingletonThreadPool `. With
:class:`QueuePool `, a pool_size setting of ``0`` indicates no limit; to disable pooling,
set ``poolclass`` to :class:`NullPool ` instead."""
pool_recycle: "Union[int, EmptyType]" = Empty
"""This setting causes the pool to recycle connections after the given number of seconds has passed. It defaults to
``-1``, or no timeout. For example, setting to ``3600`` means connections will be recycled after one hour. Note that
MySQL in particular will disconnect automatically if no activity is detected on a connection for eight hours
(although this is configurable with the MySQLDB connection itself and the server configuration as well)."""
pool_reset_on_return: 'Union[Literal["rollback", "commit"], EmptyType]' = Empty
"""Set the :attr:`Pool.reset_on_return ` object, which can be set to the values ``"rollback"``, ``"commit"``, or
``None``."""
pool_timeout: "Union[int, EmptyType]" = Empty
"""Number of seconds to wait before giving up on getting a connection from the pool. This is only used with
:class:`QueuePool `. This can be a float but is subject to the limitations of Python time
functions which may not be reliable in the tens of milliseconds."""
pool_use_lifo: "Union[bool, EmptyType]" = Empty
"""Use LIFO (last-in-first-out) when retrieving connections from :class:`QueuePool `
instead of FIFO (first-in-first-out). Using LIFO, a server-side timeout scheme can reduce the number of connections
used during non-peak periods of use. When planning for server-side timeouts, ensure that a recycle or pre-ping
strategy is in use to gracefully handle stale connections."""
plugins: "Union[list[str], EmptyType]" = Empty
"""String list of plugin names to load. See :class:`CreateEnginePlugin ` for
background."""
query_cache_size: "Union[int, EmptyType]" = Empty
"""Size of the cache used to cache the SQL string form of queries. Set to zero to disable caching.
See :attr:`query_cache_size ` for more info.
"""
use_insertmanyvalues: "Union[bool, EmptyType]" = Empty
"""``True`` by default, use the โinsertmanyvaluesโ execution style for INSERT..RETURNING statements by default."""
python-advanced-alchemy-1.0.1/advanced_alchemy/config/sync.py 0000664 0000000 0000000 00000005156 14766637146 0024312 0 ustar 00root root 0000000 0000000 """Sync SQLAlchemy configuration module."""
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from sqlalchemy import Connection, Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker
from advanced_alchemy.config.common import GenericAlembicConfig, GenericSessionConfig, GenericSQLAlchemyConfig
if TYPE_CHECKING:
from collections.abc import Generator
from typing import Callable
__all__ = (
"AlembicSyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
)
@dataclass
class SyncSessionConfig(GenericSessionConfig[Connection, Engine, Session]):
"""Configuration for synchronous SQLAlchemy sessions."""
@dataclass
class AlembicSyncConfig(GenericAlembicConfig):
"""Configuration for Alembic's synchronous migrations.
For details see: https://alembic.sqlalchemy.org/en/latest/api/config.html
"""
@dataclass
class SQLAlchemySyncConfig(GenericSQLAlchemyConfig[Engine, Session, sessionmaker[Session]]):
"""Synchronous SQLAlchemy Configuration.
Note:
The alembic configuration options are documented in the Alembic documentation.
"""
create_engine_callable: "Callable[[str], Engine]" = create_engine
"""Callable that creates an :class:`Engine ` instance or instance of its subclass."""
session_config: SyncSessionConfig = field(default_factory=SyncSessionConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration options for the :class:`sessionmaker`."""
session_maker_class: type[sessionmaker[Session]] = sessionmaker # pyright: ignore[reportIncompatibleVariableOverride]
"""Sessionmaker class to use."""
alembic_config: AlembicSyncConfig = field(default_factory=AlembicSyncConfig)
"""Configuration for the SQLAlchemy Alembic migrations.
The configuration options are documented in the Alembic documentation.
"""
def __hash__(self) -> int:
return super().__hash__()
def __eq__(self, other: object) -> bool:
return super().__eq__(other)
@contextmanager
def get_session(self) -> "Generator[Session, None, None]":
"""Get a session context manager.
Yields:
Generator[sqlalchemy.orm.Session, None, None]: A context manager yielding an active SQLAlchemy Session.
Examples:
Using the session context manager:
>>> with config.get_session() as session:
... session.execute(...)
"""
session_maker = self.create_session_maker()
with session_maker() as session:
yield session
python-advanced-alchemy-1.0.1/advanced_alchemy/config/types.py 0000664 0000000 0000000 00000001475 14766637146 0024502 0 ustar 00root root 0000000 0000000 """Type aliases and constants used in the package config."""
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Literal
from typing_extensions import TypeAlias
TypeEncodersMap: TypeAlias = Mapping[Any, Callable[[Any], Any]]
"""Type alias for a mapping of type encoders.
Maps types to their encoder functions.
"""
TypeDecodersSequence: TypeAlias = Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]
"""Type alias for a sequence of type decoders.
Each tuple contains a type check predicate and its corresponding decoder function.
"""
CommitStrategy: TypeAlias = Literal["always", "match_status"]
"""Commit strategy for SQLAlchemy sessions.
Values:
always: Always commit the session after operations
match_status: Only commit if the HTTP status code indicates success
"""
python-advanced-alchemy-1.0.1/advanced_alchemy/exceptions.py 0000664 0000000 0000000 00000027241 14766637146 0024251 0 ustar 00root root 0000000 0000000 import re
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Callable, Optional, TypedDict, Union, cast
from sqlalchemy.exc import IntegrityError as SQLAlchemyIntegrityError
from sqlalchemy.exc import InvalidRequestError as SQLAlchemyInvalidRequestError
from sqlalchemy.exc import MultipleResultsFound, SQLAlchemyError, StatementError
__all__ = (
"AdvancedAlchemyError",
"DuplicateKeyError",
"ErrorMessages",
"ForeignKeyError",
"ImproperConfigurationError",
"IntegrityError",
"MissingDependencyError",
"MultipleResultsFoundError",
"NotFoundError",
"RepositoryError",
"SerializationError",
"wrap_sqlalchemy_exception",
)
DUPLICATE_KEY_REGEXES = {
"postgresql": [
re.compile(
r"^.*duplicate\s+key.*\"(?P[^\"]+)\"\s*\n.*Key\s+\((?P.*)\)=\((?P.*)\)\s+already\s+exists.*$",
),
re.compile(r"^.*duplicate\s+key.*\"(?P[^\"]+)\"\s*\n.*$"),
],
"sqlite": [
re.compile(r"^.*columns?(?P[^)]+)(is|are)\s+not\s+unique$"),
re.compile(r"^.*UNIQUE\s+constraint\s+failed:\s+(?P.+)$"),
re.compile(r"^.*PRIMARY\s+KEY\s+must\s+be\s+unique.*$"),
],
"mysql": [
re.compile(r"^.*\b1062\b.*Duplicate entry '(?P.*)' for key '(?P[^']+)'.*$"),
re.compile(r"^.*\b1062\b.*Duplicate entry \\'(?P.*)\\' for key \\'(?P.+)\\'.*$"),
],
"oracle": [],
"spanner+spanner": [],
"duckdb": [],
"mssql": [],
"bigquery": [],
"cockroach": [],
}
FOREIGN_KEY_REGEXES = {
"postgresql": [
re.compile(
r".*on table \"(?P
.+)\" violates check constraint (?P.+)"),
],
"sqlite": [],
"mysql": [],
"oracle": [],
"spanner+spanner": [],
"duckdb": [],
"mssql": [],
"bigquery": [],
"cockroach": [],
}
class AdvancedAlchemyError(Exception):
"""Base exception class from which all Advanced Alchemy exceptions inherit."""
detail: str
def __init__(self, *args: Any, detail: str = "") -> None:
"""Initialize ``AdvancedAlchemyException``.
Args:
*args: args are converted to :class:`str` before passing to :class:`Exception`
detail: detail of the exception.
"""
str_args = [str(arg) for arg in args if arg]
if not detail:
if str_args:
detail, *str_args = str_args
elif hasattr(self, "detail"):
detail = self.detail
self.detail = detail
super().__init__(*str_args)
def __repr__(self) -> str:
if self.detail:
return f"{self.__class__.__name__} - {self.detail}"
return self.__class__.__name__
def __str__(self) -> str:
return " ".join((*self.args, self.detail)).strip()
class MissingDependencyError(AdvancedAlchemyError, ImportError):
"""Missing optional dependency.
This exception is raised when a module depends on a dependency that has not been installed.
Args:
package: Name of the missing package.
install_package: Optional alternative package name to install.
"""
def __init__(self, package: str, install_package: Optional[str] = None) -> None:
super().__init__(
f"Package {package!r} is not installed but required. You can install it by running "
f"'pip install advanced_alchemy[{install_package or package}]' to install advanced_alchemy with the required extra "
f"or 'pip install {install_package or package}' to install the package separately",
)
class ImproperConfigurationError(AdvancedAlchemyError):
"""Improper Configuration error.
This exception is raised when there is an issue with the configuration of a module.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class SerializationError(AdvancedAlchemyError):
"""Encoding or decoding error.
This exception is raised when serialization or deserialization of an object fails.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class RepositoryError(AdvancedAlchemyError):
"""Base repository exception type.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class IntegrityError(RepositoryError):
"""Data integrity error.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class DuplicateKeyError(IntegrityError):
"""Duplicate key error.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class ForeignKeyError(IntegrityError):
"""Foreign key error.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class NotFoundError(RepositoryError):
"""Not found error.
This exception is raised when a requested resource is not found.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class MultipleResultsFoundError(RepositoryError):
"""Multiple results found error.
This exception is raised when a single result was expected but multiple were found.
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class InvalidRequestError(RepositoryError):
"""Invalid request error.
This exception is raised when SQLAlchemy is unable to complete the request due to a runtime error
Args:
*args: Variable length argument list passed to parent class.
detail: Detailed error message.
"""
class ErrorMessages(TypedDict, total=False):
duplicate_key: Union[str, Callable[[Exception], str]]
integrity: Union[str, Callable[[Exception], str]]
foreign_key: Union[str, Callable[[Exception], str]]
multiple_rows: Union[str, Callable[[Exception], str]]
check_constraint: Union[str, Callable[[Exception], str]]
other: Union[str, Callable[[Exception], str]]
not_found: Union[str, Callable[[Exception], str]]
def _get_error_message(error_messages: ErrorMessages, key: str, exc: Exception) -> str:
template: Union[str, Callable[[Exception], str]] = error_messages.get(key, f"{key} error: {exc}") # type: ignore[assignment]
if callable(template): # pyright: ignore[reportUnknownArgumentType]
template = template(exc) # pyright: ignore[reportUnknownVariableType]
return template # pyright: ignore[reportUnknownVariableType]
@contextmanager
def wrap_sqlalchemy_exception( # noqa: C901, PLR0915
error_messages: Optional[ErrorMessages] = None,
dialect_name: Optional[str] = None,
wrap_exceptions: bool = True,
) -> Generator[None, None, None]:
"""Do something within context to raise a ``RepositoryError`` chained
from an original ``SQLAlchemyError``.
Args:
error_messages: Error messages to use for the exception.
dialect_name: The name of the dialect to use for the exception.
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
>>> try:
... with wrap_sqlalchemy_exception():
... raise SQLAlchemyError("Original Exception")
... except RepositoryError as exc:
... print(
... f"caught repository exception from {type(exc.__context__)}"
... )
caught repository exception from
"""
try:
yield
except NotFoundError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="not_found", exc=exc)
else:
msg = "No rows matched the specified data"
raise NotFoundError(detail=msg) from exc
except MultipleResultsFound as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="multiple_rows", exc=exc)
else:
msg = "Multiple rows matched the specified data"
raise MultipleResultsFoundError(detail=msg) from exc
except SQLAlchemyIntegrityError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None and dialect_name is not None:
_keys_to_regex = {
"duplicate_key": (DUPLICATE_KEY_REGEXES.get(dialect_name, []), DuplicateKeyError),
"check_constraint": (CHECK_CONSTRAINT_REGEXES.get(dialect_name, []), IntegrityError),
"foreign_key": (FOREIGN_KEY_REGEXES.get(dialect_name, []), ForeignKeyError),
}
detail = " - ".join(str(exc_arg) for exc_arg in exc.orig.args) if exc.orig.args else "" # type: ignore[union-attr] # pyright: ignore[reportArgumentType,reportOptionalMemberAccess]
for key, (regexes, exception) in _keys_to_regex.items():
for regex in regexes:
if (match := regex.findall(detail)) and match[0]:
raise exception(
detail=_get_error_message(error_messages=error_messages, key=key, exc=exc),
) from exc
raise IntegrityError(
detail=_get_error_message(error_messages=error_messages, key="integrity", exc=exc),
) from exc
raise IntegrityError(detail=f"An integrity error occurred: {exc}") from exc
except SQLAlchemyInvalidRequestError as exc:
if wrap_exceptions is False:
raise
raise InvalidRequestError(detail="An invalid request was made.") from exc
except StatementError as exc:
if wrap_exceptions is False:
raise
raise IntegrityError(
detail=cast("str", getattr(exc.orig, "detail", "There was an issue processing the statement."))
) from exc
except SQLAlchemyError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="other", exc=exc)
else:
msg = f"An exception occurred: {exc}"
raise RepositoryError(detail=msg) from exc
except AttributeError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="other", exc=exc)
else:
msg = f"An attribute error occurred during processing: {exc}"
raise RepositoryError(detail=msg) from exc
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/ 0000775 0000000 0000000 00000000000 14766637146 0023707 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/__init__.py 0000664 0000000 0000000 00000000000 14766637146 0026006 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/fastapi/ 0000775 0000000 0000000 00000000000 14766637146 0025336 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/fastapi/__init__.py 0000664 0000000 0000000 00000002254 14766637146 0027452 0 ustar 00root root 0000000 0000000 """FastAPI extension for Advanced Alchemy.
This module provides FastAPI integration for Advanced Alchemy, including session management,
database migrations, and service utilities.
"""
from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig
from advanced_alchemy.extensions.fastapi.cli import get_database_migration_plugin
from advanced_alchemy.extensions.fastapi.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.fastapi.extension import AdvancedAlchemy, assign_cli_group
__all__ = (
"AdvancedAlchemy",
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"assign_cli_group",
"base",
"exceptions",
"filters",
"get_database_migration_plugin",
"mixins",
"operations",
"repository",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/fastapi/cli.py 0000664 0000000 0000000 00000002435 14766637146 0026463 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Optional, cast
try:
import rich_click as click
except ImportError:
import click # type: ignore[no-redef]
from advanced_alchemy.cli import add_migration_commands
if TYPE_CHECKING:
from fastapi import FastAPI
from advanced_alchemy.extensions.fastapi.extension import AdvancedAlchemy
def get_database_migration_plugin(app: "FastAPI") -> "AdvancedAlchemy": # pragma: no cover
"""Retrieve the Advanced Alchemy extension from a FastAPI application instance."""
from advanced_alchemy.exceptions import ImproperConfigurationError
extension = cast("Optional[AdvancedAlchemy]", getattr(app.state, "advanced_alchemy", None))
if extension is None:
msg = "Failed to initialize database CLI. The Advanced Alchemy extension is not properly configured."
raise ImproperConfigurationError(msg)
return extension
def register_database_commands(app: "FastAPI") -> click.Group: # pragma: no cover
@click.group(name="database")
@click.pass_context
def database_group(ctx: click.Context) -> None:
"""Manage SQLAlchemy database components."""
ctx.ensure_object(dict)
ctx.obj["configs"] = get_database_migration_plugin(app).config
add_migration_commands(database_group)
return database_group
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/fastapi/config.py 0000664 0000000 0000000 00000000310 14766637146 0027147 0 ustar 00root root 0000000 0000000 from advanced_alchemy.extensions.starlette import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
__all__ = (
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/fastapi/extension.py 0000664 0000000 0000000 00000002743 14766637146 0027732 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Optional, Union
from advanced_alchemy.extensions.fastapi.cli import register_database_commands
from advanced_alchemy.extensions.starlette import AdvancedAlchemy as StarletteAdvancedAlchemy
if TYPE_CHECKING:
from collections.abc import Sequence
from fastapi import FastAPI
from advanced_alchemy.extensions.fastapi.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
__all__ = ("AdvancedAlchemy",)
def assign_cli_group(app: "FastAPI") -> None: # pragma: no cover
try:
from fastapi_cli.cli import app as fastapi_cli_app # pyright: ignore[reportUnknownVariableType]
from typer.main import get_group
except ImportError:
print("FastAPI CLI is not installed. Skipping CLI registration.") # noqa: T201
return
click_app = get_group(fastapi_cli_app) # pyright: ignore[reportUnknownArgumentType]
click_app.add_command(register_database_commands(app))
class AdvancedAlchemy(StarletteAdvancedAlchemy):
"""AdvancedAlchemy integration for FastAPI applications.
This class manages SQLAlchemy sessions and engine lifecycle within a FastAPI application.
It provides middleware for handling transactions based on commit strategies.
"""
def __init__(
self,
config: "Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig, Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]]",
app: "Optional[FastAPI]" = None,
) -> None:
super().__init__(config, app)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/flask/ 0000775 0000000 0000000 00000000000 14766637146 0025007 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/flask/__init__.py 0000664 0000000 0000000 00000002326 14766637146 0027123 0 ustar 00root root 0000000 0000000 """Flask extension for Advanced Alchemy.
This module provides Flask integration for Advanced Alchemy, including session management,
database migrations, and service utilities.
"""
from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig
from advanced_alchemy.extensions.flask.cli import get_database_migration_plugin
from advanced_alchemy.extensions.flask.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.flask.extension import AdvancedAlchemy
from advanced_alchemy.extensions.flask.utils import FlaskServiceMixin
__all__ = (
"AdvancedAlchemy",
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"FlaskServiceMixin",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"base",
"exceptions",
"filters",
"get_database_migration_plugin",
"mixins",
"operations",
"repository",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/flask/cli.py 0000664 0000000 0000000 00000003245 14766637146 0026134 0 ustar 00root root 0000000 0000000 """Command-line interface utilities for Flask integration.
This module provides CLI commands for database management in Flask applications.
"""
from contextlib import suppress
from typing import TYPE_CHECKING, cast
from flask.cli import with_appcontext
from advanced_alchemy.cli import add_migration_commands
try:
import rich_click as click
except ImportError:
import click # type: ignore[no-redef]
if TYPE_CHECKING:
from flask import Flask
from advanced_alchemy.extensions.flask.extension import AdvancedAlchemy
def get_database_migration_plugin(app: "Flask") -> "AdvancedAlchemy":
"""Retrieve the Advanced Alchemy extension from the Flask application.
Args:
app: The :class:`flask.Flask` application instance.
Returns:
:class:`AdvancedAlchemy`: The Advanced Alchemy extension instance.
Raises:
:exc:`advanced_alchemy.exceptions.ImproperConfigurationError`: If the extension is not found.
"""
from advanced_alchemy.exceptions import ImproperConfigurationError
with suppress(KeyError):
return cast("AdvancedAlchemy", app.extensions["advanced_alchemy"])
msg = "Failed to initialize database migrations. The Advanced Alchemy extension is not properly configured."
raise ImproperConfigurationError(msg)
@click.group(name="database")
@with_appcontext
def database_group() -> None:
"""Manage SQLAlchemy database components.
This command group provides database management commands like migrations.
"""
ctx = click.get_current_context()
app = ctx.obj.load_app()
ctx.obj = {"app": app, "configs": get_database_migration_plugin(app).config}
add_migration_commands(database_group)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/flask/config.py 0000664 0000000 0000000 00000024601 14766637146 0026631 0 ustar 00root root 0000000 0000000 """Configuration classes for Flask integration.
This module provides configuration classes for integrating SQLAlchemy with Flask applications,
including both synchronous and asynchronous database configurations.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
from click import echo
from flask import g, has_request_context
from sqlalchemy.exc import OperationalError
from typing_extensions import Literal
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.service import schema_dump
if TYPE_CHECKING:
from flask import Flask, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from advanced_alchemy.utils.portals import Portal
__all__ = ("EngineConfig", "SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig")
ConfigT = TypeVar("ConfigT", bound="Union[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig]")
def serializer(value: "Any") -> str:
"""Serialize JSON field values.
Calls the `:func:schema_dump` function to convert the value to a built-in before encoding.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(schema_dump(value))
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's Engine.
This class extends the base EngineConfig with Flask-specific JSON serialization options.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
Attributes:
json_deserializer: Callable for converting JSON strings to Python objects.
json_serializer: Callable for converting Python objects to JSON strings.
"""
json_deserializer: "Callable[[str], Any]" = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object."""
json_serializer: "Callable[[Any], str]" = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON."""
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""Flask-specific synchronous SQLAlchemy configuration.
Attributes:
app: The Flask application instance.
commit_mode: The commit mode to use for database sessions.
"""
app: "Optional[Flask]" = None
"""The Flask application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
def create_session_maker(self) -> "Callable[[], Session]":
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
def init_app(self, app: "Flask", portal: "Optional[Portal]" = None) -> None:
"""Initialize the Flask application with this configuration.
Args:
app: The Flask application instance.
portal: The portal to use for thread-safe communication. Unused in synchronous configurations.
"""
self.app = app
self.bind_key = self.bind_key or "default"
if self.create_all:
self.create_all_metadata()
if self.commit_mode != "manual":
self._setup_session_handling(app)
def _setup_session_handling(self, app: "Flask") -> None:
"""Set up the session handling for the Flask application.
Args:
app: The Flask application instance.
"""
@app.after_request
def handle_db_session(response: "Response") -> "Response": # pyright: ignore[reportUnusedFunction]
"""Commit the session if the response meets the commit criteria."""
if not has_request_context():
return response
db_session = cast("Optional[Session]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None))
if db_session is not None:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
db_session.commit()
db_session.close()
return response
def close_engines(self, portal: "Portal") -> None:
"""Close the engines.
Args:
portal: The portal to use for thread-safe communication.
"""
if self.engine_instance is not None:
self.engine_instance.dispose()
def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
with self.engine_instance.begin() as conn:
try:
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all(conn)
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
else:
echo(" * Created target metadata.")
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""Flask-specific asynchronous SQLAlchemy configuration.
Attributes:
app: The Flask application instance.
commit_mode: The commit mode to use for database sessions.
"""
app: "Optional[Flask]" = None
"""The Flask application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
def create_session_maker(self) -> "Callable[[], AsyncSession]":
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], AsyncSession]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
def init_app(self, app: "Flask", portal: "Optional[Portal]" = None) -> None:
"""Initialize the Flask application with this configuration.
Args:
app: The Flask application instance.
portal: The portal to use for thread-safe communication.
Raises:
ImproperConfigurationError: If portal is not provided for async configuration.
"""
self.app = app
self.bind_key = self.bind_key or "default"
if portal is None:
msg = "Portal is required for asynchronous configurations"
raise ImproperConfigurationError(msg)
if self.create_all:
_ = portal.call(self.create_all_metadata)
self._setup_session_handling(app, portal)
def _setup_session_handling(self, app: "Flask", portal: "Portal") -> None:
"""Set up the session handling for the Flask application.
Args:
app: The Flask application instance.
portal: The portal to use for thread-safe communication.
"""
@app.after_request
def handle_db_session(response: "Response") -> "Response": # pyright: ignore[reportUnusedFunction]
"""Commit the session if the response meets the commit criteria."""
if not has_request_context():
return response
db_session = cast("Optional[AsyncSession]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None))
if db_session is not None:
p = getattr(db_session, "_session_portal", None) or portal
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
_ = p.call(db_session.commit)
_ = p.call(db_session.close)
return response
@app.teardown_appcontext
def close_db_session(_: "Optional[BaseException]" = None) -> None: # pyright: ignore[reportUnusedFunction]
"""Close the session at the end of the request."""
db_session = cast("Optional[AsyncSession]", g.pop(f"advanced_alchemy_session_{self.bind_key}", None))
if db_session is not None:
p = getattr(db_session, "_session_portal", None) or portal
_ = p.call(db_session.close)
def close_engines(self, portal: "Portal") -> None:
"""Close the engines.
Args:
portal: The portal to use for thread-safe communication.
"""
if self.engine_instance is not None:
_ = portal.call(self.engine_instance.dispose)
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
async with self.engine_instance.begin() as conn:
try:
await conn.run_sync(
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all
)
await conn.commit()
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
else:
echo(" * Created target metadata.")
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/flask/extension.py 0000664 0000000 0000000 00000016303 14766637146 0027400 0 ustar 00root root 0000000 0000000 # ruff: noqa: SLF001, ARG001
"""Flask extension for Advanced Alchemy."""
from collections.abc import Generator, Sequence
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
from flask import g
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.flask.cli import database_group
from advanced_alchemy.extensions.flask.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.utils.portals import Portal, PortalProvider
if TYPE_CHECKING:
from flask import Flask
class AdvancedAlchemy:
"""Flask extension for Advanced Alchemy."""
__slots__ = (
"_config",
"_has_async_config",
"_session_makers",
"portal_provider",
)
def __init__(
self,
config: "Union[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig, Sequence[Union[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig]]]",
app: "Optional[Flask]" = None,
*,
portal_provider: "Optional[PortalProvider]" = None,
) -> None:
"""Initialize the extension."""
self.portal_provider = portal_provider if portal_provider is not None else PortalProvider()
self._config = config if isinstance(config, Sequence) else [config]
self._has_async_config = any(isinstance(c, SQLAlchemyAsyncConfig) for c in self.config)
self._session_makers: dict[str, Callable[..., Union[AsyncSession, Session]]] = {}
if app is not None:
self.init_app(app)
@property
def portal(self) -> "Portal":
"""Get the portal."""
return self.portal_provider.portal
@property
def config(self) -> "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]":
"""Get the SQLAlchemy configuration(s)."""
return self._config
@property
def is_async_enabled(self) -> bool:
"""Return True if any of the database configs are async."""
return self._has_async_config
def init_app(self, app: "Flask") -> None:
"""Initialize the Flask application.
Args:
app: The Flask application to initialize.
Raises:
ImproperConfigurationError: If the extension is already registered on the Flask application.
"""
if "advanced_alchemy" in app.extensions:
msg = "Advanced Alchemy extension is already registered on this Flask application."
raise ImproperConfigurationError(msg)
if self._has_async_config:
self.portal_provider.start()
# Create tables for async configs
for cfg in self._config:
if isinstance(cfg, SQLAlchemyAsyncConfig):
self.portal_provider.portal.call(cfg.create_all_metadata)
# Register shutdown handler for the portal
@app.teardown_appcontext
def shutdown_portal(exception: "Optional[BaseException]" = None) -> None: # pyright: ignore[reportUnusedFunction]
"""Stop the portal when the application shuts down."""
if not app.debug: # Don't stop portal in debug mode
with suppress(Exception):
self.portal_provider.stop()
# Initialize each config with the app
for config in self.config:
config.init_app(app, self.portal_provider.portal)
bind_key = config.bind_key if config.bind_key is not None else "default"
session_maker = config.create_session_maker()
self._session_makers[bind_key] = session_maker
# Register session cleanup only
app.teardown_appcontext(self._teardown_appcontext)
app.extensions["advanced_alchemy"] = self
app.cli.add_command(database_group)
def _teardown_appcontext(self, exception: "Optional[BaseException]" = None) -> None:
"""Clean up resources when the application context ends."""
for key in list(g):
if key.startswith("advanced_alchemy_session_"):
session = getattr(g, key)
if isinstance(session, AsyncSession):
# Close async sessions through the portal
with suppress(ImproperConfigurationError):
self.portal_provider.portal.call(session.close)
else:
session.close()
delattr(g, key)
def get_session(self, bind_key: str = "default") -> "Union[AsyncSession, Session]":
"""Get a new session from the configured session factory.
Args:
bind_key: The bind key to use for the session.
Returns:
A new session from the configured session factory.
Raises:
ImproperConfigurationError: If no session maker is found for the bind key.
"""
if bind_key == "default" and len(self.config) == 1:
bind_key = self.config[0].bind_key if self.config[0].bind_key is not None else "default"
session_key = f"advanced_alchemy_session_{bind_key}"
if hasattr(g, session_key):
return cast("Union[AsyncSession, Session]", getattr(g, session_key))
session_maker = self._session_makers.get(bind_key)
if session_maker is None:
msg = f'No session maker found for bind key "{bind_key}"'
raise ImproperConfigurationError(msg)
session = session_maker()
if self._has_async_config:
# Ensure portal is started
if not self.portal_provider.is_running:
self.portal_provider.start()
setattr(session, "_session_portal", self.portal_provider.portal)
setattr(g, session_key, session)
return session
def get_async_session(self, bind_key: str = "default") -> AsyncSession:
"""Get an async session from the configured session factory."""
session = self.get_session(bind_key)
if not isinstance(session, AsyncSession):
msg = f"Expected async session for bind key {bind_key}, but got {type(session)}"
raise ImproperConfigurationError(msg)
return session
def get_sync_session(self, bind_key: str = "default") -> Session:
"""Get a sync session from the configured session factory."""
session = self.get_session(bind_key)
if not isinstance(session, Session):
msg = f"Expected sync session for bind key {bind_key}, but got {type(session)}"
raise ImproperConfigurationError(msg)
return session
@contextmanager
def with_session( # pragma: no cover (more on this later)
self, bind_key: str = "default"
) -> "Generator[Union[AsyncSession, Session], None, None]":
"""Provide a transactional scope around a series of operations.
Args:
bind_key: The bind key to use for the session.
Yields:
A session.
"""
session = self.get_session(bind_key)
try:
yield session
finally:
if isinstance(session, AsyncSession):
with suppress(ImproperConfigurationError):
self.portal_provider.portal.call(session.close)
else:
session.close()
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/flask/utils.py 0000664 0000000 0000000 00000002245 14766637146 0026524 0 ustar 00root root 0000000 0000000 """Flask-specific service classes.
This module provides Flask-specific service mixins and utilities for integrating
with the Advanced Alchemy service layer.
"""
from typing import Any
from flask import Response, current_app
from advanced_alchemy.extensions.flask.config import serializer
class FlaskServiceMixin:
"""Flask service mixin.
This mixin provides Flask-specific functionality for services.
"""
def jsonify(
self,
data: Any,
*args: Any,
status_code: int = 200,
**kwargs: Any,
) -> Response:
"""Convert data to a Flask JSON response.
Args:
data: Data to serialize to JSON.
*args: Additional positional arguments passed to Flask's response class.
status_code: HTTP status code for the response. Defaults to 200.
**kwargs: Additional keyword arguments passed to Flask's response class.
Returns:
:class:`flask.Response`: A Flask response with JSON content type.
"""
return current_app.response_class(
serializer(data),
status=status_code,
mimetype="application/json",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/ 0000775 0000000 0000000 00000000000 14766637146 0025536 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/__init__.py 0000664 0000000 0000000 00000005302 14766637146 0027647 0 ustar 00root root 0000000 0000000 from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig
from advanced_alchemy.extensions.litestar import providers
from advanced_alchemy.extensions.litestar.cli import get_database_migration_plugin
from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig
from advanced_alchemy.extensions.litestar.plugins import (
EngineConfig,
SQLAlchemyAsyncConfig,
SQLAlchemyInitPlugin,
SQLAlchemyPlugin,
SQLAlchemySerializationPlugin,
SQLAlchemySyncConfig,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
autocommit_before_send_handler as async_autocommit_before_send_handler,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
autocommit_handler_maker as async_autocommit_handler_maker,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
default_before_send_handler as async_default_before_send_handler,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
default_handler_maker as async_default_handler_maker,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
autocommit_before_send_handler as sync_autocommit_before_send_handler,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
autocommit_handler_maker as sync_autocommit_handler_maker,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
default_before_send_handler as sync_default_before_send_handler,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
default_handler_maker as sync_default_handler_maker,
)
__all__ = (
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemyDTO",
"SQLAlchemyDTOConfig",
"SQLAlchemyInitPlugin",
"SQLAlchemyPlugin",
"SQLAlchemySerializationPlugin",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"async_autocommit_before_send_handler",
"async_autocommit_handler_maker",
"async_default_before_send_handler",
"async_default_handler_maker",
"base",
"exceptions",
"filters",
"get_database_migration_plugin",
"mixins",
"operations",
"providers",
"repository",
"service",
"sync_autocommit_before_send_handler",
"sync_autocommit_handler_maker",
"sync_default_before_send_handler",
"sync_default_handler_maker",
"types",
"utils",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/_utils.py 0000664 0000000 0000000 00000003637 14766637146 0027420 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from litestar.types import Scope
__all__ = (
"delete_aa_scope_state",
"get_aa_scope_state",
"set_aa_scope_state",
)
_SCOPE_NAMESPACE = "_aa_connection_state"
def get_aa_scope_state(scope: "Scope", key: str, default: Any = None, pop: bool = False) -> Any:
"""Get an internal value from connection scope state.
Note:
If called with a default value, this method behaves like to `dict.set_default()`, both setting the key in the
namespace to the default value, and returning it.
If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
exist.
Args:
scope: The connection scope.
key: Key to get from internal namespace in scope state.
default: Default value to return.
pop: Boolean flag dictating whether the value should be deleted from the state.
Returns:
Value mapped to ``key`` in internal connection scope namespace.
"""
namespace = scope.setdefault(_SCOPE_NAMESPACE, {}) # type: ignore[misc]
return namespace.pop(key, default) if pop else namespace.get(key, default) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
def set_aa_scope_state(scope: "Scope", key: str, value: Any) -> None:
"""Set an internal value in connection scope state.
Args:
scope: The connection scope.
key: Key to set under internal namespace in scope state.
value: Value for key.
"""
scope.setdefault(_SCOPE_NAMESPACE, {})[key] = value # type: ignore[misc]
def delete_aa_scope_state(scope: "Scope", key: str) -> None:
"""Delete an internal value from connection scope state.
Args:
scope: The connection scope.
key: Key to set under internal namespace in scope state.
"""
del scope.setdefault(_SCOPE_NAMESPACE, {})[key] # type: ignore[misc]
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/cli.py 0000664 0000000 0000000 00000002360 14766637146 0026660 0 ustar 00root root 0000000 0000000 from contextlib import suppress
from typing import TYPE_CHECKING
from litestar.cli._utils import LitestarGroup
from advanced_alchemy.cli import add_migration_commands
try:
import rich_click as click
except ImportError:
import click # type: ignore[no-redef]
if TYPE_CHECKING:
from litestar import Litestar
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin
def get_database_migration_plugin(app: "Litestar") -> "SQLAlchemyInitPlugin":
"""Retrieve a database migration plugin from the Litestar application's plugins."""
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin
with suppress(KeyError):
return app.plugins.get(SQLAlchemyInitPlugin)
msg = "Failed to initialize database migrations. The required plugin (SQLAlchemyPlugin or SQLAlchemyInitPlugin) is missing."
raise ImproperConfigurationError(msg)
@click.group(cls=LitestarGroup, name="database")
def database_group(ctx: "click.Context") -> None:
"""Manage SQLAlchemy database components."""
ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
add_migration_commands(database_group)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/dto.py 0000664 0000000 0000000 00000046067 14766637146 0026713 0 ustar 00root root 0000000 0000000 from collections.abc import Collection, Generator
from collections.abc import Set as AbstractSet
from dataclasses import asdict, dataclass, field, replace
from functools import singledispatchmethod
from typing import (
Any,
ClassVar,
Generic,
Literal,
Optional,
Union,
)
from litestar.dto.base_dto import AbstractDTO
from litestar.dto.config import DTOConfig
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField, Mark
from litestar.types.empty import Empty
from litestar.typing import FieldDefinition
from litestar.utils.signature import ParsedSignature
from sqlalchemy import Column, inspect, orm, sql
from sqlalchemy.ext.associationproxy import AssociationProxy, AssociationProxyExtensionType
from sqlalchemy.ext.hybrid import HybridExtensionType, hybrid_property
from sqlalchemy.orm import (
ColumnProperty,
CompositeProperty,
DeclarativeBase,
DynamicMapped,
InspectionAttr,
InstrumentedAttribute,
Mapped,
MappedColumn,
NotExtension,
QueryableAttribute,
Relationship,
RelationshipDirection,
RelationshipProperty,
WriteOnlyMapped,
)
from sqlalchemy.sql.expression import ColumnClause, Label
from typing_extensions import TypeAlias, TypeVar
from advanced_alchemy.exceptions import ImproperConfigurationError
__all__ = ("SQLAlchemyDTO",)
T = TypeVar("T", bound="Union[DeclarativeBase, Collection[DeclarativeBase]]")
ElementType: TypeAlias = Union[
"Column[Any]", "RelationshipProperty[Any]", "CompositeProperty[Any]", "ColumnClause[Any]", "Label[Any]"
]
SQLA_NS = {**vars(orm), **vars(sql)}
@dataclass(frozen=True)
class SQLAlchemyDTOConfig(DTOConfig):
"""Additional controls for the generated SQLAlchemy DTO."""
exclude: AbstractSet[Union[str, InstrumentedAttribute[Any]]] = field(default_factory=set) # type: ignore[assignment] # pyright: ignore[reportIncompatibleVariableOverride]
"""Explicitly exclude fields from the generated DTO.
If exclude is specified, all fields not specified in exclude will be included by default.
Notes:
- The field names are dot-separated paths to nested fields, e.g. ``"address.street"`` will
exclude the ``"street"`` field from a nested ``"address"`` model.
- 'exclude' mutually exclusive with 'include' - specifying both values will raise an
``ImproperlyConfiguredException``.
"""
include: AbstractSet[Union[str, InstrumentedAttribute[Any]]] = field(default_factory=set) # type: ignore[assignment] # pyright: ignore[reportIncompatibleVariableOverride]
"""Explicitly include fields in the generated DTO.
If include is specified, all fields not specified in include will be excluded by default.
Notes:
- The field names are dot-separated paths to nested fields, e.g. ``"address.street"`` will
include the ``"street"`` field from a nested ``"address"`` model.
- 'include' mutually exclusive with 'exclude' - specifying both values will raise an
``ImproperlyConfiguredException``.
"""
rename_fields: dict[Union[str, InstrumentedAttribute[Any]], str] = field(default_factory=dict) # type: ignore[assignment] # pyright: ignore[reportIncompatibleVariableOverride]
"""Mapping of field names, to new name."""
include_implicit_fields: Union[bool, Literal["hybrid-only"]] = True
"""Fields that are implicitly mapped are included.
Turning this off will lead to exclude all fields not using ``Mapped`` annotation,
When setting this to ``hybrid-only``, all implicitly mapped fields are excluded
with the exception for hybrid properties.
"""
def __post_init__(self) -> None:
super().__post_init__()
object.__setattr__(
self, "exclude", {f.key if isinstance(f, InstrumentedAttribute) else f for f in self.exclude}
)
object.__setattr__(
self, "include", {f.key if isinstance(f, InstrumentedAttribute) else f for f in self.include}
)
object.__setattr__(
self,
"rename_fields",
{f.key if isinstance(f, InstrumentedAttribute) else f: v for f, v in self.rename_fields.items()},
)
class SQLAlchemyDTO(AbstractDTO[T], Generic[T]):
"""Support for domain modelling with SQLAlchemy."""
config: ClassVar[SQLAlchemyDTOConfig]
@staticmethod
def _ensure_sqla_dto_config(config: Union[DTOConfig, SQLAlchemyDTOConfig]) -> SQLAlchemyDTOConfig:
if not isinstance(config, SQLAlchemyDTOConfig):
return SQLAlchemyDTOConfig(**asdict(config))
return config
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if hasattr(cls, "config"):
cls.config = cls._ensure_sqla_dto_config(cls.config) # pyright: ignore[reportIncompatibleVariableOverride]
@singledispatchmethod
@classmethod
def handle_orm_descriptor(
cls,
extension_type: Union[NotExtension, AssociationProxyExtensionType, HybridExtensionType],
orm_descriptor: InspectionAttr,
key: str,
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
msg = f"Unsupported extension type: {extension_type}"
raise NotImplementedError(msg)
@handle_orm_descriptor.register(NotExtension)
@classmethod
def _(
cls,
extension_type: NotExtension,
key: str,
orm_descriptor: InspectionAttr,
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
if not isinstance(orm_descriptor, QueryableAttribute): # pragma: no cover
msg = f"Unexpected descriptor type for '{extension_type}': '{orm_descriptor}'"
raise NotImplementedError(msg)
elem: ElementType
if isinstance(
orm_descriptor.property, # pyright: ignore[reportUnknownMemberType]
ColumnProperty, # pragma: no cover
):
if not isinstance(
orm_descriptor.property.expression, # pyright: ignore[reportUnknownMemberType]
(Column, ColumnClause, Label),
):
msg = f"Expected 'Column', got: '{orm_descriptor.property.expression}, {type(orm_descriptor.property.expression)}'" # pyright: ignore[reportUnknownMemberType]
raise NotImplementedError(msg)
elem = orm_descriptor.property.expression # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
elif isinstance(orm_descriptor.property, (RelationshipProperty, CompositeProperty)): # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
elem = orm_descriptor.property # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
else: # pragma: no cover
msg = f"Unhandled property type: '{orm_descriptor.property}'" # pyright: ignore[reportUnknownMemberType]
raise NotImplementedError(msg)
default, default_factory = _detect_defaults(elem)
try:
if (field_definition := model_type_hints[key]).origin in {
Mapped,
WriteOnlyMapped,
DynamicMapped,
Relationship,
}:
(field_definition,) = field_definition.inner_types
else: # pragma: no cover
msg = f"Expected 'Mapped' origin, got: '{field_definition.origin}'"
raise NotImplementedError(msg)
except KeyError:
field_definition = parse_type_from_element(elem, orm_descriptor) # pyright: ignore[reportUnknownArgumentType]
dto_field = elem.info.get(DTO_FIELD_META_KEY) if hasattr(elem, "info") else None # pyright: ignore[reportArgumentMemberType]
if dto_field is None and isinstance(orm_descriptor, InstrumentedAttribute) and hasattr(orm_descriptor, "info"): # pyright: ignore[reportUnknownArgumentType]
dto_field = orm_descriptor.info.get(DTO_FIELD_META_KEY) # pyright: ignore[reportArgumentMemberType]
if dto_field is None:
dto_field = DTOField()
return [
DTOFieldDefinition.from_field_definition(
field_definition=replace(
field_definition,
name=key,
default=default,
),
default_factory=default_factory,
dto_field=dto_field,
model_name=model_name,
),
]
@handle_orm_descriptor.register(AssociationProxyExtensionType)
@classmethod
def _(
cls,
extension_type: AssociationProxyExtensionType,
key: str,
orm_descriptor: InspectionAttr,
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
if not isinstance(orm_descriptor, AssociationProxy): # pragma: no cover
msg = f"Unexpected descriptor type '{orm_descriptor}' for '{extension_type}'"
raise NotImplementedError(msg)
if (field_definition := model_type_hints[key]).origin is AssociationProxy:
(field_definition,) = field_definition.inner_types
else: # pragma: no cover
msg = f"Expected 'AssociationProxy' origin, got: '{field_definition.origin}'"
raise NotImplementedError(msg)
return [
DTOFieldDefinition.from_field_definition(
field_definition=replace(
field_definition,
name=key,
default=Empty,
),
default_factory=None,
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.READ_ONLY)),
model_name=model_name,
),
]
@handle_orm_descriptor.register(HybridExtensionType)
@classmethod
def _(
cls,
extension_type: HybridExtensionType,
key: str,
orm_descriptor: InspectionAttr,
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
if not isinstance(orm_descriptor, hybrid_property):
msg = f"Unexpected descriptor type '{orm_descriptor}' for '{extension_type}'"
raise NotImplementedError(msg)
getter_sig = ParsedSignature.from_fn(orm_descriptor.fget, {}) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType,reportAttributeAccessIssue]
field_defs = [
DTOFieldDefinition.from_field_definition(
field_definition=replace(
getter_sig.return_type,
name=orm_descriptor.__name__,
default=Empty,
),
default_factory=None,
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.READ_ONLY)),
model_name=model_name,
),
]
if orm_descriptor.fset is not None: # pyright: ignore[reportUnknownMemberType]
setter_sig = ParsedSignature.from_fn(orm_descriptor.fset, {}) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType]
field_defs.append(
DTOFieldDefinition.from_field_definition(
field_definition=replace(
next(iter(setter_sig.parameters.values())),
name=orm_descriptor.__name__,
default=Empty,
),
default_factory=None,
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.WRITE_ONLY)),
model_name=model_name,
),
)
return field_defs
@classmethod
def generate_field_definitions(cls, model_type: type[DeclarativeBase]) -> Generator[DTOFieldDefinition, None, None]:
"""Generate DTO field definitions from a SQLAlchemy model.
Args:
model_type (typing.Type[sqlalchemy.orm.DeclarativeBase]): The SQLAlchemy model type to generate field definitions from.
Yields:
collections.abc.Generator[litestar.dto.data_structures.DTOFieldDefinition, None, None]: A generator yielding DTO field definitions.
Raises:
RuntimeError: If the mapper cannot be found for the model type.
NotImplementedError: If an unsupported property or extension type is encountered.
ImproperConfigurationError: If a type cannot be parsed from an element.
"""
if (mapper := inspect(model_type)) is None: # pragma: no cover # pyright: ignore[reportUnnecessaryComparison]
msg = "Unexpected `None` value for mapper." # type: ignore[unreachable]
raise RuntimeError(msg)
# includes SQLAlchemy names and other mapped class names in the forward reference resolution namespace
namespace = {**SQLA_NS, **{m.class_.__name__: m.class_ for m in mapper.registry.mappers if m is not mapper}}
model_type_hints = cls.get_model_type_hints(model_type, namespace=namespace)
model_name = model_type.__name__
include_implicit_fields = cls.config.include_implicit_fields
# the same hybrid property descriptor can be included in `all_orm_descriptors` multiple times, once
# for each method name it is bound to. We only need to see it once, so track views of it here.
seen_hybrid_descriptors: set[hybrid_property] = set() # pyright: ignore[reportUnknownVariableType,reportMissingTypeArgument]
skipped_descriptors: set[str] = set()
for composite_property in mapper.composites: # pragma: no cover
for attr in composite_property.attrs:
if isinstance(attr, (MappedColumn, Column)):
skipped_descriptors.add(attr.name)
elif isinstance(attr, str):
skipped_descriptors.add(attr)
for key, orm_descriptor in mapper.all_orm_descriptors.items():
if is_hybrid_property := isinstance(orm_descriptor, hybrid_property):
if orm_descriptor in seen_hybrid_descriptors:
continue
seen_hybrid_descriptors.add(orm_descriptor) # pyright: ignore[reportUnknownMemberType]
if key in skipped_descriptors:
continue
should_skip_descriptor = False
dto_field: Optional[DTOField] = None
if hasattr(orm_descriptor, "property"): # pyright: ignore[reportUnknownArgumentType]
dto_field = orm_descriptor.property.info.get(DTO_FIELD_META_KEY) # pyright: ignore # noqa: PGH003
# Case 1
is_field_marked_not_private = dto_field and dto_field.mark is not Mark.PRIVATE # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
# Case 2
should_exclude_anything_implicit = not include_implicit_fields and key not in model_type_hints
# Case 3
should_exclude_non_hybrid_only = (
not is_hybrid_property and include_implicit_fields == "hybrid-only" and key not in model_type_hints
)
# Descriptor is marked with with either Mark.READ_ONLY or Mark.WRITE_ONLY (see Case 1):
# - always include it regardless of anything else.
# Descriptor is not marked:
# - It's implicit BUT config excludes anything implicit (see Case 2): exclude
# - It's implicit AND not hybrid BUT config includes hybrid-only implicit descriptors (Case 3): exclude
should_skip_descriptor = not is_field_marked_not_private and (
should_exclude_anything_implicit or should_exclude_non_hybrid_only
)
if should_skip_descriptor:
continue
yield from cls.handle_orm_descriptor(
orm_descriptor.extension_type,
key,
orm_descriptor,
model_type_hints,
model_name,
)
@classmethod
def detect_nested_field(cls, field_definition: FieldDefinition) -> bool:
return field_definition.is_subclass_of(DeclarativeBase)
def _detect_defaults(elem: ElementType) -> tuple[Any, Any]:
default: Any = Empty
default_factory: Any = None # pyright:ignore # noqa: PGH003
if sqla_default := getattr(elem, "default", None):
if sqla_default.is_scalar:
default = sqla_default.arg
elif sqla_default.is_callable:
def default_factory(d: Any = sqla_default) -> Any:
return d.arg({})
elif sqla_default.is_sequence or sqla_default.is_sentinel:
# SQLAlchemy sequences represent server side defaults
# so we cannot infer a reasonable default value for
# them on the client side
pass
else:
msg = "Unexpected default type"
raise ValueError(msg)
elif (isinstance(elem, RelationshipProperty) and detect_nullable_relationship(elem)) or getattr(
elem, "nullable", False
):
default = None
return default, default_factory
def parse_type_from_element(elem: ElementType, orm_descriptor: InspectionAttr) -> FieldDefinition: # noqa: PLR0911
"""Parses a type from a SQLAlchemy element.
Args:
elem: The SQLAlchemy element to parse.
orm_descriptor: The attribute `elem` was extracted from.
Returns:
FieldDefinition: The parsed type.
Raises:
ImproperlyConfiguredException: If the type cannot be parsed.
"""
if isinstance(elem, Column):
if elem.nullable:
return FieldDefinition.from_annotation(Optional[elem.type.python_type])
return FieldDefinition.from_annotation(elem.type.python_type)
if isinstance(elem, RelationshipProperty):
if elem.direction in (RelationshipDirection.ONETOMANY, RelationshipDirection.MANYTOMANY):
collection_type = FieldDefinition.from_annotation(elem.collection_class or list) # pyright: ignore[reportUnknownMemberType]
return FieldDefinition.from_annotation(collection_type.safe_generic_origin[elem.mapper.class_])
if detect_nullable_relationship(elem):
return FieldDefinition.from_annotation(Optional[elem.mapper.class_])
return FieldDefinition.from_annotation(elem.mapper.class_)
if isinstance(elem, CompositeProperty):
return FieldDefinition.from_annotation(elem.composite_class)
if isinstance(orm_descriptor, InstrumentedAttribute):
return FieldDefinition.from_annotation(orm_descriptor.type.python_type)
msg = f"Unable to parse type from element '{elem}'. Consider adding a type hint."
raise ImproperConfigurationError(
msg,
)
def detect_nullable_relationship(elem: RelationshipProperty[Any]) -> bool:
"""Detects if a relationship is nullable.
This attempts to decide if we should allow a ``None`` default value for a relationship by looking at the
foreign key fields. If all foreign key fields are nullable, then we allow a ``None`` default value.
Args:
elem: The relationship to check.
Returns:
bool: ``True`` if the relationship is nullable, ``False`` otherwise.
"""
return elem.direction == RelationshipDirection.MANYTOONE and all(c.nullable for c in elem.local_columns)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/exception_handler.py 0000664 0000000 0000000 00000003305 14766637146 0031604 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any
from litestar.connection import Request
from litestar.connection.base import AuthT, StateT, UserT
from litestar.exceptions import (
ClientException,
HTTPException,
InternalServerException,
NotFoundException,
)
from litestar.exceptions.responses import (
create_debug_response, # pyright: ignore[reportUnknownVariableType]
create_exception_response, # pyright: ignore[reportUnknownVariableType]
)
from litestar.response import Response
from litestar.status_codes import (
HTTP_409_CONFLICT,
)
from advanced_alchemy.exceptions import (
DuplicateKeyError,
ForeignKeyError,
IntegrityError,
NotFoundError,
RepositoryError,
)
if TYPE_CHECKING:
from litestar.connection import Request
from litestar.connection.base import AuthT, StateT, UserT
from litestar.response import Response
class ConflictError(ClientException):
"""Request conflict with the current state of the target resource."""
status_code: int = HTTP_409_CONFLICT
def exception_to_http_response(request: "Request[UserT, AuthT, StateT]", exc: "RepositoryError") -> "Response[Any]":
"""Handler for all exceptions subclassed from HTTPException."""
if isinstance(exc, NotFoundError):
http_exc: type[HTTPException] = NotFoundException
elif isinstance(exc, (DuplicateKeyError, IntegrityError, ForeignKeyError)):
http_exc = ConflictError
else:
http_exc = InternalServerException
if request.app.debug:
return create_debug_response(request, exc) # pyright: ignore[reportUnknownVariableType]
return create_exception_response(request, http_exc(detail=str(exc.detail))) # pyright: ignore[reportUnknownVariableType]
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/ 0000775 0000000 0000000 00000000000 14766637146 0027217 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/__init__.py 0000664 0000000 0000000 00000003306 14766637146 0031332 0 ustar 00root root 0000000 0000000 from collections.abc import Sequence
from typing import Union
from litestar.config.app import AppConfig
from litestar.plugins import InitPluginProtocol
from advanced_alchemy.extensions.litestar.plugins import _slots_base
from advanced_alchemy.extensions.litestar.plugins.init import (
EngineConfig,
SQLAlchemyAsyncConfig,
SQLAlchemyInitPlugin,
SQLAlchemySyncConfig,
)
from advanced_alchemy.extensions.litestar.plugins.serialization import SQLAlchemySerializationPlugin
class SQLAlchemyPlugin(InitPluginProtocol, _slots_base.SlotsBase):
"""A plugin that provides SQLAlchemy integration."""
def __init__(
self,
config: Union[
SQLAlchemyAsyncConfig, SQLAlchemySyncConfig, Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]
],
) -> None:
"""Initialize ``SQLAlchemyPlugin``.
Args:
config: configure DB connection and hook handlers and dependencies.
"""
self._config = config if isinstance(config, Sequence) else [config]
@property
def config(
self,
) -> Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]:
return self._config
def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""Configure application for use with SQLAlchemy.
Args:
app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
"""
app_config.plugins.extend([SQLAlchemyInitPlugin(config=self._config), SQLAlchemySerializationPlugin()])
return app_config
__all__ = (
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemyInitPlugin",
"SQLAlchemyPlugin",
"SQLAlchemySerializationPlugin",
"SQLAlchemySyncConfig",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/_slots_base.py 0000664 0000000 0000000 00000000434 14766637146 0032067 0 ustar 00root root 0000000 0000000 """Base class that aggregates slots for all SQLAlchemy plugins.
See: https://stackoverflow.com/questions/53060607/python-3-6-5-multiple-bases-have-instance-lay-out-conflict-when-multi-inherit
"""
class SlotsBase:
__slots__ = (
"_config",
"_type_dto_map",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/ 0000775 0000000 0000000 00000000000 14766637146 0030162 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/__init__.py 0000664 0000000 0000000 00000000542 14766637146 0032274 0 ustar 00root root 0000000 0000000 from advanced_alchemy.extensions.litestar.plugins.init.config import (
EngineConfig,
SQLAlchemyAsyncConfig,
SQLAlchemySyncConfig,
)
from advanced_alchemy.extensions.litestar.plugins.init.plugin import SQLAlchemyInitPlugin
__all__ = (
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemyInitPlugin",
"SQLAlchemySyncConfig",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/config/ 0000775 0000000 0000000 00000000000 14766637146 0031427 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/config/__init__.py 0000664 0000000 0000000 00000000567 14766637146 0033550 0 ustar 00root root 0000000 0000000 from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.engine import EngineConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import SQLAlchemySyncConfig
__all__ = (
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/config/asyncio.py 0000664 0000000 0000000 00000026227 14766637146 0033457 0 ustar 00root root 0000000 0000000 from collections.abc import AsyncGenerator, Coroutine
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
from litestar.cli._utils import console
from litestar.constants import HTTP_RESPONSE_START
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.extensions.litestar._utils import (
delete_aa_scope_state,
get_aa_scope_state,
set_aa_scope_state,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.common import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.engine import EngineConfig
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Coroutine
from litestar import Litestar
from litestar.datastructures.state import State
from litestar.types import BeforeMessageSendHookHandler, Message, Scope
# noinspection PyUnresolvedReferences
__all__ = (
"SQLAlchemyAsyncConfig",
"autocommit_before_send_handler",
"autocommit_handler_maker",
"default_before_send_handler",
"default_handler_maker",
)
def default_handler_maker(
session_scope_key: str = SESSION_SCOPE_KEY,
) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
session_scope_key: The key to use within the application state
Returns:
The handler callable
"""
async def handler(message: "Message", scope: "Scope") -> None:
"""Handle commit/rollback, closing and cleaning up sessions before sending.
Args:
message: ASGI-``Message``
scope: An ASGI-``Scope``
Returns:
None
"""
session = cast("Optional[AsyncSession]", get_aa_scope_state(scope, session_scope_key))
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
await session.close()
delete_aa_scope_state(scope, session_scope_key)
return handler
default_before_send_handler = default_handler_maker()
def autocommit_handler_maker(
commit_on_redirect: bool = False,
extra_commit_statuses: Optional[set[int]] = None,
extra_rollback_statuses: Optional[set[int]] = None,
session_scope_key: str = SESSION_SCOPE_KEY,
) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
extra_commit_statuses: A set of additional status codes that trigger a commit
extra_rollback_statuses: A set of additional status codes that trigger a rollback
session_scope_key: The key to use within the application state
Returns:
The handler callable
"""
if extra_commit_statuses is None:
extra_commit_statuses = set()
if extra_rollback_statuses is None:
extra_rollback_statuses = set()
if len(extra_commit_statuses & extra_rollback_statuses) > 0:
msg = "Extra rollback statuses and commit statuses must not share any status codes"
raise ValueError(msg)
commit_range = range(200, 400 if commit_on_redirect else 300)
async def handler(message: "Message", scope: "Scope") -> None:
"""Handle commit/rollback, closing and cleaning up sessions before sending.
Args:
message: ASGI-``litestar.types.Message``
scope: An ASGI-``litestar.types.Scope``
Returns:
None
"""
session = cast("Optional[AsyncSession]", get_aa_scope_state(scope, session_scope_key))
try:
if session is not None and message["type"] == HTTP_RESPONSE_START:
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
"status"
] not in extra_rollback_statuses:
await session.commit()
else:
await session.rollback()
finally:
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
await session.close()
delete_aa_scope_state(scope, session_scope_key)
return handler
autocommit_before_send_handler = autocommit_handler_maker()
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""Litestar Async SQLAlchemy Configuration."""
before_send_handler: Optional[
Union["BeforeMessageSendHookHandler", Literal["autocommit", "autocommit_include_redirects"]]
] = None
"""Handler to call before the ASGI message is sent.
The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and
uncommitted data.
"""
engine_dependency_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_dependency_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
engine_app_state_key: str = "db_engine"
"""Key under which to store the SQLAlchemy engine in the application :class:`State `
instance.
"""
session_maker_app_state_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application
:class:`State ` instance.
"""
session_scope_key: str = SESSION_SCOPE_KEY
"""Key under which to store the SQLAlchemy scope in the application."""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
set_default_exception_handler: bool = True
"""Sets the default exception handler on application start."""
def _ensure_unique(self, registry_name: str, key: str, new_key: Optional[str] = None, _iter: int = 0) -> str:
new_key = new_key if new_key is not None else key
if new_key in getattr(self.__class__, registry_name, {}):
_iter += 1
new_key = self._ensure_unique(registry_name, key, f"{key}_{_iter}", _iter)
return new_key
def __post_init__(self) -> None:
self.session_scope_key = self._ensure_unique("_SESSION_SCOPE_KEY_REGISTRY", self.session_scope_key)
self.engine_app_state_key = self._ensure_unique("_ENGINE_APP_STATE_KEY_REGISTRY", self.engine_app_state_key)
self.session_maker_app_state_key = self._ensure_unique(
"_SESSIONMAKER_APP_STATE_KEY_REGISTRY",
self.session_maker_app_state_key,
)
self.__class__._SESSION_SCOPE_KEY_REGISTRY.add(self.session_scope_key) # noqa: SLF001
self.__class__._ENGINE_APP_STATE_KEY_REGISTRY.add(self.engine_app_state_key) # noqa: SLF001
self.__class__._SESSIONMAKER_APP_STATE_KEY_REGISTRY.add(self.session_maker_app_state_key) # noqa: SLF001
if self.before_send_handler is None:
self.before_send_handler = default_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit":
self.before_send_handler = autocommit_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit_include_redirects":
self.before_send_handler = autocommit_handler_maker(
session_scope_key=self.session_scope_key,
commit_on_redirect=True,
)
super().__post_init__()
def create_session_maker(self) -> "Callable[[], AsyncSession]":
"""Get a session maker. If none exists yet, create one.
Returns:
Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if session_kws.get("bind") is None:
session_kws["bind"] = self.get_engine()
return self.session_maker_class(**session_kws) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
@asynccontextmanager
async def lifespan(
self,
app: "Litestar",
) -> "AsyncGenerator[None, None]":
deps = self.create_app_state_items()
app.state.update(deps)
try:
if self.create_all:
await self.create_all_metadata(app)
yield
finally:
if self.engine_dependency_key in deps:
engine = deps[self.engine_dependency_key]
if hasattr(engine, "dispose"):
await cast("AsyncEngine", engine).dispose()
def provide_engine(self, state: "State") -> "AsyncEngine":
"""Create an engine instance.
Args:
state: The ``Litestar.state`` instance.
Returns:
An engine instance.
"""
return cast("AsyncEngine", state.get(self.engine_app_state_key))
def provide_session(self, state: "State", scope: "Scope") -> "AsyncSession":
"""Create a session instance.
Args:
state: The ``Litestar.state`` instance.
scope: The current connection's scope.
Returns:
A session instance.
"""
session = cast("Optional[AsyncSession]", get_aa_scope_state(scope, self.session_scope_key))
if session is None:
session_maker = cast("Callable[[], AsyncSession]", state[self.session_maker_app_state_key])
session = session_maker()
set_aa_scope_state(scope, self.session_scope_key, session)
return session
@property
def signature_namespace(self) -> dict[str, Any]:
"""Return the plugin's signature namespace.
Returns:
A string keyed dict of names to be added to the namespace for signature forward reference resolution.
"""
return {"AsyncEngine": AsyncEngine, "AsyncSession": AsyncSession}
async def create_all_metadata(self, app: "Litestar") -> None:
"""Create all metadata
Args:
app (Litestar): The ``Litestar`` instance
"""
async with self.get_engine().begin() as conn:
try:
await conn.run_sync(metadata_registry.get(self.bind_key).create_all)
except OperationalError as exc:
console.print(f"[bold red] * Could not create target metadata. Reason: {exc}")
def create_app_state_items(self) -> dict[str, Any]:
"""Key/value pairs to be stored in application state."""
return {
self.engine_app_state_key: self.get_engine(),
self.session_maker_app_state_key: self.create_session_maker(),
}
def update_app_state(self, app: "Litestar") -> None:
"""Set the app state with engine and session.
Args:
app: The ``Litestar`` instance.
"""
app.state.update(self.create_app_state_items())
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/config/common.py 0000664 0000000 0000000 00000000521 14766637146 0033267 0 ustar 00root root 0000000 0000000 from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
SESSION_SCOPE_KEY = "_sqlalchemy_db_session"
"""Session scope key."""
SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
"""ASGI events that terminate a session scope."""
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/config/engine.py 0000664 0000000 0000000 00000002230 14766637146 0033243 0 ustar 00root root 0000000 0000000 from dataclasses import dataclass
from typing import Any, Callable
from litestar.serialization import decode_json, encode_json
from advanced_alchemy.config import EngineConfig as _EngineConfig
__all__ = ("EngineConfig",)
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Args:
value: Any json serializable value.
Returns:
JSON string.
"""
return encode_json(value).decode("utf-8")
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's :class:`Engine `.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
"""
json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`JSON ` datatype, this is a Python callable that will
convert a JSON string to a Python object. By default, this is set to Litestar's decode_json function."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, Litestar's encode_json function is used."""
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/config/sync.py 0000664 0000000 0000000 00000025410 14766637146 0032757 0 ustar 00root root 0000000 0000000 from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
from litestar.cli._utils import console
from litestar.constants import HTTP_RESPONSE_START
from sqlalchemy import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.extensions.litestar._utils import (
delete_aa_scope_state,
get_aa_scope_state,
set_aa_scope_state,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.common import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.engine import EngineConfig
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from litestar import Litestar
from litestar.datastructures.state import State
from litestar.types import BeforeMessageSendHookHandler, Message, Scope
__all__ = (
"SQLAlchemySyncConfig",
"autocommit_before_send_handler",
"autocommit_handler_maker",
"default_before_send_handler",
"default_handler_maker",
)
def default_handler_maker(
session_scope_key: str = SESSION_SCOPE_KEY,
) -> "Callable[[Message, Scope], None]":
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
session_scope_key: The key to use within the application state
Returns:
The handler callable
"""
def handler(message: "Message", scope: "Scope") -> None:
"""Handle commit/rollback, closing and cleaning up sessions before sending.
Args:
message: ASGI-``Message``
scope: An ASGI-``Scope``
Returns:
None
"""
session = cast("Optional[Session]", get_aa_scope_state(scope, session_scope_key))
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
session.close()
delete_aa_scope_state(scope, session_scope_key)
return handler
default_before_send_handler = default_handler_maker()
def autocommit_handler_maker(
commit_on_redirect: bool = False,
extra_commit_statuses: "Optional[set[int]]" = None,
extra_rollback_statuses: "Optional[set[int]]" = None,
session_scope_key: str = SESSION_SCOPE_KEY,
) -> "Callable[[Message, Scope], None]":
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
extra_commit_statuses: A set of additional status codes that trigger a commit
extra_rollback_statuses: A set of additional status codes that trigger a rollback
session_scope_key: The key to use within the application state
Returns:
The handler callable
"""
if extra_commit_statuses is None:
extra_commit_statuses = set()
if extra_rollback_statuses is None:
extra_rollback_statuses = set()
if len(extra_commit_statuses & extra_rollback_statuses) > 0:
msg = "Extra rollback statuses and commit statuses must not share any status codes"
raise ValueError(msg)
commit_range = range(200, 400 if commit_on_redirect else 300)
def handler(message: "Message", scope: "Scope") -> None:
"""Handle commit/rollback, closing and cleaning up sessions before sending.
Args:
message: ASGI-``Message``
scope: An ASGI-``Scope``
Returns:
None
"""
session = cast("Optional[Session]", get_aa_scope_state(scope, session_scope_key))
try:
if session is not None and message["type"] == HTTP_RESPONSE_START:
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
"status"
] not in extra_rollback_statuses:
session.commit()
else:
session.rollback()
finally:
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
session.close()
delete_aa_scope_state(scope, session_scope_key)
return handler
autocommit_before_send_handler = autocommit_handler_maker()
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""Litestar Sync SQLAlchemy Configuration."""
before_send_handler: Optional[
Union["BeforeMessageSendHookHandler", Literal["autocommit", "autocommit_include_redirects"]]
] = None
"""Handler to call before the ASGI message is sent.
The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and
uncommitted data.
"""
engine_dependency_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_dependency_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
engine_app_state_key: str = "db_engine"
"""Key under which to store the SQLAlchemy engine in the application :class:`State <.datastructures.State>`
instance.
"""
session_maker_app_state_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application
:class:`State <.datastructures.State>` instance.
"""
session_scope_key: str = SESSION_SCOPE_KEY
"""Key under which to store the SQLAlchemy scope in the application."""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
set_default_exception_handler: bool = True
"""Sets the default exception handler on application start."""
def _ensure_unique(self, registry_name: str, key: str, new_key: Optional[str] = None, _iter: int = 0) -> str:
new_key = new_key if new_key is not None else key
if new_key in getattr(self.__class__, registry_name, {}):
_iter += 1
new_key = self._ensure_unique(registry_name, key, f"{key}_{_iter}", _iter)
return new_key
def __post_init__(self) -> None:
self.session_scope_key = self._ensure_unique("_SESSION_SCOPE_KEY_REGISTRY", self.session_scope_key)
self.engine_app_state_key = self._ensure_unique("_ENGINE_APP_STATE_KEY_REGISTRY", self.engine_app_state_key)
self.session_maker_app_state_key = self._ensure_unique(
"_SESSIONMAKER_APP_STATE_KEY_REGISTRY",
self.session_maker_app_state_key,
)
self.__class__._SESSION_SCOPE_KEY_REGISTRY.add(self.session_scope_key) # noqa: SLF001
self.__class__._ENGINE_APP_STATE_KEY_REGISTRY.add(self.engine_app_state_key) # noqa: SLF001
self.__class__._SESSIONMAKER_APP_STATE_KEY_REGISTRY.add(self.session_maker_app_state_key) # noqa: SLF001
if self.before_send_handler is None:
self.before_send_handler = default_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit":
self.before_send_handler = autocommit_handler_maker(session_scope_key=self.session_scope_key)
if self.before_send_handler == "autocommit_include_redirects":
self.before_send_handler = autocommit_handler_maker(
session_scope_key=self.session_scope_key,
commit_on_redirect=True,
)
super().__post_init__()
def create_session_maker(self) -> "Callable[[], Session]":
"""Get a session maker. If none exists yet, create one.
Returns:
Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if session_kws.get("bind") is None:
session_kws["bind"] = self.get_engine()
return self.session_maker_class(**session_kws)
@asynccontextmanager
async def lifespan(
self,
app: "Litestar",
) -> "AsyncGenerator[None, None]":
deps = self.create_app_state_items()
app.state.update(deps)
try:
if self.create_all:
self.create_all_metadata(app)
yield
finally:
if self.engine_dependency_key in deps:
engine = deps[self.engine_dependency_key]
if hasattr(engine, "dispose"):
cast("Engine", engine).dispose()
def provide_engine(self, state: "State") -> "Engine":
"""Create an engine instance.
Args:
state: The ``Litestar.state`` instance.
Returns:
An engine instance.
"""
return cast("Engine", state.get(self.engine_app_state_key))
def provide_session(self, state: "State", scope: "Scope") -> "Session":
"""Create a session instance.
Args:
state: The ``Litestar.state`` instance.
scope: The current connection's scope.
Returns:
A session instance.
"""
session = cast("Optional[Session]", get_aa_scope_state(scope, self.session_scope_key))
if session is None:
session_maker = cast("Callable[[], Session]", state[self.session_maker_app_state_key])
session = session_maker()
set_aa_scope_state(scope, self.session_scope_key, session)
return session
@property
def signature_namespace(self) -> "dict[str, Any]":
"""Return the plugin's signature namespace.
Returns:
A string keyed dict of names to be added to the namespace for signature forward reference resolution.
"""
return {"Engine": Engine, "Session": Session}
def create_all_metadata(self, app: "Litestar") -> None:
"""Create all metadata
Args:
app (Litestar): The ``Litestar`` instance
"""
with self.get_engine().begin() as conn:
try:
metadata_registry.get(self.bind_key).create_all(bind=conn)
except OperationalError as exc:
console.print(f"[bold red] * Could not create target metadata. Reason: {exc}")
def create_app_state_items(self) -> "dict[str, Any]":
"""Key/value pairs to be stored in application state."""
return {
self.engine_app_state_key: self.get_engine(),
self.session_maker_app_state_key: self.create_session_maker(),
}
def update_app_state(self, app: "Litestar") -> None:
"""Set the app state with engine and session.
Args:
app: The ``Litestar`` instance.
"""
app.state.update(self.create_app_state_items())
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/init/plugin.py 0000664 0000000 0000000 00000014174 14766637146 0032041 0 ustar 00root root 0000000 0000000 import contextlib
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Union, cast
from litestar.di import Provide
from litestar.dto import DTOData
from litestar.params import Dependency, Parameter
from litestar.plugins import CLIPlugin, InitPluginProtocol
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import Session, scoped_session
from advanced_alchemy.exceptions import ImproperConfigurationError, RepositoryError
from advanced_alchemy.extensions.litestar.exception_handler import exception_to_http_response
from advanced_alchemy.extensions.litestar.plugins import _slots_base
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
FilterTypes,
LimitOffset,
NotInCollectionFilter,
NotInSearchFilter,
OnBeforeAfter,
OrderBy,
SearchFilter,
)
from advanced_alchemy.service import ModelDictListT, ModelDictT, ModelDTOT, ModelOrRowMappingT, ModelT, OffsetPagination
if TYPE_CHECKING:
from click import Group
from litestar.config.app import AppConfig
from litestar.types import BeforeMessageSendHookHandler
from advanced_alchemy.extensions.litestar.plugins.init.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
__all__ = ("SQLAlchemyInitPlugin",)
signature_namespace_values: dict[str, Any] = {
"BeforeAfter": BeforeAfter,
"OnBeforeAfter": OnBeforeAfter,
"CollectionFilter": CollectionFilter,
"LimitOffset": LimitOffset,
"OrderBy": OrderBy,
"SearchFilter": SearchFilter,
"NotInCollectionFilter": NotInCollectionFilter,
"NotInSearchFilter": NotInSearchFilter,
"FilterTypes": FilterTypes,
"OffsetPagination": OffsetPagination,
"Parameter": Parameter,
"Dependency": Dependency,
"DTOData": DTOData,
"Sequence": Sequence,
"ModelT": ModelT,
"ModelDictT": ModelDictT,
"ModelDTOT": ModelDTOT,
"ModelDictListT": ModelDictListT,
"ModelOrRowMappingT": ModelOrRowMappingT,
"Session": Session,
"scoped_session": scoped_session,
"AsyncSession": AsyncSession,
"async_scoped_session": async_scoped_session,
}
class SQLAlchemyInitPlugin(InitPluginProtocol, CLIPlugin, _slots_base.SlotsBase):
"""SQLAlchemy application lifecycle configuration."""
def __init__(
self,
config: Union[
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]",
],
) -> None:
"""Initialize ``SQLAlchemyPlugin``.
Args:
config: configure DB connection and hook handlers and dependencies.
"""
self._config = config
@property
def config(self) -> "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]":
return self._config if isinstance(self._config, Sequence) else [self._config]
def on_cli_init(self, cli: "Group") -> None:
from advanced_alchemy.extensions.litestar.cli import database_group
cli.add_command(database_group)
def _validate_config(self) -> None:
configs = self._config if isinstance(self._config, Sequence) else [self._config]
scope_keys = {config.session_scope_key for config in configs}
engine_keys = {config.engine_dependency_key for config in configs}
session_keys = {config.session_dependency_key for config in configs}
if len(configs) > 1 and any(len(i) != len(configs) for i in (scope_keys, engine_keys, session_keys)):
raise ImproperConfigurationError(
detail="When using multiple configurations, please ensure the `session_dependency_key` and `engine_dependency_key` settings are unique across all configs. Additionally, iF you are using a custom `before_send` handler, ensure `session_scope_key` is unique.",
)
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
"""Configure application for use with SQLAlchemy.
Args:
app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
"""
self._validate_config()
with contextlib.suppress(ImportError):
from asyncpg.pgproto import pgproto # pyright: ignore[reportMissingImports]
signature_namespace_values.update({"pgproto.UUID": pgproto.UUID})
app_config.type_encoders = {pgproto.UUID: str, **(app_config.type_encoders or {})}
with contextlib.suppress(ImportError):
import uuid_utils # pyright: ignore[reportMissingImports]
signature_namespace_values.update({"uuid_utils.UUID": uuid_utils.UUID}) # pyright: ignore[reportUnknownMemberType]
app_config.type_encoders = {uuid_utils.UUID: str, **(app_config.type_encoders or {})} # pyright: ignore[reportUnknownMemberType]
app_config.type_decoders = [
(lambda x: x is uuid_utils.UUID, lambda t, v: t(str(v))), # pyright: ignore[reportUnknownMemberType]
*(app_config.type_decoders or []),
]
configure_exception_handler = False
for config in self.config:
if config.set_default_exception_handler:
configure_exception_handler = True
signature_namespace_values.update(config.signature_namespace)
app_config.lifespan.append(config.lifespan) # pyright: ignore[reportUnknownMemberType]
app_config.dependencies.update(
{
config.engine_dependency_key: Provide(config.provide_engine, sync_to_thread=False),
config.session_dependency_key: Provide(config.provide_session, sync_to_thread=False),
},
)
app_config.before_send.append(cast("BeforeMessageSendHookHandler", config.before_send_handler))
app_config.signature_namespace.update(signature_namespace_values)
if configure_exception_handler and not any(
isinstance(exc, int) or issubclass(exc, RepositoryError)
for exc in app_config.exception_handlers # pyright: ignore[reportUnknownMemberType]
):
app_config.exception_handlers.update({RepositoryError: exception_to_http_response}) # pyright: ignore[reportUnknownMemberType]
return app_config
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/plugins/serialization.py 0000664 0000000 0000000 00000002646 14766637146 0032456 0 ustar 00root root 0000000 0000000 from typing import Any
from litestar.plugins import SerializationPlugin
from litestar.typing import FieldDefinition
from sqlalchemy.orm import DeclarativeBase
from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO
from advanced_alchemy.extensions.litestar.plugins import _slots_base
class SQLAlchemySerializationPlugin(SerializationPlugin, _slots_base.SlotsBase):
def __init__(self) -> None:
self._type_dto_map: dict[type[DeclarativeBase], type[SQLAlchemyDTO[Any]]] = {}
def supports_type(self, field_definition: FieldDefinition) -> bool:
return (
field_definition.is_collection and field_definition.has_inner_subclass_of(DeclarativeBase)
) or field_definition.is_subclass_of(DeclarativeBase)
def create_dto_for_type(self, field_definition: FieldDefinition) -> type[SQLAlchemyDTO[Any]]:
# assumes that the type is a container of SQLAlchemy models or a single SQLAlchemy model
annotation = next(
(
inner_type.annotation
for inner_type in field_definition.inner_types
if inner_type.is_subclass_of(DeclarativeBase)
),
field_definition.annotation,
)
if annotation in self._type_dto_map:
return self._type_dto_map[annotation]
self._type_dto_map[annotation] = dto_type = SQLAlchemyDTO[annotation] # type:ignore[valid-type]
return dto_type
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/litestar/providers.py 0000664 0000000 0000000 00000047227 14766637146 0030141 0 ustar 00root root 0000000 0000000 # ruff: noqa: B008, PGH003
"""Application dependency providers generators.
This module contains functions to create dependency providers for services and filters.
You should not have modify this module very often and should only be invoked under normal usage.
"""
import datetime
import inspect
from collections.abc import AsyncGenerator, Callable, Generator
from typing import (
TYPE_CHECKING,
Any,
Literal,
Optional,
TypedDict,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID
from litestar.di import Provide
from litestar.params import Dependency, Parameter
from typing_extensions import NotRequired
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
FilterTypes,
LimitOffset,
OrderBy,
SearchFilter,
)
from advanced_alchemy.service import (
Empty,
EmptyType,
ErrorMessages,
LoadSpec,
ModelT,
SQLAlchemyAsyncRepositoryService,
SQLAlchemySyncRepositoryService,
)
if TYPE_CHECKING:
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from advanced_alchemy.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
DTorNone = Optional[datetime.datetime]
StringOrNone = Optional[str]
UuidOrNone = Optional[UUID]
IntOrNone = Optional[int]
BooleanOrNone = Optional[bool]
SortOrder = Literal["asc", "desc"]
SortOrderOrNone = Optional[SortOrder]
AsyncServiceT_co = TypeVar("AsyncServiceT_co", bound=SQLAlchemyAsyncRepositoryService[Any], covariant=True)
SyncServiceT_co = TypeVar("SyncServiceT_co", bound=SQLAlchemySyncRepositoryService[Any], covariant=True)
class DependencyDefaults:
FILTERS_DEPENDENCY_KEY: str = "filters"
"""Key for the filters dependency."""
CREATED_FILTER_DEPENDENCY_KEY: str = "created_filter"
"""Key for the created filter dependency."""
ID_FILTER_DEPENDENCY_KEY: str = "id_filter"
"""Key for the id filter dependency."""
LIMIT_OFFSET_DEPENDENCY_KEY: str = "limit_offset"
"""Key for the limit offset dependency."""
UPDATED_FILTER_DEPENDENCY_KEY: str = "updated_filter"
"""Key for the updated filter dependency."""
ORDER_BY_DEPENDENCY_KEY: str = "order_by"
"""Key for the order by dependency."""
SEARCH_FILTER_DEPENDENCY_KEY: str = "search_filter"
"""Key for the search filter dependency."""
DEFAULT_PAGINATION_SIZE: int = 20
"""Default pagination size."""
DEPENDENCY_DEFAULTS = DependencyDefaults()
class FilterConfig(TypedDict):
"""Configuration for generating dynamic filters."""
id_filter: NotRequired[type[Union[UUID, int]]]
"""Indicates that the id filter should be enabled. When set, the type specified will be used for the :class:`CollectionFilter`."""
id_field: NotRequired[str]
"""The field on the model that stored the primary key or identifier."""
sort_field: NotRequired[str]
"""The default field to use for the sort filter."""
sort_order: NotRequired[SortOrder]
"""The default order to use for the sort filter."""
pagination_type: NotRequired[Literal["limit_offset"]]
"""When set, pagination is enabled based on the type specified."""
pagination_size: NotRequired[int]
"""The size of the pagination."""
search: NotRequired[str]
"""When set, search is enabled for the specified fields."""
search_ignore_case: NotRequired[bool]
"""When set, search is case insensitive by default."""
created_at: NotRequired[bool]
"""When set, created_at filter is enabled."""
updated_at: NotRequired[bool]
"""When set, updated_at filter is enabled."""
class SingletonMeta(type):
"""Metaclass for singleton pattern."""
_instances: dict[type, Any] = {}
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances: # pyright: ignore[reportUnnecessaryContains]
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
class DependencyCache(metaclass=SingletonMeta):
"""Simple dependency cache for the application. This is used to help memoize dependencies that are generated dynamically."""
def __init__(self) -> None:
self.dependencies: dict[Union[int, str], dict[str, Provide]] = {}
def add_dependencies(self, key: Union[int, str], dependencies: dict[str, Provide]) -> None:
self.dependencies[key] = dependencies
def get_dependencies(self, key: Union[int, str]) -> Optional[dict[str, Provide]]:
return self.dependencies.get(key)
dep_cache = DependencyCache()
@overload
def create_service_provider(
service_class: type["AsyncServiceT_co"],
/,
statement: "Optional[Select[tuple[ModelT]]]" = None,
config: "Optional[SQLAlchemyAsyncConfig]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., AsyncGenerator[AsyncServiceT_co, None]]: ...
@overload
def create_service_provider(
service_class: type["SyncServiceT_co"],
/,
statement: "Optional[Select[tuple[ModelT]]]" = None,
config: "Optional[SQLAlchemySyncConfig]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., Generator[SyncServiceT_co, None, None]]: ...
def create_service_provider(
service_class: type[Union["AsyncServiceT_co", "SyncServiceT_co"]],
/,
statement: "Optional[Select[tuple[ModelT]]]" = None,
config: "Optional[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Callable[..., Union["AsyncGenerator[AsyncServiceT_co, None]", "Generator[SyncServiceT_co,None, None]"]]:
"""Create a dependency provider for a service."""
if issubclass(service_class, SQLAlchemyAsyncRepositoryService) or service_class is SQLAlchemyAsyncRepositoryService: # type: ignore[comparison-overlap]
async def provide_async_service(
db_session: "Optional[AsyncSession]" = None,
) -> "AsyncGenerator[AsyncServiceT_co, None]": # type: ignore[union-attr,unused-ignore]
async with service_class.new( # type: ignore[union-attr,unused-ignore]
session=db_session, # type: ignore[arg-type, unused-ignore]
statement=statement,
config=cast("Optional[SQLAlchemyAsyncConfig]", config), # type: ignore[arg-type]
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
) as service:
yield service
return provide_async_service
def provide_sync_service(
db_session: "Optional[Session]" = None,
) -> "Generator[SyncServiceT_co, None, None]":
with service_class.new(
session=db_session, # type: ignore[arg-type, unused-ignore]
statement=statement,
config=cast("Optional[SQLAlchemySyncConfig]", config),
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
) as service:
yield service
return provide_sync_service
def create_service_dependencies(
service_class: type[Union["AsyncServiceT_co", "SyncServiceT_co"]],
/,
key: str,
statement: "Optional[Select[tuple[ModelT]]]" = None,
config: "Optional[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]" = None,
error_messages: "Optional[Union[ErrorMessages, EmptyType]]" = Empty,
load: "Optional[LoadSpec]" = None,
execution_options: "Optional[dict[str, Any]]" = None,
filters: "Optional[FilterConfig]" = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
dep_defaults: "DependencyDefaults" = DEPENDENCY_DEFAULTS,
) -> dict[str, Provide]:
"""Create a dependency provider for the combined filter function.
Args:
key: The key to use for the dependency provider.
service_class: The service class to create a dependency provider for.
statement: The statement to use for the service.
config: The configuration to use for the service.
error_messages: The error messages to use for the service.
load: The load spec to use for the service.
execution_options: The execution options to use for the service.
filters: The filter configuration to use for the service.
uniquify: Whether to uniquify the service.
count_with_window_function: Whether to count with a window function.
dep_defaults: The dependency defaults to use for the service.
Returns:
A dictionary of dependency providers for the service.
"""
if issubclass(service_class, SQLAlchemyAsyncRepositoryService) or service_class is SQLAlchemyAsyncRepositoryService: # type: ignore[comparison-overlap]
svc = create_service_provider( # type: ignore[type-var,misc,unused-ignore]
service_class,
statement,
cast("Optional[SQLAlchemyAsyncConfig]", config),
error_messages,
load,
execution_options,
uniquify,
count_with_window_function,
)
deps = {key: Provide(svc)}
else:
svc = create_service_provider( # type: ignore[assignment]
service_class,
statement,
cast("Optional[SQLAlchemySyncConfig]", config),
error_messages,
load,
execution_options,
uniquify,
count_with_window_function,
)
deps = {key: Provide(svc, sync_to_thread=False)}
if filters:
deps.update(create_filter_dependencies(filters, dep_defaults))
return deps
def create_filter_dependencies(
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
) -> dict[str, Provide]:
"""Create a dependency provider for the combined filter function.
Args:
config: FilterConfig instance with desired settings.
dep_defaults: Dependency defaults to use for the filter dependencies
Returns:
A dependency provider function for the combined filter function.
"""
cache_key = sum(map(hash, config.items()))
deps = dep_cache.get_dependencies(cache_key)
if deps is not None:
return deps
deps = _create_statement_filters(config, dep_defaults)
dep_cache.add_dependencies(cache_key, deps)
return deps
def _create_statement_filters(
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
) -> dict[str, Provide]:
"""Create filter dependencies based on configuration.
Args:
config (FilterConfig): Configuration dictionary specifying which filters to enable
dep_defaults (DependencyDefaults): Dependency defaults to use for the filter dependencies
Returns:
dict[str, Provide]: Dictionary of filter provider functions
"""
filters: dict[str, Provide] = {}
if config.get("id_filter", False):
def provide_id_filter( # pyright: ignore[reportUnknownParameterType]
ids: Optional[list[str]] = Parameter(query="ids", default=None, required=False),
) -> CollectionFilter: # pyright: ignore[reportMissingTypeArgument]
return CollectionFilter(field_name=config.get("id_field", "id"), values=ids)
filters[dep_defaults.ID_FILTER_DEPENDENCY_KEY] = Provide(provide_id_filter, sync_to_thread=False) # pyright: ignore[reportUnknownArgumentType]
if config.get("created_at", False):
def provide_created_filter(
before: DTorNone = Parameter(query="createdBefore", default=None, required=False),
after: DTorNone = Parameter(query="createdAfter", default=None, required=False),
) -> BeforeAfter:
return BeforeAfter("created_at", before, after)
filters[dep_defaults.CREATED_FILTER_DEPENDENCY_KEY] = Provide(provide_created_filter, sync_to_thread=False)
if config.get("updated_at", False):
def provide_updated_filter(
before: DTorNone = Parameter(query="updatedBefore", default=None, required=False),
after: DTorNone = Parameter(query="updatedAfter", default=None, required=False),
) -> BeforeAfter:
return BeforeAfter("updated_at", before, after)
filters[dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY] = Provide(provide_updated_filter, sync_to_thread=False)
if config.get("pagination_type") == "limit_offset":
def provide_limit_offset_pagination(
current_page: int = Parameter(ge=1, query="currentPage", default=1, required=False),
page_size: int = Parameter(
query="pageSize",
ge=1,
default=config.get("pagination_size", dep_defaults.DEFAULT_PAGINATION_SIZE),
required=False,
),
) -> LimitOffset:
return LimitOffset(page_size, page_size * (current_page - 1))
filters[dep_defaults.LIMIT_OFFSET_DEPENDENCY_KEY] = Provide(
provide_limit_offset_pagination, sync_to_thread=False
)
if search_fields := config.get("search"):
def provide_search_filter(
search_string: StringOrNone = Parameter(
title="Field to search",
query="searchString",
default=None,
required=False,
),
ignore_case: BooleanOrNone = Parameter(
title="Search should be case sensitive",
query="searchIgnoreCase",
default=config.get("search_ignore_case", False),
required=False,
),
) -> SearchFilter:
return SearchFilter(
field_name=set(search_fields.split(",")),
value=search_string, # type: ignore[arg-type]
ignore_case=ignore_case or False,
)
filters[dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY] = Provide(provide_search_filter, sync_to_thread=False)
if sort_field := config.get("sort_field"):
def provide_order_by(
field_name: StringOrNone = Parameter(
title="Order by field",
query="orderBy",
default=sort_field,
required=False,
),
sort_order: SortOrderOrNone = Parameter(
title="Field to search",
query="sortOrder",
default=config.get("sort_order", "desc"),
required=False,
),
) -> OrderBy:
return OrderBy(field_name=field_name, sort_order=sort_order) # type: ignore[arg-type]
filters[dep_defaults.ORDER_BY_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False)
if filters:
filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide(
_create_filter_aggregate_function(config), sync_to_thread=False
)
return filters
def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]:
"""Create a filter function based on the provided configuration.
Args:
config: The filter configuration.
Returns:
A function that returns a list of filters based on the configuration.
"""
parameters: dict[str, inspect.Parameter] = {}
annotations: dict[str, Any] = {}
# Build parameters based on config
if cls := config.get("id_filter"):
parameters["id_filter"] = inspect.Parameter(
name="id_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=CollectionFilter[cls], # type: ignore[valid-type]
)
annotations["id_filter"] = CollectionFilter[cls] # type: ignore[valid-type]
if config.get("created_at"):
parameters["created_filter"] = inspect.Parameter(
name="created_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=BeforeAfter,
)
annotations["created_filter"] = BeforeAfter
if config.get("updated_at"):
parameters["updated_filter"] = inspect.Parameter(
name="updated_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=BeforeAfter,
)
annotations["updated_filter"] = BeforeAfter
if config.get("search"):
parameters["search_filter"] = inspect.Parameter(
name="search_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=SearchFilter,
)
annotations["search_filter"] = SearchFilter
if config.get("pagination_type") == "limit_offset":
parameters["limit_offset"] = inspect.Parameter(
name="limit_offset",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=LimitOffset,
)
annotations["limit_offset"] = LimitOffset
if config.get("sort_field"):
parameters["order_by"] = inspect.Parameter(
name="order_by",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=OrderBy,
)
annotations["order_by"] = OrderBy
def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
"""Provide filter dependencies based on configuration.
Args:
**kwargs: Filter parameters dynamically provided based on configuration.
Returns:
list[FilterTypes]: List of configured filters.
"""
filters: list[FilterTypes] = []
if id_filter := kwargs.get("id_filter"):
filters.append(id_filter)
if created_filter := kwargs.get("created_filter"):
filters.append(created_filter)
if limit_offset := kwargs.get("limit_offset"):
filters.append(limit_offset)
if updated_filter := kwargs.get("updated_filter"):
filters.append(updated_filter)
if (
(search_filter := cast("Optional[SearchFilter]", kwargs.get("search_filter")))
and search_filter is not None # pyright: ignore[reportUnnecessaryComparison]
and search_filter.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
and search_filter.value is not None # pyright: ignore[reportUnnecessaryComparison]
):
filters.append(search_filter)
if (
(order_by := cast("Optional[OrderBy]", kwargs.get("order_by")))
and order_by is not None # pyright: ignore[reportUnnecessaryComparison]
and order_by.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
):
filters.append(order_by)
return filters
# Set both signature and annotations
provide_filters.__signature__ = inspect.Signature( # type: ignore
parameters=list(parameters.values()),
return_annotation=list[FilterTypes],
)
provide_filters.__annotations__ = annotations
provide_filters.__annotations__["return"] = list[FilterTypes]
return provide_filters
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/sanic/ 0000775 0000000 0000000 00000000000 14766637146 0025004 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/sanic/__init__.py 0000664 0000000 0000000 00000001541 14766637146 0027116 0 ustar 00root root 0000000 0000000 from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import (
AlembicAsyncConfig,
AlembicSyncConfig,
AsyncSessionConfig,
SyncSessionConfig,
)
from advanced_alchemy.extensions.sanic.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.sanic.extension import AdvancedAlchemy
__all__ = (
"AdvancedAlchemy",
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"base",
"exceptions",
"filters",
"mixins",
"operations",
"repository",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/sanic/config.py 0000664 0000000 0000000 00000055042 14766637146 0026631 0 ustar 00root root 0000000 0000000 """Configuration classes for Sanic integration.
This module provides configuration classes for integrating SQLAlchemy with Sanic applications,
including both synchronous and asynchronous database configurations.
"""
import asyncio
import contextlib
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, cast
from click import echo
from sanic import HTTPResponse, Request, Sanic
from sqlalchemy.exc import OperationalError
from advanced_alchemy.exceptions import ImproperConfigurationError
try:
from sanic_ext import Extend
SANIC_INSTALLED = True
except ModuleNotFoundError: # pragma: no cover
SANIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
Extend = type("Extend", (), {}) # type: ignore # noqa: PGH003
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import Literal
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.service import schema_dump
def _make_unique_context_key(app: "Sanic[Any, Any]", key: str) -> str: # pragma: no cover
"""Generates a unique context key for the Sanic application.
Ensures that the key does not already exist in the application's state.
Args:
app (sanic.Sanic): The Sanic application instance.
key (str): The base key name.
Returns:
str: A unique key name.
"""
i = 0
while True:
if not hasattr(app.ctx, key):
return key
key = f"{key}_{i}"
i += i
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(schema_dump(value))
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's Engine.
This class extends the base EngineConfig with Sanic-specific JSON serialization options.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
Attributes:
json_deserializer: Callable for converting JSON strings to Python objects.
json_serializer: Callable for converting Python objects to JSON strings.
"""
json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object. But default, this uses the built-in serializers."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, By default, the built-in serializer is used."""
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""SQLAlchemy Async config for Sanic."""
_app: "Optional[Sanic[Any, Any]]" = None
"""The Sanic application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
async with self.engine_instance.begin() as conn:
try:
await conn.run_sync(
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all
)
await conn.commit()
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
else:
echo(" * Created target metadata.")
@property
def app(self) -> "Sanic[Any, Any]":
"""The Sanic application instance."""
if self._app is None:
msg = "The Sanic application instance is not set."
raise ImproperConfigurationError(msg)
return self._app
def init_app(self, app: "Sanic[Any, Any]", bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Initialize the Sanic application with this configuration.
Args:
app: The Sanic application instance.
bootstrap: The Sanic extension bootstrap.
"""
self._app = app
self.bind_key = self.bind_key or "default"
_ = self.create_session_maker()
self.session_key = _make_unique_context_key(app, f"advanced_alchemy_async_session_{self.session_key}")
self.engine_key = _make_unique_context_key(app, f"advanced_alchemy_async_engine_{self.engine_key}")
self.session_maker_key = _make_unique_context_key(
app, f"advanced_alchemy_async_session_maker_{self.session_maker_key}"
)
self.startup(bootstrap) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
def startup(self, bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Initialize the Sanic application with this configuration.
Args:
bootstrap: The Sanic extension bootstrap.
"""
@self.app.before_server_start # pyright: ignore[reportUnknownMemberType]
async def on_startup(_: Any) -> None: # pyright: ignore[reportUnusedFunction]
setattr(self.app.ctx, self.engine_key, self.get_engine()) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
setattr(self.app.ctx, self.session_maker_key, self.create_session_maker()) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
AsyncEngine,
self.get_engine_from_request,
)
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
async_sessionmaker[AsyncSession],
self.get_sessionmaker_from_request,
)
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
AsyncSession,
self.get_session_from_request,
)
await self.on_startup()
@self.app.after_server_stop # pyright: ignore[reportUnknownMemberType]
async def on_shutdown(_: Any) -> None: # pyright: ignore[reportUnusedFunction]
if self.engine_instance is not None:
await self.engine_instance.dispose()
if hasattr(self.app.ctx, self.engine_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.engine_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
if hasattr(self.app.ctx, self.session_maker_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.session_maker_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
@self.app.middleware("request") # pyright: ignore[reportUnknownMemberType]
async def on_request(request: Request) -> None: # pyright: ignore[reportUnusedFunction]
session = cast("Optional[AsyncSession]", getattr(request.ctx, self.session_key, None))
if session is None:
setattr(request.ctx, self.session_key, self.get_session())
@self.app.middleware("response") # type: ignore[arg-type]
async def on_response(request: Request, response: HTTPResponse) -> None: # pyright: ignore[reportUnusedFunction]
session = cast("Optional[AsyncSession]", getattr(request.ctx, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
async def on_startup(self) -> None:
"""Initialize the Sanic application with this configuration."""
if self.create_all:
await self.create_all_metadata()
def create_session_maker(self) -> Callable[[], "AsyncSession"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
async def session_handler(
self, session: "AsyncSession", request: "Request", response: "HTTPResponse"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (sanic.Request):
The incoming HTTP request.
response (sanic.HTTPResponse):
The outgoing HTTP response.
Returns:
None
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status < 400 # noqa: PLR2004
):
await session.commit()
else:
await session.rollback()
finally:
await session.close()
with contextlib.suppress(AttributeError, KeyError):
delattr(request.ctx, self.session_key)
def get_engine_from_request(self, request: "Request") -> AsyncEngine:
"""Retrieve the engine from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
AsyncEngine: The SQLAlchemy engine.
"""
return cast("AsyncEngine", getattr(request.app.ctx, self.engine_key, self.get_engine())) # pragma: no cover
def get_sessionmaker_from_request(self, request: "Request") -> async_sessionmaker[AsyncSession]:
"""Retrieve the session maker from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
SessionMakerT: The session maker.
"""
return cast(
"async_sessionmaker[AsyncSession]", getattr(request.app.ctx, self.session_maker_key, None)
) # pragma: no cover
def get_session_from_request(self, request: Request) -> AsyncSession:
"""Retrieve the session from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
SessionT: The session associated with the request.
"""
return cast("AsyncSession", getattr(request.ctx, self.session_key, None)) # pragma: no cover
async def close_engine(self) -> None: # pragma: no cover
"""Close the engine."""
if self.engine_instance is not None:
await self.engine_instance.dispose()
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
await self.close_engine()
if hasattr(self.app.ctx, self.engine_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.engine_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
if hasattr(self.app.ctx, self.session_maker_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.session_maker_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""SQLAlchemy Sync config for Starlette."""
_app: "Optional[Sanic[Any, Any]]" = None
"""The Sanic application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
@property
def app(self) -> "Sanic[Any, Any]":
"""The Sanic application instance."""
if self._app is None:
msg = "The Sanic application instance is not set."
raise ImproperConfigurationError(msg)
return self._app
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
with self.engine_instance.begin() as conn:
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None, metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all, conn
)
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
def init_app(self, app: "Sanic[Any, Any]", bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Initialize the Sanic application with this configuration.
Args:
app: The Sanic application instance.
bootstrap: The Sanic extension bootstrap.
"""
self._app = app
self.bind_key = self.bind_key or "default"
_ = self.create_session_maker()
self.session_key = _make_unique_context_key(app, f"advanced_alchemy_sync_session_{self.session_key}")
self.engine_key = _make_unique_context_key(app, f"advanced_alchemy_sync_engine_{self.engine_key}")
self.session_maker_key = _make_unique_context_key(
app, f"advanced_alchemy_sync_session_maker_{self.session_maker_key}"
)
self.startup(bootstrap) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
def startup(self, bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Initialize the Sanic application with this configuration.
Args:
bootstrap: The Sanic extension bootstrap.
"""
@self.app.before_server_start # pyright: ignore[reportUnknownMemberType]
async def on_startup(_: Any) -> None: # pyright: ignore[reportUnusedFunction]
setattr(self.app.ctx, self.engine_key, self.get_engine()) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
setattr(self.app.ctx, self.session_maker_key, self.create_session_maker()) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
AsyncEngine,
self.get_engine_from_request,
)
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
sessionmaker[Session],
self.get_sessionmaker_from_request,
)
bootstrap.add_dependency( # pyright: ignore[reportUnknownMemberType]
AsyncSession,
self.get_session_from_request,
)
await self.on_startup()
@self.app.after_server_stop # pyright: ignore[reportUnknownMemberType]
async def on_shutdown(_: Any) -> None: # pyright: ignore[reportUnusedFunction]
await self.on_shutdown()
@self.app.middleware("request") # pyright: ignore[reportUnknownMemberType]
async def on_request(request: Request) -> None: # pyright: ignore[reportUnusedFunction]
session = cast("Optional[Session]", getattr(request.ctx, self.session_key, None))
if session is None:
setattr(request.ctx, self.session_key, self.get_session())
@self.app.middleware("response") # type: ignore[arg-type]
async def on_response(request: Request, response: HTTPResponse) -> None: # pyright: ignore[reportUnusedFunction]
session = cast("Optional[Session]", getattr(request.ctx, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
async def on_startup(self) -> None:
"""Initialize the Sanic application with this configuration."""
if self.create_all:
await self.create_all_metadata()
def create_session_maker(self) -> Callable[[], "Session"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
async def session_handler(
self, session: "Session", request: "Request", response: "HTTPResponse"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.orm.Session):
The database session.
request (sanic.Request):
The incoming HTTP request.
response (sanic.HTTPResponse):
The outgoing HTTP response.
Returns:
None
"""
loop = asyncio.get_event_loop()
try:
if (self.commit_mode == "autocommit" and 200 <= response.status < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status < 400 # noqa: PLR2004
):
await loop.run_in_executor(None, session.commit)
else:
await loop.run_in_executor(None, session.rollback)
finally:
await loop.run_in_executor(None, session.close)
with contextlib.suppress(AttributeError, KeyError):
delattr(request.ctx, self.session_key)
def get_engine_from_request(self, request: Request) -> "AsyncEngine":
"""Retrieve the engine from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
AsyncEngine: The SQLAlchemy engine.
"""
return cast("AsyncEngine", getattr(request.app.ctx, self.engine_key, self.get_engine())) # pragma: no cover
def get_sessionmaker_from_request(self, request: Request) -> sessionmaker[Session]:
"""Retrieve the session maker from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
SessionMakerT: The session maker.
"""
return cast("sessionmaker[Session]", getattr(request.app.ctx, self.session_maker_key, None)) # pragma: no cover
def get_session_from_request(self, request: Request) -> "Session":
"""Retrieve the session from the request context.
Args:
request (sanic.Request): The incoming request.
Returns:
SessionT: The session associated with the request.
"""
return cast("Session", getattr(request.ctx, self.session_key, None)) # pragma: no cover
async def close_engine(self) -> None: # pragma: no cover
"""Close the engine."""
if self.engine_instance is not None:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.engine_instance.dispose)
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
await self.close_engine()
if hasattr(self.app.ctx, self.engine_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.engine_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
if hasattr(self.app.ctx, self.session_maker_key): # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
delattr(self.app.ctx, self.session_maker_key) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportOptionalMemberAccess]
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/sanic/extension.py 0000664 0000000 0000000 00000025631 14766637146 0027401 0 ustar 00root root 0000000 0000000 from collections.abc import AsyncGenerator, Generator, Sequence
from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, overload
from sanic import Request, Sanic
from advanced_alchemy.exceptions import ImproperConfigurationError, MissingDependencyError
from advanced_alchemy.extensions.sanic.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
try:
from sanic_ext import Extend
from sanic_ext.extensions.base import Extension
SANIC_INSTALLED = True
except ModuleNotFoundError: # pragma: no cover
SANIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
Extension = type("Extension", (), {}) # type: ignore # noqa: PGH003
Extend = type("Extend", (), {}) # type: ignore # noqa: PGH003
if TYPE_CHECKING:
from sanic import Sanic
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import Session
__all__ = ("AdvancedAlchemy",)
class AdvancedAlchemy(Extension): # type: ignore[no-untyped-call] # pyright: ignore[reportGeneralTypeIssues,reportUntypedBaseClass]
"""Sanic extension for integrating Advanced Alchemy with SQLAlchemy.
Args:
config: One or more configurations for SQLAlchemy.
app: The Sanic application instance.
"""
name = "AdvancedAlchemy"
def __init__(
self,
*,
sqlalchemy_config: Union[
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
Sequence[Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]],
],
sanic_app: Optional["Sanic[Any, Any]"] = None,
) -> None:
if not SANIC_INSTALLED: # pragma: no cover
msg = "Could not locate either Sanic or Sanic Extensions. Both libraries must be installed to use Advanced Alchemy. Try: pip install sanic[ext]"
raise MissingDependencyError(msg)
self._config = sqlalchemy_config if isinstance(sqlalchemy_config, Sequence) else [sqlalchemy_config]
self._mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = self.map_configs()
self._app = sanic_app
self._initialized = False
if self._app is not None:
self.register(self._app)
def register(self, sanic_app: "Sanic[Any, Any]") -> None:
"""Initialize the extension with the given Sanic app."""
self._app = sanic_app
Extend.register(self) # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
self._initialized = True
@property
def sanic_app(self) -> "Sanic[Any, Any]":
"""The Sanic app."""
if self._app is None: # pragma: no cover
msg = "AdvancedAlchemy has not been initialized with a Sanic app."
raise ImproperConfigurationError(msg)
return self._app
@property
def sqlalchemy_config(self) -> Sequence[Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]]:
"""Current Advanced Alchemy configuration."""
return self._config
def startup(self, bootstrap: "Extend") -> None: # pyright: ignore[reportUnknownParameterType,reportInvalidTypeForm]
"""Advanced Alchemy Sanic extension startup hook.
Args:
bootstrap (sanic_ext.Extend): The Sanic extension bootstrap.
"""
for config in self.sqlalchemy_config:
config.init_app(self.sanic_app, bootstrap) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
def map_configs(self) -> dict[str, Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]]:
"""Maps the configs to the session bind keys."""
mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = {}
for config in self.sqlalchemy_config:
if config.bind_key is None:
config.bind_key = "default"
mapped_configs[config.bind_key] = config
return mapped_configs
def get_config(self, key: Optional[str] = None) -> Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"]:
"""Get the config for the given key."""
if key is None:
key = "default"
if key == "default" and len(self.sqlalchemy_config) == 1:
key = self.sqlalchemy_config[0].bind_key or "default"
config = self._mapped_configs.get(key)
if config is None: # pragma: no cover
msg = f"Config with key {key} not found"
raise ImproperConfigurationError(msg)
return config
def get_async_config(self, key: Optional[str] = None) -> "SQLAlchemyAsyncConfig":
"""Get the async config for the given key."""
config = self.get_config(key)
if not isinstance(config, SQLAlchemyAsyncConfig): # pragma: no cover
msg = "Expected an async config, but got a sync config"
raise ImproperConfigurationError(msg)
return config
def get_sync_config(self, key: Optional[str] = None) -> "SQLAlchemySyncConfig":
"""Get the sync config for the given key."""
config = self.get_config(key)
if not isinstance(config, SQLAlchemySyncConfig): # pragma: no cover
msg = "Expected a sync config, but got an async config"
raise ImproperConfigurationError(msg)
return config
@asynccontextmanager
async def with_async_session(
self, key: Optional[str] = None
) -> AsyncGenerator["AsyncSession", None]: # pragma: no cover
"""Context manager for getting an async session."""
config = self.get_async_config(key)
async with config.get_session() as session:
yield session
@contextmanager
def with_sync_session(self, key: Optional[str] = None) -> Generator["Session", None]: # pragma: no cover
"""Context manager for getting a sync session."""
config = self.get_sync_config(key)
with config.get_session() as session:
yield session
@overload
@staticmethod
def _get_session_from_request(request: "Request", config: "SQLAlchemyAsyncConfig") -> "AsyncSession": ...
@overload
@staticmethod
def _get_session_from_request(request: "Request", config: "SQLAlchemySyncConfig") -> "Session": ...
@staticmethod
def _get_session_from_request(
request: "Request",
config: Union["SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig"], # pragma: no cover
) -> Union["Session", "AsyncSession"]: # pragma: no cover
"""Get the session for the given key."""
session = getattr(request.ctx, config.session_key, None)
if session is None:
setattr(request.ctx, config.session_key, config.get_session())
return cast("Union[Session, AsyncSession]", session)
def get_session(
self, request: "Request", key: Optional[str] = None
) -> Union["Session", "AsyncSession"]: # pragma: no cover
"""Get the session for the given key."""
config = self.get_config(key)
return self._get_session_from_request(request, config)
def get_async_session(self, request: "Request", key: Optional[str] = None) -> "AsyncSession": # pragma: no cover
"""Get the async session for the given key."""
config = self.get_async_config(key)
return self._get_session_from_request(request, config)
def get_sync_session(self, request: "Request", key: Optional[str] = None) -> "Session": # pragma: no cover
"""Get the sync session for the given key."""
config = self.get_sync_config(key)
return self._get_session_from_request(request, config)
def provide_session(
self, key: Optional[str] = None
) -> Callable[["Request"], Union["Session", "AsyncSession"]]: # pragma: no cover
"""Get the session for the given key."""
config = self.get_config(key)
def _get_session(request: "Request") -> Union["Session", "AsyncSession"]:
return self._get_session_from_request(request, config)
return _get_session
def provide_async_session(
self, key: Optional[str] = None
) -> Callable[["Request"], "AsyncSession"]: # pragma: no cover
"""Get the async session for the given key."""
config = self.get_async_config(key)
def _get_session(request: Request) -> "AsyncSession":
return self._get_session_from_request(request, config)
return _get_session
def provide_sync_session(self, key: Optional[str] = None) -> Callable[[Request], "Session"]: # pragma: no cover
"""Get the sync session for the given key."""
config = self.get_sync_config(key)
def _get_session(request: Request) -> "Session":
return self._get_session_from_request(request, config)
return _get_session
def get_engine(self, key: Optional[str] = None) -> Union["Engine", "AsyncEngine"]: # pragma: no cover
"""Get the engine for the given key."""
config = self.get_config(key)
return config.get_engine()
def get_async_engine(self, key: Optional[str] = None) -> "AsyncEngine": # pragma: no cover
"""Get the async engine for the given key."""
config = self.get_async_config(key)
return config.get_engine()
def get_sync_engine(self, key: Optional[str] = None) -> "Engine": # pragma: no cover
"""Get the sync engine for the given key."""
config = self.get_sync_config(key)
return config.get_engine()
def provide_engine(
self, key: Optional[str] = None
) -> Callable[[], Union["Engine", "AsyncEngine"]]: # pragma: no cover
"""Get the engine for the given key."""
config = self.get_config(key)
def _get_engine() -> Union["Engine", "AsyncEngine"]:
return config.get_engine()
return _get_engine
def provide_async_engine(self, key: Optional[str] = None) -> Callable[[], "AsyncEngine"]: # pragma: no cover
"""Get the async engine for the given key."""
config = self.get_async_config(key)
def _get_engine() -> "AsyncEngine":
return config.get_engine()
return _get_engine
def provide_sync_engine(self, key: Optional[str] = None) -> Callable[[], "Engine"]: # pragma: no cover
"""Get the sync engine for the given key."""
config = self.get_sync_config(key)
def _get_engine() -> "Engine":
return config.get_engine()
return _get_engine
def add_session_dependency(
self, session_type: type[Union["Session", "AsyncSession"]], key: Optional[str] = None
) -> None:
"""Add a session dependency to the Sanic app."""
self.sanic_app.ext.add_dependency(session_type, self.provide_session(key)) # pyright: ignore[reportUnknownMemberType]
def add_engine_dependency(
self, engine_type: type[Union["Engine", "AsyncEngine"]], key: Optional[str] = None
) -> None:
"""Add an engine dependency to the Sanic app."""
self.sanic_app.ext.add_dependency(engine_type, self.provide_engine(key)) # pyright: ignore[reportUnknownMemberType]
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/starlette/ 0000775 0000000 0000000 00000000000 14766637146 0025716 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/starlette/__init__.py 0000664 0000000 0000000 00000001774 14766637146 0030040 0 ustar 00root root 0000000 0000000 """Starlette extension for Advanced Alchemy.
This module provides Starlette integration for Advanced Alchemy, including session management and service utilities.
"""
from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
from advanced_alchemy.alembic.commands import AlembicCommands
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig
from advanced_alchemy.extensions.starlette.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.starlette.extension import AdvancedAlchemy
__all__ = (
"AdvancedAlchemy",
"AlembicAsyncConfig",
"AlembicCommands",
"AlembicSyncConfig",
"AsyncSessionConfig",
"EngineConfig",
"SQLAlchemyAsyncConfig",
"SQLAlchemySyncConfig",
"SyncSessionConfig",
"base",
"exceptions",
"filters",
"mixins",
"operations",
"repository",
"service",
"types",
"utils",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/starlette/config.py 0000664 0000000 0000000 00000036113 14766637146 0027541 0 ustar 00root root 0000000 0000000 """Configuration classes for Starlette integration.
This module provides configuration classes for integrating SQLAlchemy with Starlette applications,
including both synchronous and asynchronous database configurations.
"""
import contextlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from click import echo
from sqlalchemy.exc import OperationalError
from starlette.concurrency import run_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware
from typing_extensions import Literal
from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.service import schema_dump
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from starlette.applications import Starlette
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
def _make_unique_state_key(app: "Starlette", key: str) -> str: # pragma: no cover
"""Generates a unique state key for the Starlette application.
Ensures that the key does not already exist in the application's state.
Args:
app (starlette.applications.Starlette): The Starlette application instance.
key (str): The base key name.
Returns:
str: A unique key name.
"""
i = 0
while True:
if not hasattr(app.state, key):
return key
key = f"{key}_{i}"
i += i
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(schema_dump(value))
@dataclass
class EngineConfig(_EngineConfig):
"""Configuration for SQLAlchemy's Engine.
This class extends the base EngineConfig with Starlette-specific JSON serialization options.
For details see: https://docs.sqlalchemy.org/en/20/core/engines.html
Attributes:
json_deserializer: Callable for converting JSON strings to Python objects.
json_serializer: Callable for converting Python objects to JSON strings.
"""
json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object. But default, this uses the built-in serializers."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, By default, the built-in serializer is used."""
@dataclass
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig):
"""SQLAlchemy Async config for Starlette."""
app: "Optional[Starlette]" = None
"""The Starlette application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
async with self.engine_instance.begin() as conn:
try:
await conn.run_sync(
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all
)
await conn.commit()
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
else:
echo(" * Created target metadata.")
def init_app(self, app: "Starlette") -> None:
"""Initialize the Starlette application with this configuration.
Args:
app: The Starlette application instance.
"""
self.app = app
self.bind_key = self.bind_key or "default"
_ = self.create_session_maker()
self.session_key = _make_unique_state_key(app, f"advanced_alchemy_async_session_{self.session_key}")
self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_async_engine_{self.engine_key}")
self.session_maker_key = _make_unique_state_key(
app, f"advanced_alchemy_async_session_maker_{self.session_maker_key}"
)
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
if self.create_all:
await self.create_all_metadata()
def create_session_maker(self) -> Callable[[], "AsyncSession"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
async def session_handler(
self, session: "AsyncSession", request: "Request", response: "Response"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (starlette.requests.Request):
The incoming HTTP request.
response (starlette.responses.Response):
The outgoing HTTP response.
Returns:
None
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
await session.commit()
else:
await session.rollback()
finally:
await session.close()
with contextlib.suppress(AttributeError, KeyError):
delattr(request.state, self.session_key)
async def middleware_dispatch(
self, request: "Request", call_next: "RequestResponseEndpoint"
) -> "Response": # pragma: no cover
"""Middleware dispatch function to handle requests and responses.
Processes the request, invokes the next middleware or route handler, and
applies the session handler after the response is generated.
Args:
request (starlette.requests.Request): The incoming HTTP request.
call_next (starlette.middleware.base.RequestResponseEndpoint):
The next middleware or route handler.
Returns:
starlette.responses.Response: The HTTP response.
"""
response = await call_next(request)
session = cast("Optional[AsyncSession]", getattr(request.state, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
return response
async def close_engine(self) -> None: # pragma: no cover
"""Close the engine."""
if self.engine_instance is not None:
await self.engine_instance.dispose()
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
await self.close_engine()
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_maker_key)
delattr(self.app.state, self.session_key)
@dataclass
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig):
"""SQLAlchemy Sync config for Starlette."""
app: "Optional[Starlette]" = None
"""The Starlette application instance."""
commit_mode: Literal["manual", "autocommit", "autocommit_include_redirect"] = "manual"
"""The commit mode to use for database sessions."""
engine_key: str = "db_engine"
"""Key to use for the dependency injection of database engines."""
session_key: str = "db_session"
"""Key to use for the dependency injection of database sessions."""
session_maker_key: str = "session_maker_class"
"""Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application state instance.
"""
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride]
"""Configuration for the SQLAlchemy engine.
The configuration options are documented in the SQLAlchemy documentation.
"""
async def create_all_metadata(self) -> None: # pragma: no cover
"""Create all metadata tables in the database."""
if self.engine_instance is None:
self.engine_instance = self.get_engine()
with self.engine_instance.begin() as conn:
try:
await run_in_threadpool(
metadata_registry.get(None if self.bind_key == "default" else self.bind_key).create_all, conn
)
except OperationalError as exc:
echo(f" * Could not create target metadata. Reason: {exc}")
def init_app(self, app: "Starlette") -> None:
"""Initialize the Starlette application with this configuration.
Args:
app: The Starlette application instance.
"""
self.app = app
self.bind_key = self.bind_key or "default"
self.session_key = _make_unique_state_key(app, f"advanced_alchemy_sync_session_{self.session_key}")
self.engine_key = _make_unique_state_key(app, f"advanced_alchemy_sync_engine_{self.engine_key}")
self.session_maker_key = _make_unique_state_key(
app, f"advanced_alchemy_sync_session_maker_{self.session_maker_key}"
)
_ = self.create_session_maker()
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch)
async def on_startup(self) -> None:
"""Initialize the Starlette application with this configuration."""
if self.create_all:
await self.create_all_metadata()
def create_session_maker(self) -> Callable[[], "Session"]:
"""Get a session maker. If none exists yet, create one.
Returns:
Callable[[], Session]: Session factory used by the plugin.
"""
if self.session_maker:
return self.session_maker
session_kws = self.session_config_dict
if self.engine_instance is None:
self.engine_instance = self.get_engine()
if session_kws.get("bind") is None:
session_kws["bind"] = self.engine_instance
self.session_maker = self.session_maker_class(**session_kws)
return self.session_maker
async def session_handler(
self, session: "Session", request: "Request", response: "Response"
) -> None: # pragma: no cover
"""Handles the session after a request is processed.
Applies the commit strategy and ensures the session is closed.
Args:
session (sqlalchemy.orm.Session | sqlalchemy.ext.asyncio.AsyncSession):
The database session.
request (starlette.requests.Request):
The incoming HTTP request.
response (starlette.responses.Response):
The outgoing HTTP response.
Returns:
None
"""
try:
if (self.commit_mode == "autocommit" and 200 <= response.status_code < 300) or ( # noqa: PLR2004
self.commit_mode == "autocommit_include_redirect" and 200 <= response.status_code < 400 # noqa: PLR2004
):
await run_in_threadpool(session.commit)
else:
await run_in_threadpool(session.rollback)
finally:
await run_in_threadpool(session.close)
with contextlib.suppress(AttributeError, KeyError):
delattr(request.state, self.session_key)
async def middleware_dispatch(
self, request: "Request", call_next: "RequestResponseEndpoint"
) -> "Response": # pragma: no cover
"""Middleware dispatch function to handle requests and responses.
Processes the request, invokes the next middleware or route handler, and
applies the session handler after the response is generated.
Args:
request (starlette.requests.Request): The incoming HTTP request.
call_next (starlette.middleware.base.RequestResponseEndpoint):
The next middleware or route handler.
Returns:
starlette.responses.Response: The HTTP response.
"""
response = await call_next(request)
session = cast("Optional[Session]", getattr(request.state, self.session_key, None))
if session is not None:
await self.session_handler(session=session, request=request, response=response)
return response
async def close_engine(self) -> None: # pragma: no cover
"""Close the engines."""
if self.engine_instance is not None:
await run_in_threadpool(self.engine_instance.dispose)
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
await self.close_engine()
if self.app is not None:
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, self.engine_key)
delattr(self.app.state, self.session_maker_key)
delattr(self.app.state, self.session_key)
python-advanced-alchemy-1.0.1/advanced_alchemy/extensions/starlette/extension.py 0000664 0000000 0000000 00000026661 14766637146 0030317 0 ustar 00root root 0000000 0000000 # ruff: noqa: ARG001
import contextlib
from collections.abc import AsyncGenerator, Generator, Sequence
from contextlib import asynccontextmanager, contextmanager
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
cast,
overload,
)
from starlette.requests import Request
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.starlette.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
if TYPE_CHECKING:
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import Session
from starlette.applications import Starlette
class AdvancedAlchemy:
"""AdvancedAlchemy integration for Starlette applications.
This class manages SQLAlchemy sessions and engine lifecycle within a Starlette application.
It provides middleware for handling transactions based on commit strategies.
Args:
config (advanced_alchemy.config.asyncio.SQLAlchemyAsyncConfig | advanced_alchemy.config.sync.SQLAlchemySyncConfig):
The SQLAlchemy configuration.
app (starlette.applications.Starlette | None):
The Starlette application instance. Defaults to None.
"""
def __init__(
self,
config: Union[
SQLAlchemyAsyncConfig, SQLAlchemySyncConfig, Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]
],
app: Optional["Starlette"] = None,
) -> None:
self._config = config if isinstance(config, Sequence) else [config]
self._mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = self.map_configs()
self._app = cast("Optional[Starlette]", None)
if app is not None:
self.init_app(app)
@property
def config(self) -> Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]:
"""Current Advanced Alchemy configuration."""
return self._config
def init_app(self, app: "Starlette") -> None:
"""Initializes the Starlette application with SQLAlchemy engine and sessionmaker.
Sets up middleware and shutdown handlers for managing the database engine.
Args:
app (starlette.applications.Starlette): The Starlette application instance.
"""
self._app = app
unique_bind_keys = {config.bind_key for config in self.config}
if len(unique_bind_keys) != len(self.config): # pragma: no cover
msg = "Please ensure that each config has a unique name for the `bind_key` attribute. The default is `default` and can only be bound to a single engine."
raise ImproperConfigurationError(msg)
for config in self.config:
config.init_app(app)
app.state.advanced_alchemy = self
original_lifespan = app.router.lifespan_context
@asynccontextmanager
async def wrapped_lifespan(app: "Starlette") -> AsyncGenerator[Any, None]: # pragma: no cover
async with self.lifespan(app), original_lifespan(app) as state:
yield state
app.router.lifespan_context = wrapped_lifespan
@asynccontextmanager
async def lifespan(self, app: "Starlette") -> AsyncGenerator[Any, None]: # pragma: no cover
"""Context manager for lifespan events.
Args:
app: The starlette application.
Yields:
None
"""
await self.on_startup()
try:
yield
finally:
await self.on_shutdown()
@property
def app(self) -> "Starlette": # pragma: no cover
"""Returns the Starlette application instance.
Raises:
advanced_alchemy.exceptions.ImproperConfigurationError:
If the application is not initialized.
Returns:
starlette.applications.Starlette: The Starlette application instance.
"""
if self._app is None: # pragma: no cover
msg = "Application not initialized. Did you forget to call init_app?"
raise ImproperConfigurationError(msg)
return self._app
async def on_startup(self) -> None: # pragma: no cover
"""Initializes the database."""
for config in self.config:
await config.on_startup()
async def on_shutdown(self) -> None: # pragma: no cover
"""Handles the shutdown event by disposing of the SQLAlchemy engine.
Ensures that all connections are properly closed during application shutdown.
Returns:
None
"""
for config in self.config:
await config.on_shutdown()
with contextlib.suppress(AttributeError, KeyError):
delattr(self.app.state, "advanced_alchemy")
def map_configs(self) -> dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]:
"""Maps the configs to the session bind keys."""
mapped_configs: dict[str, Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]] = {}
for config in self.config:
if config.bind_key is None:
config.bind_key = "default"
mapped_configs[config.bind_key] = config
return mapped_configs
def get_config(self, key: Optional[str] = None) -> Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]:
"""Get the config for the given key."""
if key is None:
key = "default"
if key == "default" and len(self.config) == 1:
key = self.config[0].bind_key or "default"
config = self._mapped_configs.get(key)
if config is None: # pragma: no cover
msg = f"Config with key {key} not found"
raise ImproperConfigurationError(msg)
return config
def get_async_config(self, key: Optional[str] = None) -> SQLAlchemyAsyncConfig:
"""Get the async config for the given key."""
config = self.get_config(key)
if not isinstance(config, SQLAlchemyAsyncConfig): # pragma: no cover
msg = "Expected an async config, but got a sync config"
raise ImproperConfigurationError(msg)
return config
def get_sync_config(self, key: Optional[str] = None) -> SQLAlchemySyncConfig:
"""Get the sync config for the given key."""
config = self.get_config(key)
if not isinstance(config, SQLAlchemySyncConfig): # pragma: no cover
msg = "Expected a sync config, but got an async config"
raise ImproperConfigurationError(msg)
return config
@asynccontextmanager
async def with_async_session(
self, key: Optional[str] = None
) -> AsyncGenerator["AsyncSession", None]: # pragma: no cover
"""Context manager for getting an async session."""
config = self.get_async_config(key)
async with config.get_session() as session:
yield session
@contextmanager
def with_sync_session(self, key: Optional[str] = None) -> Generator["Session", None]: # pragma: no cover
"""Context manager for getting a sync session."""
config = self.get_sync_config(key)
with config.get_session() as session:
yield session
@overload
@staticmethod
def _get_session_from_request(request: Request, config: SQLAlchemyAsyncConfig) -> "AsyncSession": ...
@overload
@staticmethod
def _get_session_from_request(request: Request, config: SQLAlchemySyncConfig) -> "Session": ...
@staticmethod
def _get_session_from_request(
request: Request,
config: Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig], # pragma: no cover
) -> Union["Session", "AsyncSession"]: # pragma: no cover
"""Get the session for the given key."""
session = getattr(request.state, config.session_key, None)
if session is None:
session = config.create_session_maker()()
setattr(request.state, config.session_key, session)
return session
def get_session(
self, request: Request, key: Optional[str] = None
) -> Union["Session", "AsyncSession"]: # pragma: no cover
"""Get the session for the given key."""
config = self.get_config(key)
return self._get_session_from_request(request, config)
def get_async_session(self, request: Request, key: Optional[str] = None) -> "AsyncSession": # pragma: no cover
"""Get the async session for the given key."""
config = self.get_async_config(key)
return self._get_session_from_request(request, config)
def get_sync_session(self, request: Request, key: Optional[str] = None) -> "Session": # pragma: no cover
"""Get the sync session for the given key."""
config = self.get_sync_config(key)
return self._get_session_from_request(request, config)
def provide_session(
self, key: Optional[str] = None
) -> Callable[[Request], Union["Session", "AsyncSession"]]: # pragma: no cover
"""Get the session for the given key."""
config = self.get_config(key)
def _get_session(request: Request) -> Union["Session", "AsyncSession"]:
return self._get_session_from_request(request, config)
return _get_session
def provide_async_session(
self, key: Optional[str] = None
) -> Callable[[Request], "AsyncSession"]: # pragma: no cover
"""Get the async session for the given key."""
config = self.get_async_config(key)
def _get_session(request: Request) -> "AsyncSession":
return self._get_session_from_request(request, config)
return _get_session
def provide_sync_session(self, key: Optional[str] = None) -> Callable[[Request], "Session"]: # pragma: no cover
"""Get the sync session for the given key."""
config = self.get_sync_config(key)
def _get_session(request: Request) -> "Session":
return self._get_session_from_request(request, config)
return _get_session
def get_engine(self, key: Optional[str] = None) -> Union["Engine", "AsyncEngine"]: # pragma: no cover
"""Get the engine for the given key."""
config = self.get_config(key)
return config.get_engine()
def get_async_engine(self, key: Optional[str] = None) -> "AsyncEngine": # pragma: no cover
"""Get the async engine for the given key."""
config = self.get_async_config(key)
return config.get_engine()
def get_sync_engine(self, key: Optional[str] = None) -> "Engine": # pragma: no cover
"""Get the sync engine for the given key."""
config = self.get_sync_config(key)
return config.get_engine()
def provide_engine(
self, key: Optional[str] = None
) -> Callable[[], Union["Engine", "AsyncEngine"]]: # pragma: no cover
"""Get the engine for the given key."""
config = self.get_config(key)
def _get_engine() -> Union["Engine", "AsyncEngine"]:
return config.get_engine()
return _get_engine
def provide_async_engine(self, key: Optional[str] = None) -> Callable[[], "AsyncEngine"]: # pragma: no cover
"""Get the async engine for the given key."""
config = self.get_async_config(key)
def _get_engine() -> "AsyncEngine":
return config.get_engine()
return _get_engine
def provide_sync_engine(self, key: Optional[str] = None) -> Callable[[], "Engine"]: # pragma: no cover
"""Get the sync engine for the given key."""
config = self.get_sync_config(key)
def _get_engine() -> "Engine":
return config.get_engine()
return _get_engine
python-advanced-alchemy-1.0.1/advanced_alchemy/filters.py 0000664 0000000 0000000 00000047651 14766637146 0023547 0 ustar 00root root 0000000 0000000 """SQLAlchemy filter constructs for advanced query operations.
This module provides a comprehensive collection of filter datastructures designed to
enhance SQLAlchemy query construction. It implements type-safe, reusable filter patterns
for common database query operations.
Features:
Type-safe filter construction, datetime range filtering, collection-based filtering,
pagination support, search operations, and customizable ordering.
Example:
Basic usage with a datetime filter::
import datetime
from advanced_alchemy.filters import BeforeAfter
filter = BeforeAfter(
field_name="created_at",
before=datetime.datetime.now(),
after=datetime.datetime(2023, 1, 1),
)
statement = filter.append_to_statement(select(Model), Model)
Note:
All filter classes implement the :class:`StatementFilter` ABC, ensuring consistent
interface across different filter types.
See Also:
- :class:`sqlalchemy.sql.expression.Select`: Core SQLAlchemy select expression
- :class:`sqlalchemy.orm.Query`: SQLAlchemy ORM query interface
- :mod:`advanced_alchemy.base`: Base model definitions
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Optional, Union, cast
from sqlalchemy import BinaryExpression, ColumnElement, Delete, Select, Update, and_, any_, or_, text
from sqlalchemy.orm import InstrumentedAttribute
from typing_extensions import TypeAlias, TypeVar
if TYPE_CHECKING:
import datetime
from collections import abc
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from advanced_alchemy import base
__all__ = (
"BeforeAfter",
"CollectionFilter",
"FilterTypes",
"InAnyFilter",
"LimitOffset",
"NotInCollectionFilter",
"NotInSearchFilter",
"OnBeforeAfter",
"OrderBy",
"PaginationFilter",
"SearchFilter",
"StatementFilter",
"StatementFilterT",
"StatementTypeT",
)
T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="base.ModelProtocol")
StatementFilterT = TypeVar("StatementFilterT", bound="StatementFilter")
StatementTypeT = TypeVar(
"StatementTypeT",
bound="Union[ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Select[Any], Update, Delete]",
)
FilterTypes: TypeAlias = "Union[BeforeAfter, OnBeforeAfter, CollectionFilter[Any], LimitOffset, OrderBy, SearchFilter, NotInCollectionFilter[Any], NotInSearchFilter]"
"""Aggregate type alias of the types supported for collection filtering."""
class StatementFilter(ABC):
"""Abstract base class for SQLAlchemy statement filters.
This class defines the interface for all filter types in the system. Each filter
implementation must provide a method to append its filtering logic to an existing
SQLAlchemy statement.
"""
@abstractmethod
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args: Any, **kwargs: Any
) -> StatementTypeT:
"""Append filter conditions to a SQLAlchemy statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
StatementTypeT: Modified SQLAlchemy statement with filter conditions applied
Raises:
NotImplementedError: If the concrete class doesn't implement this method
Note:
This method must be implemented by all concrete filter classes.
See Also:
:meth:`sqlalchemy.sql.expression.Select.where`: SQLAlchemy where clause
"""
return statement
@staticmethod
def _get_instrumented_attr(model: Any, key: "Union[str, InstrumentedAttribute[Any]]") -> InstrumentedAttribute[Any]:
"""Get SQLAlchemy instrumented attribute from model.
Args:
model: SQLAlchemy model class or instance
key: Attribute name or instrumented attribute
Returns:
InstrumentedAttribute[Any]: SQLAlchemy instrumented attribute
See Also:
:class:`sqlalchemy.orm.attributes.InstrumentedAttribute`: SQLAlchemy attribute
"""
if isinstance(key, str):
return cast("InstrumentedAttribute[Any]", getattr(model, key))
return key
@dataclass
class BeforeAfter(StatementFilter):
"""DateTime range filter with exclusive bounds.
This filter creates date/time range conditions using < and > operators,
excluding the boundary values.
If either `before` or `after` is None, that boundary condition is not applied.
See Also:
---------
:class:`OnBeforeAfter` : Inclusive datetime range filtering
"""
field_name: str
"""Name of the model attribute to filter on."""
before: "Optional[datetime.datetime]"
"""Filter results where field is earlier than this value."""
after: "Optional[datetime.datetime]"
"""Filter results where field is later than this value."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply datetime range conditions to statement.
Parameters
----------
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
Returns:
--------
StatementTypeT
Modified statement with datetime range conditions
"""
field = self._get_instrumented_attr(model, self.field_name)
if self.before is not None:
statement = cast("StatementTypeT", statement.where(field < self.before))
if self.after is not None:
statement = cast("StatementTypeT", statement.where(field > self.after))
return statement
@dataclass
class OnBeforeAfter(StatementFilter):
"""DateTime range filter with inclusive bounds.
This filter creates date/time range conditions using <= and >= operators,
including the boundary values.
If either `on_or_before` or `on_or_after` is None, that boundary condition
is not applied.
See Also:
---------
:class:`BeforeAfter` : Exclusive datetime range filtering
"""
field_name: str
"""Name of the model attribute to filter on."""
on_or_before: "Optional[datetime.datetime]"
"""Filter results where field is on or earlier than this value."""
on_or_after: "Optional[datetime.datetime]"
"""Filter results where field is on or later than this value."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply inclusive datetime range conditions to statement.
Parameters
----------
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
Returns:
--------
StatementTypeT
Modified statement with inclusive datetime range conditions
"""
field = self._get_instrumented_attr(model, self.field_name)
if self.on_or_before is not None:
statement = cast("StatementTypeT", statement.where(field <= self.on_or_before))
if self.on_or_after is not None:
statement = cast("StatementTypeT", statement.where(field >= self.on_or_after))
return statement
class InAnyFilter(StatementFilter, ABC):
"""Base class for filters using IN or ANY operators.
This abstract class provides common functionality for filters that check
membership in a collection using either the SQL IN operator or the ANY operator.
"""
@dataclass
class CollectionFilter(InAnyFilter, Generic[T]):
"""Data required to construct a WHERE ... IN (...) clause.
This filter restricts records based on a field's presence in a collection of values.
The filter supports both ``IN`` and ``ANY`` operators for collection membership testing.
Use ``prefer_any=True`` in ``append_to_statement`` to use the ``ANY`` operator.
"""
field_name: str
"""Name of the model attribute to filter on."""
values: "Union[abc.Collection[T], None]"
"""Values for the ``IN`` clause. If this is None, no filter is applied.
An empty list will force an empty result set (WHERE 1=-1)"""
def append_to_statement(
self,
statement: StatementTypeT,
model: type[ModelT],
prefer_any: bool = False,
) -> StatementTypeT:
"""Apply a WHERE ... IN or WHERE ... ANY (...) clause to the statement.
Parameters
----------
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
prefer_any : bool, optional
If True, uses the SQLAlchemy :func:`any_` operator instead of
:func:`in_` for the filter condition
Returns:
--------
StatementTypeT
Modified statement with the appropriate IN conditions
"""
field = self._get_instrumented_attr(model, self.field_name)
if self.values is None:
return statement
if not self.values:
# Return empty result set by forcing a false condition
return cast("StatementTypeT", statement.where(text("1=-1")))
if prefer_any:
return cast("StatementTypeT", statement.where(any_(self.values) == field)) # type: ignore[arg-type]
return cast("StatementTypeT", statement.where(field.in_(self.values)))
@dataclass
class NotInCollectionFilter(InAnyFilter, Generic[T]):
"""Data required to construct a WHERE ... NOT IN (...) clause.
This filter restricts records based on a field's absence in a collection of values.
The filter supports both ``NOT IN`` and ``!= ANY`` operators for collection exclusion.
Use ``prefer_any=True`` in ``append_to_statement`` to use the ``ANY`` operator.
Parameters
----------
field_name : str
Name of the model attribute to filter on
values : abc.Collection[T] | None
Values for the ``NOT IN`` clause. If this is None or empty,
the filter is not applied.
"""
field_name: str
"""Name of the model attribute to filter on."""
values: "Union[abc.Collection[T], None]"
"""Values for the ``NOT IN`` clause. If None or empty, no filter is applied."""
def append_to_statement(
self,
statement: StatementTypeT,
model: type[ModelT],
prefer_any: bool = False,
) -> StatementTypeT:
"""Apply a WHERE ... NOT IN or WHERE ... != ANY(...) clause to the statement.
Parameters
----------
statement : StatementTypeT
The SQLAlchemy statement to modify
model : type[ModelT]
The SQLAlchemy model class
prefer_any : bool, optional
If True, uses the SQLAlchemy :func:`any_` operator instead of
:func:`notin_` for the filter condition
Returns:
--------
StatementTypeT
Modified statement with the appropriate NOT IN conditions
"""
field = self._get_instrumented_attr(model, self.field_name)
if not self.values:
# If None or empty, we do not modify the statement
return statement
if prefer_any:
return cast("StatementTypeT", statement.where(any_(self.values) != field)) # type: ignore[arg-type]
return cast("StatementTypeT", statement.where(field.notin_(self.values)))
class PaginationFilter(StatementFilter, ABC):
"""Abstract base class for pagination filters.
Subclasses should implement pagination logic, such as limit/offset or
cursor-based pagination.
"""
@dataclass
class LimitOffset(PaginationFilter):
"""Limit and offset pagination filter.
Implements traditional pagination using SQL LIMIT and OFFSET clauses.
Only applies to SELECT statements; other statement types are returned unmodified.
Note:
This filter only modifies SELECT statements. For other statement types
(UPDATE, DELETE), the statement is returned unchanged.
See Also:
- :meth:`sqlalchemy.sql.expression.Select.limit`: SQLAlchemy LIMIT clause
- :meth:`sqlalchemy.sql.expression.Select.offset`: SQLAlchemy OFFSET clause
"""
limit: int
"""Maximum number of rows to return."""
offset: int
"""Number of rows to skip before returning results."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply LIMIT/OFFSET pagination to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with limit and offset applied
Note:
Only modifies SELECT statements. Other statement types are returned as-is.
See Also:
:class:`sqlalchemy.sql.expression.Select`: SQLAlchemy SELECT statement
"""
if isinstance(statement, Select):
return cast("StatementTypeT", statement.limit(self.limit).offset(self.offset))
return statement
@dataclass
class OrderBy(StatementFilter):
"""Order by a specific field.
Appends an ORDER BY clause to SELECT statements, sorting records by the
specified field in ascending or descending order.
Note:
This filter only modifies SELECT statements. For other statement types,
the statement is returned unchanged.
See Also:
- :meth:`sqlalchemy.sql.expression.Select.order_by`: SQLAlchemy ORDER BY clause
- :meth:`sqlalchemy.sql.expression.ColumnElement.asc`: Ascending order
- :meth:`sqlalchemy.sql.expression.ColumnElement.desc`: Descending order
"""
field_name: str
"""Name of the model attribute to sort on."""
sort_order: Literal["asc", "desc"] = "asc"
"""Sort direction ("asc" or "desc")."""
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Append an ORDER BY clause to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with an ORDER BY clause
Note:
Only modifies SELECT statements. Other statement types are returned as-is.
See Also:
:meth:`sqlalchemy.sql.expression.Select.order_by`: SQLAlchemy ORDER BY
"""
if not isinstance(statement, Select):
return statement
field = self._get_instrumented_attr(model, self.field_name)
if self.sort_order == "desc":
return cast("StatementTypeT", statement.order_by(field.desc()))
return cast("StatementTypeT", statement.order_by(field.asc()))
@dataclass
class SearchFilter(StatementFilter):
"""Case-sensitive or case-insensitive substring matching filter.
Implements text search using SQL LIKE or ILIKE operators. Can search across
multiple fields using OR conditions.
Note:
The search pattern automatically adds wildcards before and after the search
value, equivalent to SQL pattern '%value%'.
See Also:
- :class:`.NotInSearchFilter`: Opposite filter using NOT LIKE/ILIKE
- :meth:`sqlalchemy.sql.expression.ColumnOperators.like`: Case-sensitive LIKE
- :meth:`sqlalchemy.sql.expression.ColumnOperators.ilike`: Case-insensitive LIKE
"""
field_name: "Union[str, set[str]]"
"""Name or set of names of model attributes to search on."""
value: str
"""Text to match within the field(s)."""
ignore_case: "Optional[bool]" = False
"""Whether to use case-insensitive matching."""
@property
def _operator(self) -> "Callable[..., ColumnElement[bool]]":
"""Return the SQL operator for combining multiple search clauses.
Returns:
Callable[..., ColumnElement[bool]]: The `or_` operator for OR conditions
See Also:
:func:`sqlalchemy.sql.expression.or_`: SQLAlchemy OR operator
"""
return or_
@property
def _func(self) -> "attrgetter[Callable[[str], BinaryExpression[bool]]]":
"""Return the appropriate LIKE or ILIKE operator as a function.
Returns:
attrgetter: Bound method for LIKE or ILIKE operations
See Also:
- :meth:`sqlalchemy.sql.expression.ColumnOperators.like`: LIKE operator
- :meth:`sqlalchemy.sql.expression.ColumnOperators.ilike`: ILIKE operator
"""
return attrgetter("ilike" if self.ignore_case else "like")
@property
def normalized_field_names(self) -> set[str]:
"""Convert field_name to a set if it's a single string.
Returns:
set[str]: Set of field names to be searched
"""
return {self.field_name} if isinstance(self.field_name, str) else self.field_name
def get_search_clauses(self, model: type[ModelT]) -> list["BinaryExpression[bool]"]:
"""Generate the LIKE/ILIKE clauses for all specified fields.
Args:
model: The SQLAlchemy model class
Returns:
list[BinaryExpression[bool]]: List of text matching expressions
See Also:
:class:`sqlalchemy.sql.expression.BinaryExpression`: SQLAlchemy expression
"""
search_clause: list[BinaryExpression[bool]] = []
for field_name in self.normalized_field_names:
field = self._get_instrumented_attr(model, field_name)
search_text = f"%{self.value}%"
search_clause.append(self._func(field)(search_text))
return search_clause
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Append a LIKE/ILIKE clause to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with text search clauses
See Also:
:meth:`sqlalchemy.sql.expression.Select.where`: SQLAlchemy WHERE clause
"""
where_clause = self._operator(*self.get_search_clauses(model))
return cast("StatementTypeT", statement.where(where_clause))
@dataclass
class NotInSearchFilter(SearchFilter):
"""Filter for excluding records that match a substring.
Implements negative text search using SQL NOT LIKE or NOT ILIKE operators.
Can exclude across multiple fields using AND conditions.
Args:
field_name: Name or set of names of model attributes to search on
value: Text to exclude from the field(s)
ignore_case: If True, uses NOT ILIKE for case-insensitive matching
Note:
Uses AND for multiple fields, meaning records matching any field will be excluded.
See Also:
- :class:`.SearchFilter`: Opposite filter using LIKE/ILIKE
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notlike`: NOT LIKE operator
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notilike`: NOT ILIKE operator
"""
@property
def _operator(self) -> Callable[..., ColumnElement[bool]]:
"""Return the SQL operator for combining multiple negated search clauses.
Returns:
Callable[..., ColumnElement[bool]]: The `and_` operator for AND conditions
See Also:
:func:`sqlalchemy.sql.expression.and_`: SQLAlchemy AND operator
"""
return and_
@property
def _func(self) -> "attrgetter[Callable[[str], BinaryExpression[bool]]]":
"""Return the appropriate NOT LIKE or NOT ILIKE operator as a function.
Returns:
attrgetter: Bound method for NOT LIKE or NOT ILIKE operations
See Also:
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notlike`: NOT LIKE
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notilike`: NOT ILIKE
"""
return attrgetter("not_ilike" if self.ignore_case else "not_like")
python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/ 0000775 0000000 0000000 00000000000 14766637146 0023017 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/__init__.py 0000664 0000000 0000000 00000001176 14766637146 0025135 0 ustar 00root root 0000000 0000000 from advanced_alchemy.mixins.audit import AuditColumns
from advanced_alchemy.mixins.bigint import BigIntPrimaryKey
from advanced_alchemy.mixins.nanoid import NanoIDPrimaryKey
from advanced_alchemy.mixins.sentinel import SentinelMixin
from advanced_alchemy.mixins.slug import SlugKey
from advanced_alchemy.mixins.unique import UniqueMixin
from advanced_alchemy.mixins.uuid import UUIDPrimaryKey, UUIDv6PrimaryKey, UUIDv7PrimaryKey
__all__ = (
"AuditColumns",
"BigIntPrimaryKey",
"NanoIDPrimaryKey",
"SentinelMixin",
"SlugKey",
"UUIDPrimaryKey",
"UUIDv6PrimaryKey",
"UUIDv7PrimaryKey",
"UniqueMixin",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/audit.py 0000664 0000000 0000000 00000001722 14766637146 0024501 0 ustar 00root root 0000000 0000000 import datetime
from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column, validates
from advanced_alchemy.types import DateTimeUTC
@declarative_mixin
class AuditColumns:
"""Created/Updated At Fields Mixin."""
created_at: Mapped[datetime.datetime] = mapped_column(
DateTimeUTC(timezone=True),
default=lambda: datetime.datetime.now(datetime.timezone.utc),
)
"""Date/time of instance creation."""
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTimeUTC(timezone=True),
default=lambda: datetime.datetime.now(datetime.timezone.utc),
onupdate=lambda: datetime.datetime.now(datetime.timezone.utc),
)
"""Date/time of instance last update."""
@validates("created_at", "updated_at")
def validate_tz_info(self, _: str, value: datetime.datetime) -> datetime.datetime:
if value.tzinfo is None:
value = value.replace(tzinfo=datetime.timezone.utc)
return value
python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/bigint.py 0000664 0000000 0000000 00000001037 14766637146 0024646 0 ustar 00root root 0000000 0000000 from sqlalchemy import Sequence
from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr, mapped_column
from advanced_alchemy.types import BigIntIdentity
@declarative_mixin
class BigIntPrimaryKey:
"""BigInt Primary Key Field Mixin."""
@declared_attr
def id(cls) -> Mapped[int]:
"""BigInt Primary key column."""
return mapped_column(
BigIntIdentity,
Sequence(f"{cls.__tablename__}_id_seq", optional=False), # type: ignore[attr-defined]
primary_key=True,
)
python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/nanoid.py 0000664 0000000 0000000 00000001303 14766637146 0024636 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING
from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
from advanced_alchemy.mixins.sentinel import SentinelMixin
from advanced_alchemy.types import NANOID_INSTALLED
if NANOID_INSTALLED and not TYPE_CHECKING:
from fastnanoid import ( # type: ignore[import-not-found,unused-ignore] # pyright: ignore[reportMissingImports]
generate as nanoid,
)
else:
from uuid import uuid4 as nanoid # type: ignore[assignment,unused-ignore]
@declarative_mixin
class NanoIDPrimaryKey(SentinelMixin):
"""Nano ID Primary Key Field Mixin."""
id: Mapped[str] = mapped_column(default=nanoid, primary_key=True)
"""Nano ID Primary key column."""
python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/sentinel.py 0000664 0000000 0000000 00000000471 14766637146 0025214 0 ustar 00root root 0000000 0000000 from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr, orm_insert_sentinel
@declarative_mixin
class SentinelMixin:
"""Mixin to add a sentinel column for SQLAlchemy models."""
@declared_attr
def _sentinel(cls) -> Mapped[int]:
return orm_insert_sentinel(name="sa_orm_sentinel")
python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/slug.py 0000664 0000000 0000000 00000002617 14766637146 0024351 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any
from sqlalchemy import Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr, mapped_column
if TYPE_CHECKING:
from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType # pyright: ignore[reportPrivateUsage]
@declarative_mixin
class SlugKey:
"""Slug unique Field Model Mixin."""
@declared_attr
def slug(cls) -> Mapped[str]:
"""Slug field."""
return mapped_column(
String(length=100),
nullable=False,
)
@staticmethod
def _create_unique_slug_index(*_: Any, **kwargs: Any) -> bool:
return bool(kwargs["dialect"].name.startswith("spanner"))
@staticmethod
def _create_unique_slug_constraint(*_: Any, **kwargs: Any) -> bool:
return not kwargs["dialect"].name.startswith("spanner")
@declared_attr.directive
@classmethod
def __table_args__(cls) -> "TableArgsType":
return (
UniqueConstraint(
cls.slug,
name=f"uq_{cls.__tablename__}_slug", # type: ignore[attr-defined]
).ddl_if(callable_=cls._create_unique_slug_constraint),
Index(
f"ix_{cls.__tablename__}_slug_unique", # type: ignore[attr-defined]
cls.slug,
unique=True,
).ddl_if(callable_=cls._create_unique_slug_index),
)
python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/unique.py 0000664 0000000 0000000 00000013132 14766637146 0024677 0 ustar 00root root 0000000 0000000 from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union
from sqlalchemy import ColumnElement, select
from sqlalchemy.orm import declarative_mixin
from typing_extensions import Self
from advanced_alchemy.exceptions import wrap_sqlalchemy_exception
if TYPE_CHECKING:
from collections.abc import Hashable, Iterator
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import Session
from sqlalchemy.orm.scoping import scoped_session
__all__ = ("UniqueMixin",)
@declarative_mixin
class UniqueMixin:
"""Mixin for instantiating objects while ensuring uniqueness on some field(s).
This is a slightly modified implementation derived from https://github.com/sqlalchemy/sqlalchemy/wiki/UniqueObject
"""
@classmethod
@contextmanager
def _prevent_autoflush(
cls,
session: "Union[AsyncSession, async_scoped_session[AsyncSession], Session, scoped_session[Session]]",
) -> "Iterator[None]":
with session.no_autoflush, wrap_sqlalchemy_exception():
yield
@classmethod
def _check_uniqueness(
cls,
cache: "Optional[dict[tuple[type[Self], Hashable], Self]]",
session: "Union[AsyncSession, async_scoped_session[AsyncSession], Session, scoped_session[Session]]",
key: "tuple[type[Self], Hashable]",
*args: Any,
**kwargs: Any,
) -> "tuple[dict[tuple[type[Self], Hashable], Self], Select[tuple[Self]], Optional[Self]]":
if cache is None:
cache = {}
setattr(session, "_unique_cache", cache)
statement = select(cls).where(cls.unique_filter(*args, **kwargs)).limit(2)
return cache, statement, cache.get(key)
@classmethod
async def as_unique_async(
cls,
session: "Union[AsyncSession, async_scoped_session[AsyncSession]]",
*args: Any,
**kwargs: Any,
) -> Self:
"""Instantiate and return a unique object within the provided session based on the given arguments.
If an object with the same unique identifier already exists in the session, it is returned from the cache.
Args:
session (AsyncSession | async_scoped_session[AsyncSession]): SQLAlchemy async session
*args (Any): Values used to instantiate the instance if no duplicate exists
**kwargs (Any): Values used to instantiate the instance if no duplicate exists
Returns:
Self: The unique object instance.
"""
key = cls, cls.unique_hash(*args, **kwargs)
cache, statement, obj = cls._check_uniqueness(
getattr(session, "_unique_cache", None),
session,
key,
*args,
**kwargs,
)
if obj:
return obj
with cls._prevent_autoflush(session):
if (obj := (await session.execute(statement)).scalar_one_or_none()) is None:
session.add(obj := cls(*args, **kwargs))
cache[key] = obj
return obj
@classmethod
def as_unique_sync(
cls,
session: "Union[Session, scoped_session[Session]]",
*args: Any,
**kwargs: Any,
) -> Self:
"""Instantiate and return a unique object within the provided session based on the given arguments.
If an object with the same unique identifier already exists in the session, it is returned from the cache.
Args:
session (Session | scoped_session[Session]): SQLAlchemy sync session
*args (Any): Values used to instantiate the instance if no duplicate exists
**kwargs (Any): Values used to instantiate the instance if no duplicate exists
Returns:
Self: The unique object instance.
"""
key = cls, cls.unique_hash(*args, **kwargs)
cache, statement, obj = cls._check_uniqueness(
getattr(session, "_unique_cache", None),
session,
key,
*args,
**kwargs,
)
if obj:
return obj
with cls._prevent_autoflush(session):
if (obj := session.execute(statement).scalar_one_or_none()) is None:
session.add(obj := cls(*args, **kwargs))
cache[key] = obj
return obj
@classmethod
def unique_hash(cls, *args: Any, **kwargs: Any) -> "Hashable":
"""Generate a unique key based on the provided arguments.
This method should be implemented in the subclass.
Args:
*args (Any): Values passed to the alternate classmethod constructors
**kwargs (Any): Values passed to the alternate classmethod constructors
Raises:
NotImplementedError: If not implemented in the subclass.
Returns:
Hashable: Any hashable object.
"""
msg = "Implement this in subclass"
raise NotImplementedError(msg)
@classmethod
def unique_filter(cls, *args: Any, **kwargs: Any) -> "ColumnElement[bool]":
"""Generate a filter condition for ensuring uniqueness.
This method should be implemented in the subclass.
Args:
*args (Any): Values passed to the alternate classmethod constructors
**kwargs (Any): Values passed to the alternate classmethod constructors
Raises:
NotImplementedError: If not implemented in the subclass.
Returns:
ColumnElement[bool]: Filter condition to establish the uniqueness.
"""
msg = "Implement this in subclass"
raise NotImplementedError(msg)
python-advanced-alchemy-1.0.1/advanced_alchemy/mixins/uuid.py 0000664 0000000 0000000 00000002401 14766637146 0024334 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING
from uuid import UUID, uuid4
from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
from advanced_alchemy.mixins.sentinel import SentinelMixin
from advanced_alchemy.types import UUID_UTILS_INSTALLED
if UUID_UTILS_INSTALLED and not TYPE_CHECKING:
from uuid_utils.compat import ( # type: ignore[no-redef,unused-ignore] # pyright: ignore[reportMissingImports]
uuid4,
uuid6,
uuid7,
)
else:
from uuid import uuid4 # type: ignore[no-redef,unused-ignore]
uuid6 = uuid4 # type: ignore[assignment, unused-ignore]
uuid7 = uuid4 # type: ignore[assignment, unused-ignore]
@declarative_mixin
class UUIDPrimaryKey(SentinelMixin):
"""UUID Primary Key Field Mixin."""
id: Mapped[UUID] = mapped_column(default=uuid4, primary_key=True)
"""UUID Primary key column."""
@declarative_mixin
class UUIDv6PrimaryKey(SentinelMixin):
"""UUID v6 Primary Key Field Mixin."""
id: Mapped[UUID] = mapped_column(default=uuid6, primary_key=True)
"""UUID Primary key column."""
@declarative_mixin
class UUIDv7PrimaryKey(SentinelMixin):
"""UUID v7 Primary Key Field Mixin."""
id: Mapped[UUID] = mapped_column(default=uuid7, primary_key=True)
"""UUID Primary key column."""
python-advanced-alchemy-1.0.1/advanced_alchemy/operations.py 0000664 0000000 0000000 00000002422 14766637146 0024245 0 ustar 00root root 0000000 0000000 """Advanced database operations for SQLAlchemy.
This module provides high-performance database operations that extend beyond basic CRUD
functionality. It implements specialized database operations optimized for bulk data
handling and schema management.
The operations module is designed to work seamlessly with SQLAlchemy Core and ORM,
providing efficient implementations for common database operations patterns.
Features
--------
- Table merging and upsert operations
- Dynamic table creation from SELECT statements
- Bulk data import/export operations
- Optimized copy operations for PostgreSQL
- Transaction-safe batch operations
Todo:
-----
- Implement merge operations with customizable conflict resolution
- Add CTAS (Create Table As Select) functionality
- Implement bulk copy operations (COPY TO/FROM) for PostgreSQL
- Add support for temporary table operations
- Implement materialized view operations
Notes:
------
This module is designed to be database-agnostic where possible, with specialized
optimizations for specific database backends where appropriate.
See Also:
---------
- :mod:`sqlalchemy.sql.expression` : SQLAlchemy Core expression language
- :mod:`sqlalchemy.orm` : SQLAlchemy ORM functionality
- :mod:`advanced_alchemy.extensions` : Additional database extensions
"""
python-advanced-alchemy-1.0.1/advanced_alchemy/py.typed 0000664 0000000 0000000 00000000000 14766637146 0023175 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/repository/ 0000775 0000000 0000000 00000000000 14766637146 0023727 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/repository/__init__.py 0000664 0000000 0000000 00000003025 14766637146 0026040 0 ustar 00root root 0000000 0000000 from advanced_alchemy.exceptions import ErrorMessages
from advanced_alchemy.repository._async import (
SQLAlchemyAsyncQueryRepository,
SQLAlchemyAsyncRepository,
SQLAlchemyAsyncRepositoryProtocol,
SQLAlchemyAsyncSlugRepository,
SQLAlchemyAsyncSlugRepositoryProtocol,
)
from advanced_alchemy.repository._sync import (
SQLAlchemySyncQueryRepository,
SQLAlchemySyncRepository,
SQLAlchemySyncRepositoryProtocol,
SQLAlchemySyncSlugRepository,
SQLAlchemySyncSlugRepositoryProtocol,
)
from advanced_alchemy.repository._util import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
FilterableRepository,
FilterableRepositoryProtocol,
LoadSpec,
get_instrumented_attr,
model_from_dict,
)
from advanced_alchemy.repository.typing import ModelOrRowMappingT, ModelT, OrderingPair
from advanced_alchemy.utils.dataclass import Empty, EmptyType
__all__ = (
"DEFAULT_ERROR_MESSAGE_TEMPLATES",
"Empty",
"EmptyType",
"ErrorMessages",
"FilterableRepository",
"FilterableRepositoryProtocol",
"LoadSpec",
"ModelOrRowMappingT",
"ModelT",
"OrderingPair",
"SQLAlchemyAsyncQueryRepository",
"SQLAlchemyAsyncRepository",
"SQLAlchemyAsyncRepositoryProtocol",
"SQLAlchemyAsyncSlugRepository",
"SQLAlchemyAsyncSlugRepositoryProtocol",
"SQLAlchemySyncQueryRepository",
"SQLAlchemySyncRepository",
"SQLAlchemySyncRepositoryProtocol",
"SQLAlchemySyncSlugRepository",
"SQLAlchemySyncSlugRepositoryProtocol",
"get_instrumented_attr",
"model_from_dict",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/repository/_async.py 0000664 0000000 0000000 00000306273 14766637146 0025570 0 ustar 00root root 0000000 0000000 import random
import string
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
Optional,
Protocol,
Union,
cast,
runtime_checkable,
)
from sqlalchemy import (
Delete,
Result,
Row,
Select,
TextClause,
Update,
any_,
delete,
over,
select,
text,
update,
)
from sqlalchemy import func as sql_func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import StatementFilter, StatementTypeT
from advanced_alchemy.repository._util import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
FilterableRepository,
FilterableRepositoryProtocol,
LoadSpec,
get_abstract_loader_options,
get_instrumented_attr,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, T
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
if TYPE_CHECKING:
from sqlalchemy.engine.interfaces import _CoreSingleExecuteParams # pyright: ignore[reportPrivateUsage]
DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
@runtime_checkable
class SQLAlchemyAsyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Base Protocol"""
id_attribute: str
match_fields: Optional[Union[list[str], str]] = None
statement: Select[tuple[ModelT]]
session: Union[AsyncSession, async_scoped_session[AsyncSession]]
auto_expunge: bool
auto_refresh: bool
auto_commit: bool
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None
error_messages: Optional[ErrorMessages] = None
wrap_exceptions: bool = True
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None: ...
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any: ...
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT: ...
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT: ...
async def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT: ...
async def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]: ...
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> Sequence[ModelT]: ...
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
sanity_check: bool = True,
**kwargs: Any,
) -> Sequence[ModelT]: ...
async def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool: ...
async def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> ModelT: ...
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[ModelT]: ...
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
async def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> int: ...
async def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
async def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> list[ModelT]: ...
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]: ...
async def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
async def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> list[ModelT]: ...
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
count_with_window_function: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]: ...
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
**kwargs: Any,
) -> list[ModelT]: ...
@classmethod
async def check_health(cls, session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> bool: ...
@runtime_checkable
class SQLAlchemyAsyncSlugRepositoryProtocol(SQLAlchemyAsyncRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Protocol for SQLAlchemy repositories that support slug-based operations.
Extends the base repository protocol to add slug-related functionality.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
"""
async def get_by_slug(
self,
slug: str,
*,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Get a model instance by its slug.
Args:
slug: The slug value to search for.
error_messages: Optional custom error message templates.
load: Specification for eager loading of relationships.
execution_options: Options for statement execution.
**kwargs: Additional filtering criteria.
Returns:
ModelT | None: The found model instance or None if not found.
"""
...
async def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Generate a unique slug for a given value.
Args:
value_to_slugify: The string to convert to a slug.
**kwargs: Additional parameters for slug generation.
Returns:
str: A unique slug derived from the input value.
"""
...
class SQLAlchemyAsyncRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT], FilterableRepository[ModelT]):
"""Async SQLAlchemy repository implementation.
Provides a complete implementation of async database operations using SQLAlchemy,
including CRUD operations, filtering, and relationship loading.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
.. seealso::
:class:`~advanced_alchemy.repository._util.FilterableRepository`
"""
id_attribute: str = "id"
"""Name of the unique identifier for the model."""
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
error_messages: Optional[ErrorMessages] = None
"""Default error messages for the repository."""
wrap_exceptions: bool = True
"""Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised."""
inherit_lazy_relationships: bool = True
"""Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
merge_loader_options: bool = True
"""Merges the default loader options with the loader options specified in the ``load`` argument. This is useful for when you want to totally
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
match_fields: Optional[Union[list[str], str]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: bool = False
"""Optionally apply the ``unique()`` method to results before returning.
This is useful for certain SQLAlchemy uses cases such as applying ``contains_eager`` to a query containing a one-to-many relationship
"""
count_with_window_function: bool = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query.
"""
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
wrap_exceptions: bool = True,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
**kwargs: Any,
) -> None:
"""Repository for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
load: Set default relationships to be loaded
execution_options: Set default execution options
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
**kwargs: Additional arguments.
"""
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.order_by = order_by
self.session = session
self.error_messages = self._get_error_messages(
error_messages=error_messages, default_messages=self.error_messages
)
self.wrap_exceptions = wrap_exceptions
self.uniquify = self._get_uniquify(uniquify)
self.count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options(
loader_options=load if load is not None else self.loader_options,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
execution_options = execution_options if execution_options is not None else self.execution_options
self._default_execution_options = execution_options or {}
self.statement = select(self.model_type) if statement is None else statement
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
return bool(uniquify) if uniquify is not None else self.uniquify
@staticmethod
def _get_error_messages(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
if default_messages == Empty:
default_messages = None
messages = DEFAULT_ERROR_MESSAGE_TEMPLATES
if default_messages and isinstance(default_messages, dict):
messages.update(default_messages)
if error_messages:
messages.update(cast("ErrorMessages", error_messages))
return messages
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT:
"""Raise :exc:`advanced_alchemy.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item (:class:`T `) to be tested for existence.
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
def _get_execution_options(
self,
execution_options: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
if execution_options is None:
return self._default_execution_options
return execution_options
def _get_loader_options(
self,
loader_options: Optional[LoadSpec],
) -> Union[tuple[list[_AbstractLoad], bool], tuple[None, bool]]:
if loader_options is None:
# use the defaults set at initialization
return self._default_loader_options, self._loader_options_have_wildcards or self.uniquify
return get_abstract_loader_options(
loader_options=loader_options,
default_loader_options=self._default_loader_options,
default_options_have_wildcards=self._loader_options_have_wildcards or self.uniquify,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
async def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT:
"""Add ``data`` to the collection.
Args:
data: Instance to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
The added instance.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
instance = await self._attach_to_session(data)
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(instance, auto_refresh=auto_refresh)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]:
"""Add many `data` to the collection.
Args:
data: list of Instances to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
The added instances.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
self.session.add_all(data)
await self._flush_or_commit(auto_commit=auto_commit)
for datum in data:
self._expunge(datum, auto_expunge=auto_expunge)
return data
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Delete instance identified by ``item_id``.
Args:
item_id: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The deleted instance.
Raises:
NotFoundError: If no instance found identified by ``item_id``.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
instance = await self.get(
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
)
await self.session.delete(instance)
await self._flush_or_commit(auto_commit=auto_commit)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Delete instance identified by `item_id`.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The deleted instances.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = get_instrumented_attr(
self.model_type,
id_attribute if id_attribute is not None else self.id_attribute,
)
instances: list[ModelT] = []
if self._prefer_any:
chunk_size = len(item_ids) + 1
chunk_size = self._get_insertmanyvalues_max_parameters(chunk_size)
for idx in range(0, len(item_ids), chunk_size):
chunk = item_ids[idx : min(idx + chunk_size, len(item_ids))]
if self._dialect.delete_executemany_returning:
instances.extend(
await self.session.scalars(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
else:
instances.extend(
await self.session.scalars(
self._get_delete_many_statement(
statement_type="select",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
await self.session.execute(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def _get_insertmanyvalues_max_parameters(self, chunk_size: Optional[int] = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Delete instances specified by referenced kwargs and filters.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Arguments to apply to a delete
Returns:
The deleted instances.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
model_type = self.model_type
statement = self._get_base_stmt(
statement=delete(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs)
statement = self._apply_filters(*filters, statement=statement, apply_pagination=False)
instances: list[ModelT] = []
if self._dialect.delete_executemany_returning:
instances.extend(await self.session.scalars(statement.returning(model_type)))
else:
instances.extend(
await self.list(
*filters,
load=load,
execution_options=execution_options,
auto_expunge=auto_expunge,
**kwargs,
),
)
result = await self.session.execute(statement)
row_count = getattr(result, "rowcount", -2)
if sanity_check and row_count >= 0 and len(instances) != row_count: # pyright: ignore # noqa: PGH003
# backends will return a -1 if they can't determine impacted rowcount
# only compare length of selected instances to results if it's >= 0
await self.session.rollback()
raise RepositoryError(detail="Deleted count does not match fetched count. Rollback issued.")
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
async def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
True if the instance was found. False if not found..
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
existing = await self.count(
*filters,
load=load,
execution_options=execution_options,
error_messages=error_messages,
**kwargs,
)
return existing > 0
def _get_base_stmt(
self,
*,
statement: StatementTypeT,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> StatementTypeT:
"""Get base statement with options applied.
Args:
statement: The select statement to modify
loader_options: Options for loading relationships
execution_options: Options for statement execution
Returns:
Modified select statement
"""
if loader_options:
statement = cast("StatementTypeT", statement.options(*loader_options))
if execution_options:
statement = cast("StatementTypeT", statement.execution_options(**execution_options))
return statement
def _get_delete_many_statement(
self,
*,
model_type: type[ModelT],
id_attribute: InstrumentedAttribute[Any],
id_chunk: list[Any],
supports_returning: bool,
statement_type: Literal["delete", "select"] = "delete",
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Select[tuple[ModelT]], Delete, ReturningDelete[tuple[ModelT]]]:
# Base statement is static
statement = self._get_base_stmt(
statement=delete(model_type) if statement_type == "delete" else select(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
if execution_options:
statement = statement.execution_options(**execution_options)
if supports_returning and statement_type != "select":
statement = cast("ReturningDelete[tuple[ModelT]]", statement.returning(model_type)) # type: ignore[union-attr,assignment] # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
if self._prefer_any:
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
async def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Get instance identified by `item_id`.
Args:
item_id: Identifier of the instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The retrieved instance.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = id_attribute if id_attribute is not None else self.id_attribute
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)])
instance = (await self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
"""Get instance identified by ``kwargs``.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
instance = (await self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance or None
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
instance = cast(
"Result[tuple[ModelT]]",
(await self._execute(statement, uniquify=loader_options_have_wildcard)),
).scalar_one_or_none()
if instance:
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Union[bool, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` or create if it doesn't exist.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, automatically perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be created.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = await self.get_one_or_none(
*filters,
**match_filter,
load=load,
execution_options=execution_options,
)
if not existing:
return (
await self.add(
self.model_type(**kwargs),
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
),
True,
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, False
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` and update the model if the arguments are different.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be updated.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = await self.get_one(*filters, **match_filter, load=load, execution_options=execution_options)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, updated
async def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
"""Get the count of records returned by a query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
results = await self._execute(
statement=self._get_count_stmt(
statement=statement, loader_options=loader_options, execution_options=execution_options
),
uniquify=loader_options_have_wildcard,
)
return cast("int", results.scalar_one())
async def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.
Args:
data: An instance that should have a value for `self.id_attribute` that
exists in the collection.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated instance.
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
item_id = self.get_id_attribute_value(
data,
id_attribute=id_attribute,
)
# this will raise for not found, and will put the item in the session
await self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options)
# this will merge the inbound data to the instance we just put in the session
instance = await self._attach_to_session(data, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
"""Update one or more instances with the attribute values present on `data`.
This function has an optimized bulk update based on the configured SQL dialect:
- For backends supporting `RETURNING` with `executemany`, a single bulk update with returning clause is executed.
- For other backends, it does a bulk update and then returns the updated data after a refresh.
Args:
data: A list of instances to update. Each should have a value for `self.id_attribute` that exists in the
collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated instances.
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore[misc]
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options = self._get_loader_options(load)[0]
supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle"
statement = self._get_update_many_statement(
self.model_type,
supports_returning,
loader_options=loader_options,
execution_options=execution_options,
)
if supports_returning:
instances = list(
await self.session.scalars(
statement,
cast("_CoreSingleExecuteParams", data_to_update), # this is not correct but the only way
# currently to deal with an SQLAlchemy typing issue. See
# https://github.com/sqlalchemy/sqlalchemy/discussions/9925
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
await self.session.execute(statement, data_to_update, execution_options=execution_options)
await self._flush_or_commit(auto_commit=auto_commit)
return data
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Union[list[_AbstractLoad], None],
execution_options: Union[dict[str, Any], None],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
# Base update statement is static
statement = self._get_base_stmt(
statement=update(table=model_type), loader_options=loader_options, execution_options=execution_options
)
if supports_returning:
return statement.returning(model_type)
return statement
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if self._dialect.name in {"spanner", "spanner+spanner"} or not count_with_window_function:
return await self._list_and_count_basic(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
order_by=order_by,
error_messages=error_messages,
**kwargs,
)
return await self._list_and_count_window(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
error_messages=error_messages,
order_by=order_by,
**kwargs,
)
def _expunge(self, instance: ModelT, auto_expunge: Optional[bool]) -> None:
if auto_expunge is None:
auto_expunge = self.auto_expunge
return self.session.expunge(instance) if auto_expunge else None
async def _flush_or_commit(self, auto_commit: Optional[bool]) -> None:
if auto_commit is None:
auto_commit = self.auto_commit
return await self.session.commit() if auto_commit else await self.session.flush()
async def _refresh(
self,
instance: ModelT,
auto_refresh: Optional[bool],
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
) -> None:
if auto_refresh is None:
auto_refresh = self.auto_refresh
return (
await self.session.refresh(
instance=instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
if auto_refresh
else None
)
async def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: List[OrderingPair] | OrderingPair | None = None,
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by or []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = await self._execute(
statement.add_columns(over(sql_func.count())), uniquify=loader_options_have_wildcard
)
count: int = 0
instances: list[ModelT] = []
for i, (instance, count_value) in enumerate(result):
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
if i == 0:
count = count_value
return instances, count
async def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by or []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
count_result = await self.session.execute(
self._get_count_stmt(
statement,
loader_options=loader_options,
execution_options=execution_options,
),
)
count = count_result.scalar_one()
result = await self._execute(statement, uniquify=loader_options_have_wildcard)
instances: list[ModelT] = []
for (instance,) in result:
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
return instances, count
def _get_count_stmt(
self,
statement: Select[tuple[ModelT]],
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Select[tuple[int]]:
# Count statement transformations are static
return (
statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True)
.limit(None)
.offset(None)
.order_by(None)
)
async def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Modify or create instance.
Updates instance with the attribute values present on `data`, or creates a new instance if
one doesn't exist.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated or created instance.
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: getattr(data, field_name, None)
for field_name in match_fields
if getattr(data, field_name, None) is not None
}
elif getattr(data, self.id_attribute, None) is not None:
match_filter = {self.id_attribute: getattr(data, self.id_attribute, None)}
else:
match_filter = data.to_dict(exclude={self.id_attribute})
existing = await self.get_one_or_none(load=load, execution_options=execution_options, **match_filter)
if not existing:
return await self.add(
data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
instance = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
"""Modify or create multiple instances.
Update instances with the attribute values present on `data`, or create a new instance if
one doesn't exist.
!!! tip
In most cases, you will want to set `match_fields` to the combination of attributes, excluded the primary key, that define uniqueness for a row.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on ``data`` named as value of
:attr:`id_attribute`.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements
match_fields: a list of keys to use to match the existing model. When
empty, automatically uses ``self.id_attribute`` (`id` by default) to match .
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated or created instance.
Raises:
NotFoundError: If no instance found with same identifier as ``data``.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
instances: list[ModelT] = []
data_to_update: list[ModelT] = []
data_to_insert: list[ModelT] = []
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
match_filter: list[Union[StatementFilter, ColumnElement[bool]]] = []
if match_fields:
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
]
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
existing_objs = await self.list(
*match_filter,
load=load,
execution_options=execution_options,
auto_expunge=False,
)
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = list(
{getattr(datum, field_name) for datum in existing_objs if datum}, # ensure the list is unique
)
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
if getattr(datum, self.id_attribute, None) in existing_ids:
data_to_update.append(datum)
else:
data_to_insert.append(datum)
if data_to_insert:
instances.extend(
await self.add_many(data_to_insert, auto_commit=False, auto_expunge=False),
)
if data_to_update:
instances.extend(
await self.update_many(
data_to_update,
auto_commit=False,
auto_expunge=False,
load=load,
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]:
return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None]
def _get_match_fields(
self,
match_fields: Optional[Union[list[str], str]] = None,
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
def _merge_on_match_fields(
self,
data: list[ModelT],
existing_data: list[ModelT],
match_fields: Optional[Union[list[str], str]] = None,
) -> list[ModelT]:
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for _row_id, datum in enumerate(data):
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(datum, self.id_attribute, getattr(existing_datum, self.id_attribute))
return data
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
"""Get a list of instances, optionally filtered.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by or []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = await self._execute(statement, uniquify=loader_options_have_wildcard)
instances = list(result.scalars())
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return cast("list[ModelT]", instances)
@classmethod
async def check_health(cls, session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> bool:
"""Perform a health check on the database.
Args:
session: through which we run a check statement
Returns:
``True`` if healthy.
"""
with wrap_sqlalchemy_exception():
return ( # type: ignore[no-any-return]
await session.execute(cls._get_health_check_statement(session))
).scalar_one() == 1
@staticmethod
def _get_health_check_statement(session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> TextClause:
if session.bind and session.bind.dialect.name == "oracle":
return text("SELECT 1 FROM DUAL")
return text("SELECT 1")
async def _attach_to_session(
self, model: ModelT, strategy: Literal["add", "merge"] = "add", load: bool = True
) -> ModelT:
"""Attach detached instance to the session.
Args:
model: The instance to be attached to the session.
strategy: How the instance should be attached.
- "add": New instance added to session
- "merge": Instance merged with existing, or new one added.
load: Boolean, when False, merge switches into
a "high performance" mode which causes it to forego emitting history
events as well as all database access. This flag is used for
cases such as transferring graphs of objects into a session
from a second level cache, or to transfer just-loaded objects
into the session owned by a worker thread or process
without re-querying the database.
Returns:
Instance attached to the session - if `"merge"` strategy, may not be same instance
that was provided.
"""
if strategy == "add":
self.session.add(model)
return model
if strategy == "merge":
return await self.session.merge(model, load=load)
msg = "Unexpected value for `strategy`, must be `'add'` or `'merge'`" # type: ignore[unreachable]
raise ValueError(msg)
async def _execute(
self,
statement: Select[Any],
uniquify: bool = False,
) -> Result[Any]:
result = await self.session.execute(statement)
if uniquify or self.uniquify:
result = result.unique()
return result
class SQLAlchemyAsyncSlugRepository(
SQLAlchemyAsyncRepository[ModelT],
SQLAlchemyAsyncSlugRepositoryProtocol[ModelT],
):
"""Extends the repository to include slug model features.."""
async def get_by_slug(
self,
slug: str,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Select record by slug value."""
return await self.get_one_or_none(
slug=slug,
load=load,
execution_options=execution_options,
error_messages=error_messages,
uniquify=uniquify,
)
async def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if await self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
async def _is_slug_unique(
self,
slug: str,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool:
return await self.exists(slug=slug, load=load, execution_options=execution_options, **kwargs) is False
class SQLAlchemyAsyncQueryRepository:
"""SQLAlchemy Query Repository.
This is a loosely typed helper to query for when you need to select data in ways that don't align to the normal repository pattern.
"""
error_messages: Optional[ErrorMessages] = None
def __init__(
self,
*,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
error_messages: Optional[ErrorMessages] = None,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
session: Session managing the unit-of-work for the operation.
error_messages: A set of error messages to use for operations.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.session = session
self.error_messages = error_messages
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
async def get_one(
self,
statement: Select[tuple[Any]],
**kwargs: Any,
) -> Row[Any]:
"""Get instance identified by ``kwargs``.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
instance = (await self.execute(statement)).scalar_one_or_none()
return self.check_not_found(instance)
async def get_one_or_none(
self,
statement: Select[Any],
**kwargs: Any,
) -> Optional[Row[Any]]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance or None
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
instance = (await self.execute(statement)).scalar_one_or_none()
return instance or None
async def count(self, statement: Select[Any], **kwargs: Any) -> int:
"""Get the count of records returned by a query.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(
None,
)
statement = self._filter_statement_by_kwargs(statement, **kwargs)
results = await self.execute(statement)
return results.scalar_one() # type: ignore # noqa: PGH003
async def list_and_count(
self,
statement: Select[Any],
count_with_window_function: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query.
count_with_window_function: Force list and count to use two queries instead of an analytical window function.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
if self._dialect.name in {"spanner", "spanner+spanner"} or count_with_window_function:
return await self._list_and_count_basic(statement=statement, **kwargs)
return await self._list_and_count_window(statement=statement, **kwargs)
async def _list_and_count_window(
self,
statement: Select[Any],
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = statement.add_columns(over(sql_func.count(text("1"))))
statement = self._filter_statement_by_kwargs(statement, **kwargs)
result = await self.execute(statement)
count: int = 0
instances: list[Row[Any]] = []
for i, (instance, count_value) in enumerate(result):
instances.append(instance)
if i == 0:
count = count_value
return instances, count
def _get_count_stmt(self, statement: Select[Any]) -> Select[Any]:
return statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None) # pyright: ignore[reportUnknownVariable]
async def _list_and_count_basic(
self,
statement: Select[Any],
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query. .
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
count_result = await self.session.execute(self._get_count_stmt(statement))
count = count_result.scalar_one()
result = await self.execute(statement)
instances: list[Row[Any]] = []
for (instance,) in result:
instances.append(instance)
return instances, count
async def list(self, statement: Select[Any], **kwargs: Any) -> list[Row[Any]]:
"""Get a list of instances, optionally filtered.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
result = await self.execute(statement)
return list(result.all())
def _filter_statement_by_kwargs(
self,
statement: Select[Any],
/,
**kwargs: Any,
) -> Select[Any]:
"""Filter the collection by kwargs.
Args:
statement: statement to filter
**kwargs: key/value pairs such that objects remaining in the statement after filtering
have the property that their attribute named `key` has value equal to `value`.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
return statement.filter_by(**kwargs)
# the following is all sqlalchemy implementation detail, and shouldn't be directly accessed
@staticmethod
def check_not_found(item_or_none: Optional[T]) -> T:
"""Raise :class:`NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item to be tested for existence.
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
async def execute(
self,
statement: Union[
ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Update, Delete, Select[Any]
],
) -> Result[Any]:
return await self.session.execute(statement)
python-advanced-alchemy-1.0.1/advanced_alchemy/repository/_sync.py 0000664 0000000 0000000 00000304504 14766637146 0025422 0 ustar 00root root 0000000 0000000 # Do not edit this file directly. It has been autogenerated from
# advanced_alchemy/repository/_async.py
import random
import string
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
Optional,
Protocol,
Union,
cast,
runtime_checkable,
)
from sqlalchemy import (
Delete,
Result,
Row,
Select,
TextClause,
Update,
any_,
delete,
over,
select,
text,
update,
)
from sqlalchemy import func as sql_func
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import StatementFilter, StatementTypeT
from advanced_alchemy.repository._util import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
FilterableRepository,
FilterableRepositoryProtocol,
LoadSpec,
get_abstract_loader_options,
get_instrumented_attr,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, T
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
if TYPE_CHECKING:
from sqlalchemy.engine.interfaces import _CoreSingleExecuteParams # pyright: ignore[reportPrivateUsage]
DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
@runtime_checkable
class SQLAlchemySyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Base Protocol"""
id_attribute: str
match_fields: Optional[Union[list[str], str]] = None
statement: Select[tuple[ModelT]]
session: Union[Session, scoped_session[Session]]
auto_expunge: bool
auto_refresh: bool
auto_commit: bool
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None
error_messages: Optional[ErrorMessages] = None
wrap_exceptions: bool = True
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[Session, scoped_session[Session]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None: ...
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any: ...
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT: ...
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT: ...
def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT: ...
def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]: ...
def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> Sequence[ModelT]: ...
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
sanity_check: bool = True,
**kwargs: Any,
) -> Sequence[ModelT]: ...
def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool: ...
def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> ModelT: ...
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[ModelT]: ...
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> int: ...
def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> list[ModelT]: ...
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]: ...
def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> ModelT: ...
def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
) -> list[ModelT]: ...
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
count_with_window_function: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]: ...
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
**kwargs: Any,
) -> list[ModelT]: ...
@classmethod
def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool: ...
@runtime_checkable
class SQLAlchemySyncSlugRepositoryProtocol(SQLAlchemySyncRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Protocol for SQLAlchemy repositories that support slug-based operations.
Extends the base repository protocol to add slug-related functionality.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
"""
def get_by_slug(
self,
slug: str,
*,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Get a model instance by its slug.
Args:
slug: The slug value to search for.
error_messages: Optional custom error message templates.
load: Specification for eager loading of relationships.
execution_options: Options for statement execution.
**kwargs: Additional filtering criteria.
Returns:
ModelT | None: The found model instance or None if not found.
"""
...
def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Generate a unique slug for a given value.
Args:
value_to_slugify: The string to convert to a slug.
**kwargs: Additional parameters for slug generation.
Returns:
str: A unique slug derived from the input value.
"""
...
class SQLAlchemySyncRepository(SQLAlchemySyncRepositoryProtocol[ModelT], FilterableRepository[ModelT]):
"""Async SQLAlchemy repository implementation.
Provides a complete implementation of async database operations using SQLAlchemy,
including CRUD operations, filtering, and relationship loading.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
.. seealso::
:class:`~advanced_alchemy.repository._util.FilterableRepository`
"""
id_attribute: str = "id"
"""Name of the unique identifier for the model."""
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
error_messages: Optional[ErrorMessages] = None
"""Default error messages for the repository."""
wrap_exceptions: bool = True
"""Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised."""
inherit_lazy_relationships: bool = True
"""Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
merge_loader_options: bool = True
"""Merges the default loader options with the loader options specified in the ``load`` argument. This is useful for when you want to totally
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
match_fields: Optional[Union[list[str], str]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: bool = False
"""Optionally apply the ``unique()`` method to results before returning.
This is useful for certain SQLAlchemy uses cases such as applying ``contains_eager`` to a query containing a one-to-many relationship
"""
count_with_window_function: bool = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query.
"""
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[Session, scoped_session[Session]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
wrap_exceptions: bool = True,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
**kwargs: Any,
) -> None:
"""Repository for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
load: Set default relationships to be loaded
execution_options: Set default execution options
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
**kwargs: Additional arguments.
"""
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.order_by = order_by
self.session = session
self.error_messages = self._get_error_messages(
error_messages=error_messages, default_messages=self.error_messages
)
self.wrap_exceptions = wrap_exceptions
self.uniquify = self._get_uniquify(uniquify)
self.count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options(
loader_options=load if load is not None else self.loader_options,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
execution_options = execution_options if execution_options is not None else self.execution_options
self._default_execution_options = execution_options or {}
self.statement = select(self.model_type) if statement is None else statement
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
return bool(uniquify) if uniquify is not None else self.uniquify
@staticmethod
def _get_error_messages(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
if default_messages == Empty:
default_messages = None
messages = DEFAULT_ERROR_MESSAGE_TEMPLATES
if default_messages and isinstance(default_messages, dict):
messages.update(default_messages)
if error_messages:
messages.update(cast("ErrorMessages", error_messages))
return messages
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT:
"""Raise :exc:`advanced_alchemy.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item (:class:`T `) to be tested for existence.
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
def _get_execution_options(
self,
execution_options: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
if execution_options is None:
return self._default_execution_options
return execution_options
def _get_loader_options(
self,
loader_options: Optional[LoadSpec],
) -> Union[tuple[list[_AbstractLoad], bool], tuple[None, bool]]:
if loader_options is None:
# use the defaults set at initialization
return self._default_loader_options, self._loader_options_have_wildcards or self.uniquify
return get_abstract_loader_options(
loader_options=loader_options,
default_loader_options=self._default_loader_options,
default_options_have_wildcards=self._loader_options_have_wildcards or self.uniquify,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> ModelT:
"""Add ``data`` to the collection.
Args:
data: Instance to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
The added instance.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
instance = self._attach_to_session(data)
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(instance, auto_refresh=auto_refresh)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Sequence[ModelT]:
"""Add many `data` to the collection.
Args:
data: list of Instances to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
The added instances.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
self.session.add_all(data)
self._flush_or_commit(auto_commit=auto_commit)
for datum in data:
self._expunge(datum, auto_expunge=auto_expunge)
return data
def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Delete instance identified by ``item_id``.
Args:
item_id: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The deleted instance.
Raises:
NotFoundError: If no instance found identified by ``item_id``.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
instance = self.get(
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
)
self.session.delete(instance)
self._flush_or_commit(auto_commit=auto_commit)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Delete instance identified by `item_id`.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The deleted instances.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = get_instrumented_attr(
self.model_type,
id_attribute if id_attribute is not None else self.id_attribute,
)
instances: list[ModelT] = []
if self._prefer_any:
chunk_size = len(item_ids) + 1
chunk_size = self._get_insertmanyvalues_max_parameters(chunk_size)
for idx in range(0, len(item_ids), chunk_size):
chunk = item_ids[idx : min(idx + chunk_size, len(item_ids))]
if self._dialect.delete_executemany_returning:
instances.extend(
self.session.scalars(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
else:
instances.extend(
self.session.scalars(
self._get_delete_many_statement(
statement_type="select",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
self.session.execute(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
)
self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def _get_insertmanyvalues_max_parameters(self, chunk_size: Optional[int] = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Delete instances specified by referenced kwargs and filters.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Arguments to apply to a delete
Returns:
The deleted instances.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
model_type = self.model_type
statement = self._get_base_stmt(
statement=delete(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs)
statement = self._apply_filters(*filters, statement=statement, apply_pagination=False)
instances: list[ModelT] = []
if self._dialect.delete_executemany_returning:
instances.extend(self.session.scalars(statement.returning(model_type)))
else:
instances.extend(
self.list(
*filters,
load=load,
execution_options=execution_options,
auto_expunge=auto_expunge,
**kwargs,
),
)
result = self.session.execute(statement)
row_count = getattr(result, "rowcount", -2)
if sanity_check and row_count >= 0 and len(instances) != row_count: # pyright: ignore # noqa: PGH003
# backends will return a -1 if they can't determine impacted rowcount
# only compare length of selected instances to results if it's >= 0
self.session.rollback()
raise RepositoryError(detail="Deleted count does not match fetched count. Rollback issued.")
self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
True if the instance was found. False if not found..
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
existing = self.count(
*filters,
load=load,
execution_options=execution_options,
error_messages=error_messages,
**kwargs,
)
return existing > 0
def _get_base_stmt(
self,
*,
statement: StatementTypeT,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> StatementTypeT:
"""Get base statement with options applied.
Args:
statement: The select statement to modify
loader_options: Options for loading relationships
execution_options: Options for statement execution
Returns:
Modified select statement
"""
if loader_options:
statement = cast("StatementTypeT", statement.options(*loader_options))
if execution_options:
statement = cast("StatementTypeT", statement.execution_options(**execution_options))
return statement
def _get_delete_many_statement(
self,
*,
model_type: type[ModelT],
id_attribute: InstrumentedAttribute[Any],
id_chunk: list[Any],
supports_returning: bool,
statement_type: Literal["delete", "select"] = "delete",
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Select[tuple[ModelT]], Delete, ReturningDelete[tuple[ModelT]]]:
# Base statement is static
statement = self._get_base_stmt(
statement=delete(model_type) if statement_type == "delete" else select(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
if execution_options:
statement = statement.execution_options(**execution_options)
if supports_returning and statement_type != "select":
statement = cast("ReturningDelete[tuple[ModelT]]", statement.returning(model_type)) # type: ignore[union-attr,assignment] # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
if self._prefer_any:
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Get instance identified by `item_id`.
Args:
item_id: Identifier of the instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The retrieved instance.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = id_attribute if id_attribute is not None else self.id_attribute
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)])
instance = (self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
"""Get instance identified by ``kwargs``.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
instance = (self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance or None
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
instance = cast(
"Result[tuple[ModelT]]",
(self._execute(statement, uniquify=loader_options_have_wildcard)),
).scalar_one_or_none()
if instance:
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Union[bool, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` or create if it doesn't exist.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, automatically perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be created.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = self.get_one_or_none(
*filters,
**match_filter,
load=load,
execution_options=execution_options,
)
if not existing:
return (
self.add(
self.model_type(**kwargs),
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
),
True,
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = self._attach_to_session(existing, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, False
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` and update the model if the arguments are different.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be updated.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = self.get_one(*filters, **match_filter, load=load, execution_options=execution_options)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = self._attach_to_session(existing, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, updated
def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
"""Get the count of records returned by a query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
results = self._execute(
statement=self._get_count_stmt(
statement=statement, loader_options=loader_options, execution_options=execution_options
),
uniquify=loader_options_have_wildcard,
)
return cast("int", results.scalar_one())
def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.
Args:
data: An instance that should have a value for `self.id_attribute` that
exists in the collection.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated instance.
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
item_id = self.get_id_attribute_value(
data,
id_attribute=id_attribute,
)
# this will raise for not found, and will put the item in the session
self.get(item_id, id_attribute=id_attribute, load=load, execution_options=execution_options)
# this will merge the inbound data to the instance we just put in the session
instance = self._attach_to_session(data, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
"""Update one or more instances with the attribute values present on `data`.
This function has an optimized bulk update based on the configured SQL dialect:
- For backends supporting `RETURNING` with `executemany`, a single bulk update with returning clause is executed.
- For other backends, it does a bulk update and then returns the updated data after a refresh.
Args:
data: A list of instances to update. Each should have a value for `self.id_attribute` that exists in the
collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated instances.
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore[misc]
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
loader_options = self._get_loader_options(load)[0]
supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle"
statement = self._get_update_many_statement(
self.model_type,
supports_returning,
loader_options=loader_options,
execution_options=execution_options,
)
if supports_returning:
instances = list(
self.session.scalars(
statement,
cast("_CoreSingleExecuteParams", data_to_update), # this is not correct but the only way
# currently to deal with an SQLAlchemy typing issue. See
# https://github.com/sqlalchemy/sqlalchemy/discussions/9925
execution_options=execution_options,
),
)
self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
self.session.execute(statement, data_to_update, execution_options=execution_options)
self._flush_or_commit(auto_commit=auto_commit)
return data
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Union[list[_AbstractLoad], None],
execution_options: Union[dict[str, Any], None],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
# Base update statement is static
statement = self._get_base_stmt(
statement=update(table=model_type), loader_options=loader_options, execution_options=execution_options
)
if supports_returning:
return statement.returning(model_type)
return statement
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if self._dialect.name in {"spanner", "spanner+spanner"} or not count_with_window_function:
return self._list_and_count_basic(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
order_by=order_by,
error_messages=error_messages,
**kwargs,
)
return self._list_and_count_window(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
error_messages=error_messages,
order_by=order_by,
**kwargs,
)
def _expunge(self, instance: ModelT, auto_expunge: Optional[bool]) -> None:
if auto_expunge is None:
auto_expunge = self.auto_expunge
return self.session.expunge(instance) if auto_expunge else None
def _flush_or_commit(self, auto_commit: Optional[bool]) -> None:
if auto_commit is None:
auto_commit = self.auto_commit
return self.session.commit() if auto_commit else self.session.flush()
def _refresh(
self,
instance: ModelT,
auto_refresh: Optional[bool],
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
) -> None:
if auto_refresh is None:
auto_refresh = self.auto_refresh
return (
self.session.refresh(
instance=instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
if auto_refresh
else None
)
def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: List[OrderingPair] | OrderingPair | None = None,
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by or []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = self._execute(statement.add_columns(over(sql_func.count())), uniquify=loader_options_have_wildcard)
count: int = 0
instances: list[ModelT] = []
for i, (instance, count_value) in enumerate(result):
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
if i == 0:
count = count_value
return instances, count
def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by or []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
count_result = self.session.execute(
self._get_count_stmt(
statement,
loader_options=loader_options,
execution_options=execution_options,
),
)
count = count_result.scalar_one()
result = self._execute(statement, uniquify=loader_options_have_wildcard)
instances: list[ModelT] = []
for (instance,) in result:
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
return instances, count
def _get_count_stmt(
self,
statement: Select[tuple[ModelT]],
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Select[tuple[int]]:
# Count statement transformations are static
return (
statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True)
.limit(None)
.offset(None)
.order_by(None)
)
def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Modify or create instance.
Updates instance with the attribute values present on `data`, or creates a new instance if
one doesn't exist.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated or created instance.
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: getattr(data, field_name, None)
for field_name in match_fields
if getattr(data, field_name, None) is not None
}
elif getattr(data, self.id_attribute, None) is not None:
match_filter = {self.id_attribute: getattr(data, self.id_attribute, None)}
else:
match_filter = data.to_dict(exclude={self.id_attribute})
existing = self.get_one_or_none(load=load, execution_options=execution_options, **match_filter)
if not existing:
return self.add(
data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
instance = self._attach_to_session(existing, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
"""Modify or create multiple instances.
Update instances with the attribute values present on `data`, or create a new instance if
one doesn't exist.
!!! tip
In most cases, you will want to set `match_fields` to the combination of attributes, excluded the primary key, that define uniqueness for a row.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on ``data`` named as value of
:attr:`id_attribute`.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements
match_fields: a list of keys to use to match the existing model. When
empty, automatically uses ``self.id_attribute`` (`id` by default) to match .
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
The updated or created instance.
Raises:
NotFoundError: If no instance found with same identifier as ``data``.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
instances: list[ModelT] = []
data_to_update: list[ModelT] = []
data_to_insert: list[ModelT] = []
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
match_filter: list[Union[StatementFilter, ColumnElement[bool]]] = []
if match_fields:
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
]
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
existing_objs = self.list(
*match_filter,
load=load,
execution_options=execution_options,
auto_expunge=False,
)
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = list(
{getattr(datum, field_name) for datum in existing_objs if datum}, # ensure the list is unique
)
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
if getattr(datum, self.id_attribute, None) in existing_ids:
data_to_update.append(datum)
else:
data_to_insert.append(datum)
if data_to_insert:
instances.extend(
self.add_many(data_to_insert, auto_commit=False, auto_expunge=False),
)
if data_to_update:
instances.extend(
self.update_many(
data_to_update,
auto_commit=False,
auto_expunge=False,
load=load,
execution_options=execution_options,
),
)
self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]:
return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None]
def _get_match_fields(
self,
match_fields: Optional[Union[list[str], str]] = None,
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
def _merge_on_match_fields(
self,
data: list[ModelT],
existing_data: list[ModelT],
match_fields: Optional[Union[list[str], str]] = None,
) -> list[ModelT]:
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for _row_id, datum in enumerate(data):
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(datum, self.id_attribute, getattr(existing_datum, self.id_attribute))
return data
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
"""Get a list of instances, optionally filtered.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
self.uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by or []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = self._execute(statement, uniquify=loader_options_have_wildcard)
instances = list(result.scalars())
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return cast("list[ModelT]", instances)
@classmethod
def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool:
"""Perform a health check on the database.
Args:
session: through which we run a check statement
Returns:
``True`` if healthy.
"""
with wrap_sqlalchemy_exception():
return ( # type: ignore[no-any-return]
session.execute(cls._get_health_check_statement(session))
).scalar_one() == 1
@staticmethod
def _get_health_check_statement(session: Union[Session, scoped_session[Session]]) -> TextClause:
if session.bind and session.bind.dialect.name == "oracle":
return text("SELECT 1 FROM DUAL")
return text("SELECT 1")
def _attach_to_session(self, model: ModelT, strategy: Literal["add", "merge"] = "add", load: bool = True) -> ModelT:
"""Attach detached instance to the session.
Args:
model: The instance to be attached to the session.
strategy: How the instance should be attached.
- "add": New instance added to session
- "merge": Instance merged with existing, or new one added.
load: Boolean, when False, merge switches into
a "high performance" mode which causes it to forego emitting history
events as well as all database access. This flag is used for
cases such as transferring graphs of objects into a session
from a second level cache, or to transfer just-loaded objects
into the session owned by a worker thread or process
without re-querying the database.
Returns:
Instance attached to the session - if `"merge"` strategy, may not be same instance
that was provided.
"""
if strategy == "add":
self.session.add(model)
return model
if strategy == "merge":
return self.session.merge(model, load=load)
msg = "Unexpected value for `strategy`, must be `'add'` or `'merge'`" # type: ignore[unreachable]
raise ValueError(msg)
def _execute(
self,
statement: Select[Any],
uniquify: bool = False,
) -> Result[Any]:
result = self.session.execute(statement)
if uniquify or self.uniquify:
result = result.unique()
return result
class SQLAlchemySyncSlugRepository(
SQLAlchemySyncRepository[ModelT],
SQLAlchemySyncSlugRepositoryProtocol[ModelT],
):
"""Extends the repository to include slug model features.."""
def get_by_slug(
self,
slug: str,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Select record by slug value."""
return self.get_one_or_none(
slug=slug,
load=load,
execution_options=execution_options,
error_messages=error_messages,
uniquify=uniquify,
)
def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
def _is_slug_unique(
self,
slug: str,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool:
return self.exists(slug=slug, load=load, execution_options=execution_options, **kwargs) is False
class SQLAlchemySyncQueryRepository:
"""SQLAlchemy Query Repository.
This is a loosely typed helper to query for when you need to select data in ways that don't align to the normal repository pattern.
"""
error_messages: Optional[ErrorMessages] = None
def __init__(
self,
*,
session: Union[Session, scoped_session[Session]],
error_messages: Optional[ErrorMessages] = None,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
session: Session managing the unit-of-work for the operation.
error_messages: A set of error messages to use for operations.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.session = session
self.error_messages = error_messages
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
def get_one(
self,
statement: Select[tuple[Any]],
**kwargs: Any,
) -> Row[Any]:
"""Get instance identified by ``kwargs``.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
instance = (self.execute(statement)).scalar_one_or_none()
return self.check_not_found(instance)
def get_one_or_none(
self,
statement: Select[Any],
**kwargs: Any,
) -> Optional[Row[Any]]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance or None
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
instance = (self.execute(statement)).scalar_one_or_none()
return instance or None
def count(self, statement: Select[Any], **kwargs: Any) -> int:
"""Get the count of records returned by a query.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(
None,
)
statement = self._filter_statement_by_kwargs(statement, **kwargs)
results = self.execute(statement)
return results.scalar_one() # type: ignore # noqa: PGH003
def list_and_count(
self,
statement: Select[Any],
count_with_window_function: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query.
count_with_window_function: Force list and count to use two queries instead of an analytical window function.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
if self._dialect.name in {"spanner", "spanner+spanner"} or count_with_window_function:
return self._list_and_count_basic(statement=statement, **kwargs)
return self._list_and_count_window(statement=statement, **kwargs)
def _list_and_count_window(
self,
statement: Select[Any],
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = statement.add_columns(over(sql_func.count(text("1"))))
statement = self._filter_statement_by_kwargs(statement, **kwargs)
result = self.execute(statement)
count: int = 0
instances: list[Row[Any]] = []
for i, (instance, count_value) in enumerate(result):
instances.append(instance)
if i == 0:
count = count_value
return instances, count
def _get_count_stmt(self, statement: Select[Any]) -> Select[Any]:
return statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None) # pyright: ignore[reportUnknownVariable]
def _list_and_count_basic(
self,
statement: Select[Any],
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query. .
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
count_result = self.session.execute(self._get_count_stmt(statement))
count = count_result.scalar_one()
result = self.execute(statement)
instances: list[Row[Any]] = []
for (instance,) in result:
instances.append(instance)
return instances, count
def list(self, statement: Select[Any], **kwargs: Any) -> list[Row[Any]]:
"""Get a list of instances, optionally filtered.
Args:
statement: To facilitate customization of the underlying select query.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
result = self.execute(statement)
return list(result.all())
def _filter_statement_by_kwargs(
self,
statement: Select[Any],
/,
**kwargs: Any,
) -> Select[Any]:
"""Filter the collection by kwargs.
Args:
statement: statement to filter
**kwargs: key/value pairs such that objects remaining in the statement after filtering
have the property that their attribute named `key` has value equal to `value`.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
return statement.filter_by(**kwargs)
# the following is all sqlalchemy implementation detail, and shouldn't be directly accessed
@staticmethod
def check_not_found(item_or_none: Optional[T]) -> T:
"""Raise :class:`NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item to be tested for existence.
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
def execute(
self,
statement: Union[
ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Update, Delete, Select[Any]
],
) -> Result[Any]:
return self.session.execute(statement)
python-advanced-alchemy-1.0.1/advanced_alchemy/repository/_util.py 0000664 0000000 0000000 00000032654 14766637146 0025427 0 ustar 00root root 0000000 0000000 from collections.abc import Iterable, Sequence
from typing import Any, Literal, Optional, Protocol, Union, cast, overload
from sqlalchemy import (
Delete,
Dialect,
Select,
Update,
)
from sqlalchemy.orm import (
InstrumentedAttribute,
MapperProperty,
RelationshipProperty,
joinedload,
lazyload,
selectinload,
)
from sqlalchemy.orm.strategy_options import (
_AbstractLoad, # pyright: ignore[reportPrivateUsage] # pyright: ignore[reportPrivateUsage]
)
from sqlalchemy.sql import ColumnElement, ColumnExpressionArgument
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from typing_extensions import TypeAlias
from advanced_alchemy.base import ModelProtocol
from advanced_alchemy.exceptions import ErrorMessages
from advanced_alchemy.exceptions import wrap_sqlalchemy_exception as _wrap_sqlalchemy_exception
from advanced_alchemy.filters import (
InAnyFilter,
PaginationFilter,
StatementFilter,
StatementTypeT,
)
from advanced_alchemy.repository.typing import ModelT, OrderingPair
WhereClauseT = ColumnExpressionArgument[bool]
SingleLoad: TypeAlias = Union[
_AbstractLoad,
Literal["*"],
InstrumentedAttribute[Any],
RelationshipProperty[Any],
MapperProperty[Any],
]
LoadCollection: TypeAlias = Sequence[Union[SingleLoad, Sequence[SingleLoad]]]
ExecutableOptions: TypeAlias = Sequence[ExecutableOption]
LoadSpec: TypeAlias = Union[LoadCollection, SingleLoad, ExecutableOption, ExecutableOptions]
OrderByT: TypeAlias = Union[
str,
InstrumentedAttribute[Any],
RelationshipProperty[Any],
]
# NOTE: For backward compatibility with Litestar - this is imported from here within the litestar codebase.
wrap_sqlalchemy_exception = _wrap_sqlalchemy_exception
DEFAULT_ERROR_MESSAGE_TEMPLATES: ErrorMessages = {
"integrity": "There was a data validation error during processing",
"foreign_key": "A foreign key is missing or invalid",
"multiple_rows": "Multiple matching rows found",
"duplicate_key": "A record matching the supplied data already exists.",
"other": "There was an error during data processing",
"check_constraint": "The data failed a check constraint during processing",
"not_found": "The requested resource was not found",
}
"""Default error messages for repository errors."""
def get_instrumented_attr(
model: type[ModelProtocol],
key: Union[str, InstrumentedAttribute[Any]],
) -> InstrumentedAttribute[Any]:
"""Get an instrumented attribute from a model.
Args:
model: SQLAlchemy model class.
key: Either a string attribute name or an :class:`sqlalchemy.orm.InstrumentedAttribute`.
Returns:
:class:`sqlalchemy.orm.InstrumentedAttribute`: The instrumented attribute from the model.
"""
if isinstance(key, str):
return cast("InstrumentedAttribute[Any]", getattr(model, key))
return key
def model_from_dict(model: type[ModelT], **kwargs: Any) -> ModelT:
"""Create an ORM model instance from a dictionary of attributes.
Args:
model: The SQLAlchemy model class to instantiate.
**kwargs: Keyword arguments containing model attribute values.
Returns:
ModelT: A new instance of the model populated with the provided values.
"""
data = {
column_name: kwargs[column_name]
for column_name in model.__mapper__.columns.keys() # noqa: SIM118 # pyright: ignore[reportUnknownMemberType]
if column_name in kwargs
}
return model(**data)
def get_abstract_loader_options(
loader_options: Union[LoadSpec, None],
default_loader_options: Union[list[_AbstractLoad], None] = None,
default_options_have_wildcards: bool = False,
merge_with_default: bool = True,
inherit_lazy_relationships: bool = True,
cycle_count: int = 0,
) -> tuple[list[_AbstractLoad], bool]:
"""Generate SQLAlchemy loader options for eager loading relationships.
Args:
loader_options :class:`~advanced_alchemy.repository.typing.LoadSpec`|:class:`None` Specification for how to load relationships. Can be:
- None: Use defaults
- :class:`sqlalchemy.orm.strategy_options._AbstractLoad`: Direct SQLAlchemy loader option
- :class:`sqlalchemy.orm.InstrumentedAttribute`: Model relationship attribute
- :class:`sqlalchemy.orm.RelationshipProperty`: SQLAlchemy relationship
- str: "*" for wildcard loading
- :class:`typing.Sequence` of the above
default_loader_options: :class:`typing.Sequence` of :class:`sqlalchemy.orm.strategy_options._AbstractLoad` loader options to start with.
default_options_have_wildcards: Whether the default options contain wildcards.
merge_with_default: Whether to merge the default options with the loader options.
inherit_lazy_relationships: Whether to inherit the ``lazy`` configuration from the model's relationships.
cycle_count: Number of times this function has been called recursively.
Returns:
tuple[:class:`list`[:class:`sqlalchemy.orm.strategy_options._AbstractLoad`], bool]: A tuple containing:
- :class:`list` of :class:`sqlalchemy.orm.strategy_options._AbstractLoad` SQLAlchemy loader option objects
- Boolean indicating if any wildcard loaders are present
"""
loads: list[_AbstractLoad] = []
if cycle_count == 0 and not inherit_lazy_relationships:
loads.append(lazyload("*"))
if cycle_count == 0 and merge_with_default and default_loader_options is not None:
loads.extend(default_loader_options)
options_have_wildcards = default_options_have_wildcards
if loader_options is None:
return (loads, options_have_wildcards)
if isinstance(loader_options, _AbstractLoad):
return ([loader_options], options_have_wildcards)
if isinstance(loader_options, InstrumentedAttribute):
loader_options = [loader_options.property]
if isinstance(loader_options, RelationshipProperty):
class_ = loader_options.class_attribute
return (
[selectinload(class_)]
if loader_options.uselist
else [joinedload(class_, innerjoin=loader_options.innerjoin)],
options_have_wildcards if loader_options.uselist else True,
)
if isinstance(loader_options, str) and loader_options == "*":
options_have_wildcards = True
return ([joinedload("*")], options_have_wildcards)
if isinstance(loader_options, (list, tuple)):
for attribute in loader_options: # pyright: ignore[reportUnknownVariableType]
if isinstance(attribute, (list, tuple)):
load_chain, options_have_wildcards = get_abstract_loader_options(
loader_options=attribute, # pyright: ignore[reportUnknownArgumentType]
default_options_have_wildcards=options_have_wildcards,
inherit_lazy_relationships=inherit_lazy_relationships,
merge_with_default=merge_with_default,
cycle_count=cycle_count + 1,
)
loader = load_chain[-1]
for sub_load in load_chain[-2::-1]:
loader = sub_load.options(loader)
loads.append(loader)
else:
load_chain, options_have_wildcards = get_abstract_loader_options(
loader_options=attribute, # pyright: ignore[reportUnknownArgumentType]
default_options_have_wildcards=options_have_wildcards,
inherit_lazy_relationships=inherit_lazy_relationships,
merge_with_default=merge_with_default,
cycle_count=cycle_count + 1,
)
loads.extend(load_chain)
return (loads, options_have_wildcards)
class FilterableRepositoryProtocol(Protocol[ModelT]):
"""Protocol defining the interface for filterable repositories.
This protocol defines the required attributes and methods that any
filterable repository implementation must provide.
"""
model_type: type[ModelT]
"""The SQLAlchemy model class this repository manages."""
class FilterableRepository(FilterableRepositoryProtocol[ModelT]):
"""Default implementation of a filterable repository.
Provides core filtering, ordering and pagination functionality for
SQLAlchemy models.
"""
model_type: type[ModelT]
"""The SQLAlchemy model class this repository manages."""
prefer_any_dialects: Optional[tuple[str]] = ("postgresql",)
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None
"""List or single :class:`~advanced_alchemy.repository.typing.OrderingPair` to use for sorting."""
_prefer_any: bool = False
"""Whether to prefer ANY() over IN() in queries."""
_dialect: Dialect
"""The SQLAlchemy :class:`sqlalchemy.dialects.Dialect` being used."""
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Select[tuple[ModelT]],
) -> Select[tuple[ModelT]]: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Delete,
) -> Delete: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Union[ReturningDelete[tuple[ModelT]], ReturningUpdate[tuple[ModelT]]],
) -> Union[ReturningDelete[tuple[ModelT]], ReturningUpdate[tuple[ModelT]]]: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Update,
) -> Update: ...
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: StatementTypeT,
) -> StatementTypeT:
"""Apply filters to a SQL statement.
Args:
*filters: Filter conditions to apply.
apply_pagination: Whether to apply pagination filters.
statement: The base SQL statement to filter.
Returns:
StatementTypeT: The filtered SQL statement.
"""
for filter_ in filters:
if isinstance(filter_, (PaginationFilter,)):
if apply_pagination:
statement = filter_.append_to_statement(statement, self.model_type)
elif isinstance(filter_, (InAnyFilter,)):
statement = filter_.append_to_statement(statement, self.model_type)
elif isinstance(filter_, ColumnElement):
statement = cast("StatementTypeT", statement.where(filter_))
else:
statement = filter_.append_to_statement(statement, self.model_type)
return statement
def _filter_select_by_kwargs(
self,
statement: StatementTypeT,
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
) -> StatementTypeT:
"""Filter a statement using keyword arguments.
Args:
statement: :class:`sqlalchemy.sql.Select` The SQL statement to filter.
kwargs: Dictionary or iterable of tuples containing filter criteria.
Keys should be model attribute names, values are what to filter for.
Returns:
StatementTypeT: The filtered SQL statement.
"""
for key, val in dict(kwargs).items():
field = get_instrumented_attr(self.model_type, key)
statement = cast("StatementTypeT", statement.where(field == val))
return statement
def _apply_order_by(
self,
statement: StatementTypeT,
order_by: Union[
list[tuple[Union[str, InstrumentedAttribute[Any]], bool]],
tuple[Union[str, InstrumentedAttribute[Any]], bool],
],
) -> StatementTypeT:
"""Apply ordering to a SQL statement.
Args:
statement: The SQL statement to order.
order_by: Ordering specification. Either a single tuple or list of tuples where:
- First element is the field name or :class:`sqlalchemy.orm.InstrumentedAttribute` to order by
- Second element is a boolean indicating descending (True) or ascending (False)
Returns:
StatementTypeT: The ordered SQL statement.
"""
if not isinstance(order_by, list):
order_by = [order_by]
for order_field, is_desc in order_by:
field = get_instrumented_attr(self.model_type, order_field)
statement = self._order_by_attribute(statement, field, is_desc)
return statement
def _order_by_attribute(
self,
statement: StatementTypeT,
field: InstrumentedAttribute[Any],
is_desc: bool,
) -> StatementTypeT:
"""Apply ordering by a single attribute to a SQL statement.
Args:
statement: The SQL statement to order.
field: The model attribute to order by.
is_desc: Whether to order in descending (True) or ascending (False) order.
Returns:
StatementTypeT: The ordered SQL statement.
"""
if not isinstance(statement, Select):
return statement
return cast("StatementTypeT", statement.order_by(field.desc() if is_desc else field.asc()))
python-advanced-alchemy-1.0.1/advanced_alchemy/repository/memory/ 0000775 0000000 0000000 00000000000 14766637146 0025237 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/repository/memory/__init__.py 0000664 0000000 0000000 00000000624 14766637146 0027352 0 ustar 00root root 0000000 0000000 from advanced_alchemy.repository.memory._async import SQLAlchemyAsyncMockRepository, SQLAlchemyAsyncMockSlugRepository
from advanced_alchemy.repository.memory._sync import SQLAlchemySyncMockRepository, SQLAlchemySyncMockSlugRepository
__all__ = [
"SQLAlchemyAsyncMockRepository",
"SQLAlchemyAsyncMockSlugRepository",
"SQLAlchemySyncMockRepository",
"SQLAlchemySyncMockSlugRepository",
]
python-advanced-alchemy-1.0.1/advanced_alchemy/repository/memory/_async.py 0000664 0000000 0000000 00000072773 14766637146 0027105 0 ustar 00root root 0000000 0000000 import datetime
import random
import re
import string
from collections import abc
from collections.abc import Iterable
from typing import Any, Optional, Union, cast, overload
from unittest.mock import create_autospec
from sqlalchemy import (
ColumnElement,
Dialect,
Select,
StatementLambdaElement,
Update,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql.dml import ReturningUpdate
from typing_extensions import Self
from advanced_alchemy.exceptions import ErrorMessages, IntegrityError, NotFoundError, RepositoryError
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
LimitOffset,
NotInCollectionFilter,
NotInSearchFilter,
OnBeforeAfter,
OrderBy,
SearchFilter,
StatementFilter,
)
from advanced_alchemy.repository._async import SQLAlchemyAsyncRepositoryProtocol, SQLAlchemyAsyncSlugRepositoryProtocol
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES, LoadSpec
from advanced_alchemy.repository.memory.base import (
AnyObject,
InMemoryStore,
SQLAlchemyInMemoryStore,
SQLAlchemyMultiStore,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
class SQLAlchemyAsyncMockRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT]):
"""In memory repository."""
__database__: SQLAlchemyMultiStore[ModelT] = SQLAlchemyMultiStore(SQLAlchemyInMemoryStore)
__database_registry__: dict[type[Self], SQLAlchemyMultiStore[ModelT]] = {}
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
model_type: type[ModelT]
id_attribute: Any = "id"
match_fields: Optional[Union[list[str], str]] = None
uniquify: bool = False
_exclude_kwargs: set[str] = {
"statement",
"session",
"auto_expunge",
"auto_refresh",
"auto_commit",
"attribute_names",
"with_for_update",
"count_with_window_function",
"loader_options",
"execution_options",
"order_by",
"load",
"error_messages",
"wrap_exceptions",
"uniquify",
}
def __init__(
self,
*,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Union[list[OrderingPair], OrderingPair, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> None:
self.session = session
self.statement = create_autospec("Select[Tuple[ModelT]]", instance=True)
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.error_messages = self._get_error_messages(error_messages=error_messages)
self.wrap_exceptions = wrap_exceptions
self.order_by = order_by
self._dialect: Dialect = create_autospec(Dialect, instance=True)
self._dialect.name = "mock"
self.__filtered_store__: InMemoryStore[ModelT] = self.__database__.store_type()
self._default_options: Any = []
self._default_execution_options: Any = {}
self._loader_options: Any = []
self._loader_options_have_wildcards = False
self.uniquify = bool(uniquify)
def __init_subclass__(cls) -> None:
cls.__database_registry__[cls] = cls.__database__ # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
@staticmethod
def _get_error_messages(
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
default_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
default_messages = cast(
"Optional[ErrorMessages]",
default_messages if default_messages != Empty else DEFAULT_ERROR_MESSAGE_TEMPLATES,
)
if error_messages is not None and default_messages is not None:
default_messages.update(cast("ErrorMessages", error_messages))
return default_messages
@classmethod
def __database_add__(cls, identity: Any, data: ModelT) -> ModelT:
return cast("ModelT", cls.__database__.add(identity, data)) # pyright: ignore[reportUnnecessaryCast,reportGeneralTypeIssues]
@classmethod
def __database_clear__(cls) -> None:
for database in cls.__database_registry__.values(): # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
database.remove_all()
@overload
def __collection__(self) -> InMemoryStore[ModelT]: ...
@overload
def __collection__(self, identity: type[AnyObject]) -> InMemoryStore[AnyObject]: ...
def __collection__(
self,
identity: Optional[type[AnyObject]] = None,
) -> Union[InMemoryStore[AnyObject], InMemoryStore[ModelT]]:
if identity:
return self.__database__.store(identity)
return self.__filtered_store__ or self.__database__.store(self.model_type)
@staticmethod
def check_not_found(item_or_none: Union[ModelT, None]) -> ModelT:
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
return {key: value for key, value in kwargs.items() if key not in self._exclude_kwargs}
def _apply_limit_offset_pagination(self, result: list[ModelT], limit: int, offset: int) -> list[ModelT]:
return result[offset:limit]
def _filter_in_collection(
self,
result: list[ModelT],
field_name: str,
values: abc.Collection[Any],
) -> list[ModelT]:
return [item for item in result if getattr(item, field_name) in values]
def _filter_not_in_collection(
self,
result: list[ModelT],
field_name: str,
values: abc.Collection[Any],
) -> list[ModelT]:
if not values:
return result
return [item for item in result if getattr(item, field_name) not in values]
def _filter_on_datetime_field(
self,
result: list[ModelT],
field_name: str,
before: Optional[datetime.datetime] = None,
after: Optional[datetime.datetime] = None,
on_or_before: Optional[datetime.datetime] = None,
on_or_after: Optional[datetime.datetime] = None,
) -> list[ModelT]:
result_: list[ModelT] = []
for item in result:
attr: datetime.datetime = getattr(item, field_name)
if before is not None and attr < before:
result_.append(item)
if after is not None and attr > after:
result_.append(item)
if on_or_before is not None and attr <= on_or_before:
result_.append(item)
if on_or_after is not None and attr >= on_or_after:
result_.append(item)
return result_
def _filter_by_like(
self,
result: list[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> list[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: list[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(items))
def _filter_by_not_like(
self,
result: list[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> list[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: list[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(result).difference(set(items)))
def _filter_result_by_kwargs(
self,
result: Iterable[ModelT],
/,
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
) -> list[ModelT]:
kwargs_: dict[Any, Any] = kwargs if isinstance(kwargs, dict) else dict(*kwargs)
kwargs_ = self._exclude_unused_kwargs(kwargs_)
try:
return [item for item in result if all(getattr(item, field) == value for field, value in kwargs_.items())]
except AttributeError as error:
raise RepositoryError from error
def _order_by(self, result: list[ModelT], field_name: str, sort_desc: bool = False) -> list[ModelT]:
return sorted(result, key=lambda item: getattr(item, field_name), reverse=sort_desc)
def _apply_filters(
self,
result: list[ModelT],
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
) -> list[ModelT]:
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if apply_pagination:
result = self._apply_limit_offset_pagination(result, filter_.limit, filter_.offset)
elif isinstance(filter_, BeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
before=filter_.before,
after=filter_.after,
)
elif isinstance(filter_, OnBeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
on_or_before=filter_.on_or_before,
on_or_after=filter_.on_or_after,
)
elif isinstance(filter_, NotInCollectionFilter):
if filter_.values is not None: # pyright: ignore # noqa: PGH003
result = self._filter_not_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore # noqa: PGH003
elif isinstance(filter_, CollectionFilter):
if filter_.values is not None: # pyright: ignore # noqa: PGH003
result = self._filter_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore # noqa: PGH003
elif isinstance(filter_, OrderBy):
result = self._order_by(
result,
filter_.field_name,
sort_desc=filter_.sort_order == "desc",
)
elif isinstance(filter_, NotInSearchFilter):
result = self._filter_by_not_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif isinstance(filter_, SearchFilter):
result = self._filter_by_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif not isinstance(filter_, ColumnElement):
msg = f"Unexpected filter: {filter_}"
raise RepositoryError(msg)
return result
def _get_match_fields(
self,
match_fields: Union[list[str], str, None],
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
async def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[list[ModelT], int]:
result = await self.list(*filters, **kwargs)
return result, len(result)
async def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[list[ModelT], int]:
return await self._list_and_count_basic(*filters, **kwargs)
def _find_or_raise_not_found(self, id_: Any) -> ModelT:
return self.check_not_found(self.__collection__().get_or_none(id_))
def _find_one_or_raise_error(self, result: list[ModelT]) -> ModelT:
if not result:
msg = "No item found when one was expected"
raise IntegrityError(msg)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0]
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
return self.statement # type: ignore[no-any-return] # pyright: ignore[reportReturnType]
@classmethod
async def check_health(cls, session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> bool:
return True
async def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
return self._find_or_raise_not_found(item_id)
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
return self.check_not_found(await self.get_one_or_none(**kwargs))
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
result = self._filter_result_by_kwargs(self.__collection__().list(), kwargs)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0] if result else None
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[list[str], str, None] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = await self.get_one_or_none(**match_filter)
if not existing:
return (await self.add(self.model_type(**kwargs_)), True)
if upsert:
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = await self.update(existing)
return existing, False
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[list[str], str, None] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = await self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = await self.update(existing)
return existing, updated
async def exists(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
existing = await self.count(*filters, **kwargs)
return existing > 0
async def count(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
result = self._apply_filters(self.__collection__().list(), *filters)
return len(self._filter_result_by_kwargs(result, kwargs))
async def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> ModelT:
try:
self.__database__.add(self.model_type, data)
except KeyError as exc:
msg = "Item already exist in collection"
raise IntegrityError(msg) from exc
return data
async def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> list[ModelT]:
for obj in data:
await self.add(obj) # pyright: ignore[reportCallIssue]
return data
async def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
self._find_or_raise_not_found(self.__collection__().key(data))
return self.__collection__().update(data)
async def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
return [self.__collection__().update(obj) for obj in data if obj in self.__collection__()]
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
try:
return self._find_or_raise_not_found(item_id)
finally:
self.__collection__().remove(item_id)
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
chunk_size: Optional[int] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
deleted: list[ModelT] = []
for id_ in item_ids:
if obj := self.__collection__().get_or_none(id_):
deleted.append(obj)
self.__collection__().remove(id_)
return deleted
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
result = self.__collection__().list()
result = self._apply_filters(result, *filters)
models = self._filter_result_by_kwargs(result, kwargs)
item_ids = [getattr(model, self.id_attribute) for model in models]
return await self.delete_many(item_ids=item_ids)
async def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Union[list[str], str, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
# sourcery skip: assign-if-exp, reintroduce-else
if data in self.__collection__():
return await self.update(data)
return await self.add(data)
async def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Union[list[str], str, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
return [await self.upsert(item) for item in data]
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Union[list[OrderingPair], OrderingPair, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
return await self._list_and_count_basic(*filters, **kwargs)
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
result = self.__collection__().list()
result = self._apply_filters(result, *filters)
return self._filter_result_by_kwargs(result, kwargs)
class SQLAlchemyAsyncMockSlugRepository(
SQLAlchemyAsyncMockRepository[ModelT],
SQLAlchemyAsyncSlugRepositoryProtocol[ModelT],
):
async def get_by_slug(
self,
slug: str,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
"""Select record by slug value."""
return await self.get_one_or_none(slug=slug)
async def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if await self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
async def _is_slug_unique(
self,
slug: str,
**kwargs: Any,
) -> bool:
return await self.exists(slug=slug) is False
python-advanced-alchemy-1.0.1/advanced_alchemy/repository/memory/_sync.py 0000664 0000000 0000000 00000072422 14766637146 0026733 0 ustar 00root root 0000000 0000000 # Do not edit this file directly. It has been autogenerated from
# advanced_alchemy/repository/memory/_async.py
import datetime
import random
import re
import string
from collections import abc
from collections.abc import Iterable
from typing import Any, Optional, Union, cast, overload
from unittest.mock import create_autospec
from sqlalchemy import (
ColumnElement,
Dialect,
Select,
StatementLambdaElement,
Update,
)
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql.dml import ReturningUpdate
from typing_extensions import Self
from advanced_alchemy.exceptions import ErrorMessages, IntegrityError, NotFoundError, RepositoryError
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
LimitOffset,
NotInCollectionFilter,
NotInSearchFilter,
OnBeforeAfter,
OrderBy,
SearchFilter,
StatementFilter,
)
from advanced_alchemy.repository._sync import SQLAlchemySyncRepositoryProtocol, SQLAlchemySyncSlugRepositoryProtocol
from advanced_alchemy.repository._util import DEFAULT_ERROR_MESSAGE_TEMPLATES, LoadSpec
from advanced_alchemy.repository.memory.base import (
AnyObject,
InMemoryStore,
SQLAlchemyInMemoryStore,
SQLAlchemyMultiStore,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
class SQLAlchemySyncMockRepository(SQLAlchemySyncRepositoryProtocol[ModelT]):
"""In memory repository."""
__database__: SQLAlchemyMultiStore[ModelT] = SQLAlchemyMultiStore(SQLAlchemyInMemoryStore)
__database_registry__: dict[type[Self], SQLAlchemyMultiStore[ModelT]] = {}
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
model_type: type[ModelT]
id_attribute: Any = "id"
match_fields: Optional[Union[list[str], str]] = None
uniquify: bool = False
_exclude_kwargs: set[str] = {
"statement",
"session",
"auto_expunge",
"auto_refresh",
"auto_commit",
"attribute_names",
"with_for_update",
"count_with_window_function",
"loader_options",
"execution_options",
"order_by",
"load",
"error_messages",
"wrap_exceptions",
"uniquify",
}
def __init__(
self,
*,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
session: Union[Session, scoped_session[Session]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Union[list[OrderingPair], OrderingPair, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> None:
self.session = session
self.statement = create_autospec("Select[Tuple[ModelT]]", instance=True)
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.error_messages = self._get_error_messages(error_messages=error_messages)
self.wrap_exceptions = wrap_exceptions
self.order_by = order_by
self._dialect: Dialect = create_autospec(Dialect, instance=True)
self._dialect.name = "mock"
self.__filtered_store__: InMemoryStore[ModelT] = self.__database__.store_type()
self._default_options: Any = []
self._default_execution_options: Any = {}
self._loader_options: Any = []
self._loader_options_have_wildcards = False
self.uniquify = bool(uniquify)
def __init_subclass__(cls) -> None:
cls.__database_registry__[cls] = cls.__database__ # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
@staticmethod
def _get_error_messages(
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
default_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
default_messages = cast(
"Optional[ErrorMessages]",
default_messages if default_messages != Empty else DEFAULT_ERROR_MESSAGE_TEMPLATES,
)
if error_messages is not None and default_messages is not None:
default_messages.update(cast("ErrorMessages", error_messages))
return default_messages
@classmethod
def __database_add__(cls, identity: Any, data: ModelT) -> ModelT:
return cast("ModelT", cls.__database__.add(identity, data)) # pyright: ignore[reportUnnecessaryCast,reportGeneralTypeIssues]
@classmethod
def __database_clear__(cls) -> None:
for database in cls.__database_registry__.values(): # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
database.remove_all()
@overload
def __collection__(self) -> InMemoryStore[ModelT]: ...
@overload
def __collection__(self, identity: type[AnyObject]) -> InMemoryStore[AnyObject]: ...
def __collection__(
self,
identity: Optional[type[AnyObject]] = None,
) -> Union[InMemoryStore[AnyObject], InMemoryStore[ModelT]]:
if identity:
return self.__database__.store(identity)
return self.__filtered_store__ or self.__database__.store(self.model_type)
@staticmethod
def check_not_found(item_or_none: Union[ModelT, None]) -> ModelT:
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
return {key: value for key, value in kwargs.items() if key not in self._exclude_kwargs}
def _apply_limit_offset_pagination(self, result: list[ModelT], limit: int, offset: int) -> list[ModelT]:
return result[offset:limit]
def _filter_in_collection(
self,
result: list[ModelT],
field_name: str,
values: abc.Collection[Any],
) -> list[ModelT]:
return [item for item in result if getattr(item, field_name) in values]
def _filter_not_in_collection(
self,
result: list[ModelT],
field_name: str,
values: abc.Collection[Any],
) -> list[ModelT]:
if not values:
return result
return [item for item in result if getattr(item, field_name) not in values]
def _filter_on_datetime_field(
self,
result: list[ModelT],
field_name: str,
before: Optional[datetime.datetime] = None,
after: Optional[datetime.datetime] = None,
on_or_before: Optional[datetime.datetime] = None,
on_or_after: Optional[datetime.datetime] = None,
) -> list[ModelT]:
result_: list[ModelT] = []
for item in result:
attr: datetime.datetime = getattr(item, field_name)
if before is not None and attr < before:
result_.append(item)
if after is not None and attr > after:
result_.append(item)
if on_or_before is not None and attr <= on_or_before:
result_.append(item)
if on_or_after is not None and attr >= on_or_after:
result_.append(item)
return result_
def _filter_by_like(
self,
result: list[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> list[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: list[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(items))
def _filter_by_not_like(
self,
result: list[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> list[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: list[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(result).difference(set(items)))
def _filter_result_by_kwargs(
self,
result: Iterable[ModelT],
/,
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
) -> list[ModelT]:
kwargs_: dict[Any, Any] = kwargs if isinstance(kwargs, dict) else dict(*kwargs)
kwargs_ = self._exclude_unused_kwargs(kwargs_)
try:
return [item for item in result if all(getattr(item, field) == value for field, value in kwargs_.items())]
except AttributeError as error:
raise RepositoryError from error
def _order_by(self, result: list[ModelT], field_name: str, sort_desc: bool = False) -> list[ModelT]:
return sorted(result, key=lambda item: getattr(item, field_name), reverse=sort_desc)
def _apply_filters(
self,
result: list[ModelT],
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
) -> list[ModelT]:
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if apply_pagination:
result = self._apply_limit_offset_pagination(result, filter_.limit, filter_.offset)
elif isinstance(filter_, BeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
before=filter_.before,
after=filter_.after,
)
elif isinstance(filter_, OnBeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
on_or_before=filter_.on_or_before,
on_or_after=filter_.on_or_after,
)
elif isinstance(filter_, NotInCollectionFilter):
if filter_.values is not None: # pyright: ignore # noqa: PGH003
result = self._filter_not_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore # noqa: PGH003
elif isinstance(filter_, CollectionFilter):
if filter_.values is not None: # pyright: ignore # noqa: PGH003
result = self._filter_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore # noqa: PGH003
elif isinstance(filter_, OrderBy):
result = self._order_by(
result,
filter_.field_name,
sort_desc=filter_.sort_order == "desc",
)
elif isinstance(filter_, NotInSearchFilter):
result = self._filter_by_not_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif isinstance(filter_, SearchFilter):
result = self._filter_by_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif not isinstance(filter_, ColumnElement):
msg = f"Unexpected filter: {filter_}"
raise RepositoryError(msg)
return result
def _get_match_fields(
self,
match_fields: Union[list[str], str, None],
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[list[ModelT], int]:
result = self.list(*filters, **kwargs)
return result, len(result)
def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[list[ModelT], int]:
return self._list_and_count_basic(*filters, **kwargs)
def _find_or_raise_not_found(self, id_: Any) -> ModelT:
return self.check_not_found(self.__collection__().get_or_none(id_))
def _find_one_or_raise_error(self, result: list[ModelT]) -> ModelT:
if not result:
msg = "No item found when one was expected"
raise IntegrityError(msg)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0]
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
return self.statement # type: ignore[no-any-return] # pyright: ignore[reportReturnType]
@classmethod
def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool:
return True
def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
return self._find_or_raise_not_found(item_id)
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
return self.check_not_found(self.get_one_or_none(**kwargs))
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
result = self._filter_result_by_kwargs(self.__collection__().list(), kwargs)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0] if result else None
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[list[str], str, None] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = self.get_one_or_none(**match_filter)
if not existing:
return (self.add(self.model_type(**kwargs_)), True)
if upsert:
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = self.update(existing)
return existing, False
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[list[str], str, None] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = self.update(existing)
return existing, updated
def exists(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
existing = self.count(*filters, **kwargs)
return existing > 0
def count(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
result = self._apply_filters(self.__collection__().list(), *filters)
return len(self._filter_result_by_kwargs(result, kwargs))
def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> ModelT:
try:
self.__database__.add(self.model_type, data)
except KeyError as exc:
msg = "Item already exist in collection"
raise IntegrityError(msg) from exc
return data
def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> list[ModelT]:
for obj in data:
self.add(obj) # pyright: ignore[reportCallIssue]
return data
def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
self._find_or_raise_not_found(self.__collection__().key(data))
return self.__collection__().update(data)
def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
return [self.__collection__().update(obj) for obj in data if obj in self.__collection__()]
def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
try:
return self._find_or_raise_not_found(item_id)
finally:
self.__collection__().remove(item_id)
def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
chunk_size: Optional[int] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
deleted: list[ModelT] = []
for id_ in item_ids:
if obj := self.__collection__().get_or_none(id_):
deleted.append(obj)
self.__collection__().remove(id_)
return deleted
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
result = self.__collection__().list()
result = self._apply_filters(result, *filters)
models = self._filter_result_by_kwargs(result, kwargs)
item_ids = [getattr(model, self.id_attribute) for model in models]
return self.delete_many(item_ids=item_ids)
def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Union[list[str], str, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
# sourcery skip: assign-if-exp, reintroduce-else
if data in self.__collection__():
return self.update(data)
return self.add(data)
def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Union[list[str], str, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> list[ModelT]:
return [self.upsert(item) for item in data]
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Union[list[OrderingPair], OrderingPair, None] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
return self._list_and_count_basic(*filters, **kwargs)
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> list[ModelT]:
result = self.__collection__().list()
result = self._apply_filters(result, *filters)
return self._filter_result_by_kwargs(result, kwargs)
class SQLAlchemySyncMockSlugRepository(
SQLAlchemySyncMockRepository[ModelT],
SQLAlchemySyncSlugRepositoryProtocol[ModelT],
):
def get_by_slug(
self,
slug: str,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
"""Select record by slug value."""
return self.get_one_or_none(slug=slug)
def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
def _is_slug_unique(
self,
slug: str,
**kwargs: Any,
) -> bool:
return self.exists(slug=slug) is False
python-advanced-alchemy-1.0.1/advanced_alchemy/repository/memory/base.py 0000664 0000000 0000000 00000030170 14766637146 0026524 0 ustar 00root root 0000000 0000000 # ruff: noqa: PD011
import builtins
import contextlib
from collections import defaultdict
from inspect import isclass, signature
from typing import TYPE_CHECKING, Any, Generic, Union, cast, overload
from sqlalchemy import ColumnElement, inspect
from sqlalchemy.orm import RelationshipProperty, Session, class_mapper, object_mapper
from typing_extensions import TypeVar
from advanced_alchemy.exceptions import AdvancedAlchemyError
from advanced_alchemy.repository.typing import _MISSING, MISSING, ModelT # pyright: ignore[reportPrivateUsage]
if TYPE_CHECKING:
from collections.abc import Iterable
from sqlalchemy.orm import Mapper
CollectionT = TypeVar("CollectionT")
T = TypeVar("T")
AnyObject = TypeVar("AnyObject", bound="Any")
class _NotSet:
pass
class InMemoryStore(Generic[T]):
def __init__(self) -> None:
self._store: dict[Any, T] = {}
def _resolve_key(self, key: Any) -> Any:
"""Test different key representations
Args:
key: The key to test
Raises:
KeyError: Raised if key is not present
Returns:
The key representation that is present in the store
"""
for key_ in (key, str(key)):
if key_ in self._store:
return key_
raise KeyError
def key(self, obj: T) -> Any:
return hash(obj)
def add(self, obj: T) -> T:
if (key := self.key(obj)) not in self._store:
self._store[key] = obj
return obj
raise KeyError
def update(self, obj: T) -> T:
key = self._resolve_key(self.key(obj))
self._store[key] = obj
return obj
@overload
def get(self, key: Any, default: type[_NotSet] = _NotSet) -> T: ...
@overload
def get(self, key: Any, default: AnyObject) -> "Union[T, AnyObject]": ...
def get(
self, key: Any, default: "Union[AnyObject, type[_NotSet]]" = _NotSet
) -> "Union[T, AnyObject]": # pragma: no cover
"""Get the object identified by `key`, or return `default` if set or raise a `KeyError` otherwise
Args:
key: The key to test
default: Value to return if key is not present. Defaults to _NotSet.
Raises:
KeyError: Raised if key is not present
Returns:
The object identified by key
"""
try:
key = self._resolve_key(key)
except KeyError as error:
if isclass(default) and not issubclass(default, _NotSet): # pyright: ignore[reportUnnecessaryIsInstance]
return cast("AnyObject", default)
raise KeyError from error
return self._store[key]
def get_or_none(self, key: Any, default: Any = _NotSet) -> "Union[T, None]":
return self.get(key) if default is _NotSet else self.get(key, default)
def remove(self, key: Any) -> T:
return self._store.pop(self._resolve_key(key))
def list(self) -> list[T]:
return list(self._store.values())
def remove_all(self) -> None:
self._store = {}
def __contains__(self, obj: T) -> bool:
try:
self._resolve_key(self.key(obj))
except KeyError:
return False
else:
return True
def __bool__(self) -> bool:
return bool(self._store)
class MultiStore(Generic[T]):
def __init__(self, store_type: "type[InMemoryStore[T]]") -> None:
self.store_type = store_type
self._store: defaultdict[Any, InMemoryStore[T]] = defaultdict(store_type)
def add(self, identity: Any, obj: T) -> T:
return self._store[identity].add(obj)
def store(self, identity: Any) -> "InMemoryStore[T]":
return self._store[identity]
def identity(self, obj: T) -> Any:
return type(obj)
def remove_all(self) -> None:
self._store = defaultdict(self.store_type)
class SQLAlchemyInMemoryStore(InMemoryStore[ModelT]):
id_attribute: str = "id"
def _update_relationship(self, data: ModelT, ref: ModelT) -> None: # pragma: no cover
"""Set relationship data fields targeting ref class to ref.
Example:
```python
class Parent(Base):
child = relationship("Child")
class Child(Base):
pass
```
If data and ref are respectively a `Parent` and `Child` instances,
then `data.child` will be set to `ref`
Args:
data: Model instance on which to update relationships
ref: Target model instance to set on data relationships
"""
ref_mapper = object_mapper(ref)
for relationship in object_mapper(data).relationships:
local = next(iter(relationship.local_columns))
remote = next(iter(relationship.remote_side))
if not local.key or not remote.key:
msg = f"Cannot update relationship {relationship} for model {ref_mapper.class_}"
raise AdvancedAlchemyError(msg)
value = getattr(data, relationship.key)
if not value and relationship.mapper.class_ is ref_mapper.class_:
if relationship.uselist:
for elem in value:
if local_value := getattr(data, local.key):
setattr(elem, remote.key, local_value)
else:
setattr(data, relationship.key, ref)
def _update_fks(self, data: ModelT) -> None: # pragma: no cover
"""Update foreign key fields according to their corresponding relationships.
This make sure that `data.child_id` == `data.child.id`
or `data.children[0].parent_id` == `data.id`
Args:
data: Instance to be updated
"""
ref_mapper = object_mapper(data)
for relationship in ref_mapper.relationships:
if value := getattr(data, relationship.key):
local = next(iter(relationship.local_columns))
remote = next(iter(relationship.remote_side))
if not local.key or not remote.key:
msg = f"Cannot update relationship {relationship} for model {ref_mapper.class_}"
raise AdvancedAlchemyError(msg)
if relationship.uselist:
for elem in value:
if local_value := getattr(data, local.key):
setattr(elem, remote.key, local_value)
self._update_relationship(elem, data)
# Remove duplicates added by orm when updating list items
if isinstance(value, list):
setattr(data, relationship.key, type(value)(set(value))) # pyright: ignore[reportUnknownArgumentType]
else:
if remote_value := getattr(value, remote.key):
setattr(data, local.key, remote_value)
self._update_relationship(value, data)
def _set_defaults(self, data: ModelT) -> None: # pragma: no cover
"""Set fields with dynamic defaults.
Args:
data: Instance to be updated
"""
for elem in object_mapper(data).c:
default = getattr(elem, "default", MISSING)
value = getattr(data, elem.key, MISSING)
# If value is MISSING, it may be a declared_attr whose name can't be
# determined from the column/relationship element returned
if value is not MISSING and not value and not isinstance(default, _MISSING) and default is not None:
if default.is_scalar:
default_value: Any = default.arg
elif default.is_callable:
default_callable = default.arg.__func__ if isinstance(default.arg, staticmethod) else default.arg # pyright: ignore[reportUnknownMemberType]
if (
# Eager test because inspect.signature() does not
# recognize builtins
hasattr(builtins, default_callable.__name__)
# If present, context contains information about the current
# statement and can be used to access values from other columns.
# As we can't reproduce such context in Pydantic, we don't want
# include a default_factory in that case.
or "context" not in signature(default_callable).parameters
):
default_value = default.arg({}) # pyright: ignore[reportUnknownMemberType, reportCallIssue]
else:
continue
else:
continue
setattr(data, elem.key, default_value)
def changed_attrs(self, data: ModelT) -> "Iterable[str]": # pragma: no cover
res: list[str] = []
mapper = inspect(data)
if mapper is None:
msg = f"Cannot inspect {data.__class__} model"
raise AdvancedAlchemyError(msg)
attrs = class_mapper(data.__class__).column_attrs
for attr in attrs:
hist = getattr(mapper.attrs, attr.key).history
if hist.has_changes():
res.append(attr.key)
return res
def key(self, obj: ModelT) -> str:
return str(getattr(obj, self.id_attribute))
def add(self, obj: ModelT) -> ModelT:
self._set_defaults(obj)
self._update_fks(obj)
return super().add(obj)
def update(self, obj: ModelT) -> ModelT:
existing = self.get(self.key(obj))
for attr in self.changed_attrs(obj):
setattr(existing, attr, getattr(obj, attr))
self._update_fks(existing)
return super().update(existing)
class SQLAlchemyMultiStore(MultiStore[ModelT]):
def _new_instances(self, instance: ModelT) -> "Iterable[ModelT]":
session = Session()
session.add(instance)
relations = list(session.new)
session.expunge_all()
return relations
def _set_relationships_for_fks(self, data: ModelT) -> None: # pragma: no cover
"""Set relationships matching newly added foreign keys on the instance.
Example:
```python
class Parent(Base):
id: Mapped[UUID]
class Child(Base):
id: Mapped[UUID]
parent_id: Mapped[UUID] = mapped_column(ForeignKey("parent.id"))
parent: Mapped[Parent] = relationship(Parent)
```
If `data` is a Child instance and `parent_id` is set, `parent` will be set
to the matching Parent instance if found in the repository
Args:
data: The model to update
"""
obj_mapper = object_mapper(data)
mappers: dict[str, Mapper[Any]] = {}
column_relationships: dict[ColumnElement[Any], RelationshipProperty[Any]] = {}
for mapper in obj_mapper.registry.mappers:
for table in mapper.tables:
mappers[table.name] = mapper
for relationship in obj_mapper.relationships:
for column in relationship.local_columns:
column_relationships[column] = relationship
# sourcery skip: assign-if-exp
if state := inspect(data):
new_attrs: dict[str, Any] = state.dict
else:
new_attrs = {}
for column in obj_mapper.columns:
if column.key not in new_attrs or not column.foreign_keys:
continue
remote_mapper = mappers[next(iter(column.foreign_keys))._table_key()] # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
try:
obj = self.store(remote_mapper.class_).get(new_attrs.get(column.key, None))
except KeyError:
continue
with contextlib.suppress(KeyError):
setattr(data, column_relationships[column].key, obj)
def add(self, identity: Any, obj: ModelT) -> ModelT:
for relation in self._new_instances(obj):
instance_type = self.identity(relation)
self._set_relationships_for_fks(relation)
if relation in self.store(instance_type):
continue
self.store(instance_type).add(relation)
return obj
python-advanced-alchemy-1.0.1/advanced_alchemy/repository/typing.py 0000664 0000000 0000000 00000005150 14766637146 0025614 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any, Union
from sqlalchemy.orm import InstrumentedAttribute
from typing_extensions import TypeAlias, TypeVar
if TYPE_CHECKING:
from sqlalchemy import RowMapping, Select
from advanced_alchemy import base
from advanced_alchemy.repository._async import SQLAlchemyAsyncRepository
from advanced_alchemy.repository._sync import SQLAlchemySyncRepository
from advanced_alchemy.repository.memory._async import SQLAlchemyAsyncMockRepository
from advanced_alchemy.repository.memory._sync import SQLAlchemySyncMockRepository
__all__ = (
"MISSING",
"ModelOrRowMappingT",
"ModelT",
"OrderingPair",
"RowMappingT",
"RowT",
"SQLAlchemyAsyncRepositoryT",
"SQLAlchemySyncRepositoryT",
"SelectT",
"T",
)
T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="base.ModelProtocol")
"""Type variable for SQLAlchemy models.
:class:`~advanced_alchemy.base.ModelProtocol`
"""
SelectT = TypeVar("SelectT", bound="Select[Any]")
"""Type variable for SQLAlchemy select statements.
:class:`~sqlalchemy.sql.Select`
"""
RowT = TypeVar("RowT", bound=tuple[Any, ...])
"""Type variable for rows.
:class:`~sqlalchemy.engine.Row`
"""
RowMappingT = TypeVar("RowMappingT", bound="RowMapping")
"""Type variable for row mappings.
:class:`~sqlalchemy.engine.RowMapping`
"""
ModelOrRowMappingT = TypeVar("ModelOrRowMappingT", bound="Union[base.ModelProtocol, RowMapping]")
"""Type variable for models or row mappings.
:class:`~advanced_alchemy.base.ModelProtocol` | :class:`~sqlalchemy.engine.RowMapping`
"""
SQLAlchemySyncRepositoryT = TypeVar(
"SQLAlchemySyncRepositoryT",
bound="Union[SQLAlchemySyncRepository[Any], SQLAlchemySyncMockRepository[Any]]",
default="Any",
)
"""Type variable for synchronous SQLAlchemy repositories.
:class:`~advanced_alchemy.repository.SQLAlchemySyncRepository`
"""
SQLAlchemyAsyncRepositoryT = TypeVar(
"SQLAlchemyAsyncRepositoryT",
bound="Union[SQLAlchemyAsyncRepository[Any], SQLAlchemyAsyncMockRepository[Any]]",
default="Any",
)
"""Type variable for asynchronous SQLAlchemy repositories.
:class:`~advanced_alchemy.repository.SQLAlchemyAsyncRepository`
"""
OrderingPair: TypeAlias = tuple[Union[str, InstrumentedAttribute[Any]], bool]
"""Type alias for ordering pairs.
A tuple of (column, ascending) where:
- column: Union[str, :class:`sqlalchemy.orm.InstrumentedAttribute`]
- ascending: bool
This type is used to specify ordering criteria for repository queries.
"""
class _MISSING:
"""Placeholder for missing values."""
MISSING = _MISSING()
"""Missing value placeholder.
:class:`~advanced_alchemy.repository.typing._MISSING`
"""
python-advanced-alchemy-1.0.1/advanced_alchemy/service/ 0000775 0000000 0000000 00000000000 14766637146 0023150 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/service/__init__.py 0000664 0000000 0000000 00000004463 14766637146 0025270 0 ustar 00root root 0000000 0000000 from advanced_alchemy.repository import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
Empty,
EmptyType,
ErrorMessages,
LoadSpec,
ModelOrRowMappingT,
ModelT,
OrderingPair,
model_from_dict,
)
from advanced_alchemy.service._async import (
SQLAlchemyAsyncQueryService,
SQLAlchemyAsyncRepositoryReadService,
SQLAlchemyAsyncRepositoryService,
)
from advanced_alchemy.service._sync import (
SQLAlchemySyncQueryService,
SQLAlchemySyncRepositoryReadService,
SQLAlchemySyncRepositoryService,
)
from advanced_alchemy.service._util import ResultConverter, find_filter
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service.typing import (
FilterTypeT,
ModelDictListT,
ModelDictT,
ModelDTOT,
SupportedSchemaModel,
is_dict,
is_dict_with_field,
is_dict_without_field,
is_dto_data,
is_msgspec_struct,
is_msgspec_struct_with_field,
is_msgspec_struct_without_field,
is_pydantic_model,
is_pydantic_model_with_field,
is_pydantic_model_without_field,
is_schema,
is_schema_or_dict,
is_schema_or_dict_with_field,
is_schema_or_dict_without_field,
is_schema_with_field,
is_schema_without_field,
schema_dump,
)
__all__ = (
"DEFAULT_ERROR_MESSAGE_TEMPLATES",
"Empty",
"EmptyType",
"ErrorMessages",
"FilterTypeT",
"LoadSpec",
"ModelDTOT",
"ModelDictListT",
"ModelDictT",
"ModelOrRowMappingT",
"ModelT",
"OffsetPagination",
"OrderingPair",
"ResultConverter",
"SQLAlchemyAsyncQueryService",
"SQLAlchemyAsyncRepositoryReadService",
"SQLAlchemyAsyncRepositoryService",
"SQLAlchemySyncQueryService",
"SQLAlchemySyncRepositoryReadService",
"SQLAlchemySyncRepositoryService",
"SupportedSchemaModel",
"find_filter",
"is_dict",
"is_dict_with_field",
"is_dict_without_field",
"is_dto_data",
"is_msgspec_struct",
"is_msgspec_struct_with_field",
"is_msgspec_struct_without_field",
"is_pydantic_model",
"is_pydantic_model_with_field",
"is_pydantic_model_without_field",
"is_schema",
"is_schema_or_dict",
"is_schema_or_dict_with_field",
"is_schema_or_dict_without_field",
"is_schema_with_field",
"is_schema_without_field",
"model_from_dict",
"schema_dump",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/service/_async.py 0000664 0000000 0000000 00000135242 14766637146 0025005 0 ustar 00root root 0000000 0000000 """Service object implementation for SQLAlchemy.
RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""
from collections.abc import AsyncIterator, Iterable, Sequence
from contextlib import asynccontextmanager
from functools import cached_property
from typing import Any, ClassVar, Generic, Optional, Union, cast
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql import ColumnElement
from typing_extensions import Self
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.exceptions import AdvancedAlchemyError, ErrorMessages, ImproperConfigurationError, RepositoryError
from advanced_alchemy.filters import StatementFilter
from advanced_alchemy.repository import (
SQLAlchemyAsyncQueryRepository,
)
from advanced_alchemy.repository._util import LoadSpec, model_from_dict
from advanced_alchemy.repository.typing import ModelT, OrderingPair, SQLAlchemyAsyncRepositoryT
from advanced_alchemy.service._util import ResultConverter
from advanced_alchemy.service.typing import (
BulkModelDictT,
ModelDictListT,
ModelDictT,
is_dict,
is_dto_data,
is_msgspec_struct,
is_pydantic_model,
)
from advanced_alchemy.utils.dataclass import Empty, EmptyType
class SQLAlchemyAsyncQueryService(ResultConverter):
"""Simple service to execute the basic Query repository.."""
def __init__(
self,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
**repo_kwargs: Any,
) -> None:
"""Configure the service object.
Args:
session: Session managing the unit-of-work for the operation.
**repo_kwargs: Optional configuration values to pass into the repository
"""
self.repository = SQLAlchemyAsyncQueryRepository(
session=session,
**repo_kwargs,
)
@classmethod
@asynccontextmanager
async def new(
cls,
session: Optional[Union[AsyncSession, async_scoped_session[AsyncSession]]] = None,
config: Optional[SQLAlchemyAsyncConfig] = None,
) -> AsyncIterator[Self]:
"""Context manager that returns instance of service object.
Handles construction of the database session._create_select_for_model
Returns:
The service object instance.
"""
if not config and not session:
raise AdvancedAlchemyError(detail="Please supply an optional configuration or session to use.")
if session:
yield cls(session=session)
elif config:
async with config.get_session() as db_session:
yield cls(session=db_session)
class SQLAlchemyAsyncRepositoryReadService(ResultConverter, Generic[ModelT, SQLAlchemyAsyncRepositoryT]):
"""Service object that operates on a repository object."""
repository_type: type[SQLAlchemyAsyncRepositoryT]
"""Type of the repository to use."""
loader_options: ClassVar[Optional[LoadSpec]] = None
"""Default loader options for the repository."""
execution_options: ClassVar[Optional[dict[str, Any]]] = None
"""Default execution options for the repository."""
match_fields: ClassVar[Optional[Union[list[str], str]]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: ClassVar[bool] = False
"""Optionally apply the ``unique()`` method to results before returning."""
count_with_window_function: ClassVar[bool] = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query."""
_repository_instance: SQLAlchemyAsyncRepositoryT
def __init__(
self,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
*,
statement: Optional[Select[Any]] = None,
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
**repo_kwargs: Any,
) -> None:
"""Configure the service object.
Args:
session: Session managing the unit-of-work for the operation.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap exceptions in a RepositoryError
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
**repo_kwargs: passed as keyword args to repo instantiation.
"""
load = load if load is not None else self.loader_options
execution_options = execution_options if execution_options is not None else self.execution_options
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._repository_instance: SQLAlchemyAsyncRepositoryT = self.repository_type( # type: ignore[assignment]
session=session,
statement=statement,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
auto_commit=auto_commit,
order_by=order_by,
error_messages=error_messages,
wrap_exceptions=wrap_exceptions,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
count_with_window_function=count_with_window_function,
**repo_kwargs,
)
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
return bool(uniquify or self.uniquify)
@property
def repository(self) -> SQLAlchemyAsyncRepositoryT:
"""Return the repository instance."""
if not self._repository_instance:
msg = "Repository not initialized"
raise ImproperConfigurationError(msg)
return self._repository_instance
@cached_property
def model_type(self) -> type[ModelT]:
"""Return the model type."""
return cast("type[ModelT]", self.repository.model_type)
async def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
"""Count of records returned by query.
Args:
*filters: arguments for filtering.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: key value pairs of filter types.
Returns:
A count of the collection, filtered, but ignoring pagination.
"""
return await self.repository.count(
*filters,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
)
async def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
"""Wrap repository exists operation.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Keyword arguments for attribute based filtering.
Returns:
Representation of instance with identifier `item_id`.
"""
return await self.repository.exists(
*filters,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
)
async def get(
self,
item_id: Any,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository scalar operation.
Args:
item_id: Identifier of instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"ModelT",
await self.repository.get(
item_id=item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
"""Wrap repository scalar operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"ModelT",
await self.repository.get_one(
*filters,
auto_expunge=auto_expunge,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Wrap repository scalar operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"Optional[ModelT]",
await self.repository.get_one_or_none(
*filters,
auto_expunge=auto_expunge,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
async def to_model_on_create(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on create.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
async def to_model_on_update(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on update.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
async def to_model_on_delete(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on delete.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
async def to_model_on_upsert(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on upsert.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
async def to_model(
self,
data: "ModelDictT[ModelT]",
operation: Optional[str] = None,
) -> ModelT:
"""Parse and Convert input into a model.
Args:
data: Representations to be created.
operation: Optional operation flag so that you can provide behavior based on CRUD operation
Returns:
Representation of created instances.
"""
operation_map = {
"create": self.to_model_on_create,
"update": self.to_model_on_update,
"delete": self.to_model_on_delete,
"upsert": self.to_model_on_upsert,
}
if operation and (op := operation_map.get(operation)):
data = await op(data)
if is_dict(data):
return model_from_dict(model=self.model_type, **data)
if is_pydantic_model(data):
return model_from_dict(
model=self.model_type,
**data.model_dump(exclude_unset=True),
)
if is_msgspec_struct(data):
from msgspec import UNSET
return model_from_dict(
model=self.model_type,
**{f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET},
)
if is_dto_data(data):
return cast("ModelT", data.create_instance())
return cast("ModelT", data)
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[Sequence[ModelT], int]:
"""List of records and total count returned by query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
List of instances and count of total collection, ignoring pagination.
"""
return cast(
"tuple[Sequence[ModelT], int]",
await self.repository.list_and_count(
*filters,
statement=statement,
auto_expunge=auto_expunge,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
@classmethod
@asynccontextmanager
async def new(
cls,
session: Optional[Union[AsyncSession, async_scoped_session[AsyncSession]]] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
config: Optional[SQLAlchemyAsyncConfig] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> AsyncIterator[Self]:
"""Context manager that returns instance of service object.
Handles construction of the database session._create_select_for_model
Returns:
The service object instance.
"""
if not config and not session:
raise AdvancedAlchemyError(detail="Please supply an optional configuration or session to use.")
if session:
yield cls(
statement=statement,
session=session,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
elif config:
async with config.get_session() as db_session:
yield cls(
statement=statement,
session=db_session,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Wrap repository scalars operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances retrieved from the repository.
"""
return cast(
"Sequence[ModelT]",
await self.repository.list(
*filters,
statement=statement,
auto_expunge=auto_expunge,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
class SQLAlchemyAsyncRepositoryService(
SQLAlchemyAsyncRepositoryReadService[ModelT, SQLAlchemyAsyncRepositoryT],
Generic[ModelT, SQLAlchemyAsyncRepositoryT],
):
"""Service object that operates on a repository object."""
async def create(
self,
data: "ModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> "ModelT":
"""Wrap repository instance creation.
Args:
data: Representation to be created.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
Representation of created instance.
"""
data = await self.to_model(data, "create")
return cast(
"ModelT",
await self.repository.add(
data=data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
),
)
async def create_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance creation.
Args:
data: Representations to be created.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of created instances.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(await self.to_model(datum, "create")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
await self.repository.add_many(
data=cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
),
)
async def update(
self,
data: "ModelDictT[ModelT]",
item_id: Optional[Any] = None,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> "ModelT":
"""Wrap repository update operation.
Args:
data: Representation to be updated.
item_id: Identifier of item to be updated.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated representation.
"""
data = await self.to_model(data, "update")
if (
item_id is None
and self.repository.get_id_attribute_value( # pyright: ignore[reportUnknownMemberType]
item=data,
id_attribute=id_attribute,
)
is None
):
msg = (
"Could not identify ID attribute value. One of the following is required: "
f"``item_id`` or ``data.{id_attribute or self.repository.id_attribute}``"
)
raise RepositoryError(msg)
if item_id is not None:
data = self.repository.set_id_attribute_value(item_id=item_id, item=data, id_attribute=id_attribute) # pyright: ignore[reportUnknownMemberType]
return cast(
"ModelT",
await self.repository.update(
data=data,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def update_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance update.
Args:
data: Representations to be updated.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of updated instances.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(await self.to_model(datum, "update")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
await self.repository.update_many(
cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def upsert(
self,
data: "ModelDictT[ModelT]",
item_id: Optional[Any] = None,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository upsert operation.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
item_id: Identifier of the object for upsert.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated or created representation.
"""
data = await self.to_model(data, "upsert")
item_id = item_id if item_id is not None else self.repository.get_id_attribute_value(item=data) # pyright: ignore[reportUnknownMemberType]
if item_id is not None:
self.repository.set_id_attribute_value(item_id, data) # pyright: ignore[reportUnknownMemberType]
return cast(
"ModelT",
await self.repository.upsert(
data=data,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_expunge=auto_expunge,
auto_commit=auto_commit,
auto_refresh=auto_refresh,
match_fields=match_fields,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def upsert_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository upsert operation.
Args:
data: Instance to update existing, or be created.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements (**reserved for future use**)
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated or created representation.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(await self.to_model(datum, "upsert")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
await self.repository.upsert_many(
data=cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_expunge=auto_expunge,
auto_commit=auto_commit,
no_merge=no_merge,
match_fields=match_fields,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, perform an update operation on the model.
create: Should a model be created. If no model is found, an exception is raised.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of created instance.
"""
match_fields = match_fields or self.match_fields
validated_model = await self.to_model(kwargs, "create")
return cast(
"tuple[ModelT, bool]",
await self.repository.get_or_upsert(
*filters,
match_fields=match_fields,
upsert=upsert,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**validated_model.to_dict(),
),
)
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of updated instance.
"""
match_fields = match_fields or self.match_fields
validated_model = await self.to_model(kwargs, "update")
return cast(
"tuple[ModelT, bool]",
await self.repository.get_and_update(
*filters,
match_fields=match_fields,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**validated_model.to_dict(),
),
)
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository delete operation.
Args:
item_id: Identifier of instance to be deleted.
auto_commit: Commit objects before returning.
auto_expunge: Remove object from session before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of the deleted instance.
"""
return cast(
"ModelT",
await self.repository.delete(
item_id=item_id,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance deletion.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of removed instances.
"""
return cast(
"Sequence[ModelT]",
await self.repository.delete_many(
item_ids=item_ids,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
id_attribute=id_attribute,
chunk_size=chunk_size,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Wrap repository scalars operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances deleted from the repository.
"""
return cast(
"Sequence[ModelT]",
await self.repository.delete_where(
*filters,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
sanity_check=sanity_check,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
python-advanced-alchemy-1.0.1/advanced_alchemy/service/_sync.py 0000664 0000000 0000000 00000134424 14766637146 0024645 0 ustar 00root root 0000000 0000000 # Do not edit this file directly. It has been autogenerated from
# advanced_alchemy/service/_async.py
"""Service object implementation for SQLAlchemy.
RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from functools import cached_property
from typing import Any, ClassVar, Generic, Optional, Union, cast
from sqlalchemy import Select
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.sql import ColumnElement
from typing_extensions import Self
from advanced_alchemy.config.sync import SQLAlchemySyncConfig
from advanced_alchemy.exceptions import AdvancedAlchemyError, ErrorMessages, ImproperConfigurationError, RepositoryError
from advanced_alchemy.filters import StatementFilter
from advanced_alchemy.repository import SQLAlchemySyncQueryRepository
from advanced_alchemy.repository._util import LoadSpec, model_from_dict
from advanced_alchemy.repository.typing import ModelT, OrderingPair, SQLAlchemySyncRepositoryT
from advanced_alchemy.service._util import ResultConverter
from advanced_alchemy.service.typing import (
BulkModelDictT,
ModelDictListT,
ModelDictT,
is_dict,
is_dto_data,
is_msgspec_struct,
is_pydantic_model,
)
from advanced_alchemy.utils.dataclass import Empty, EmptyType
class SQLAlchemySyncQueryService(ResultConverter):
"""Simple service to execute the basic Query repository.."""
def __init__(
self,
session: Union[Session, scoped_session[Session]],
**repo_kwargs: Any,
) -> None:
"""Configure the service object.
Args:
session: Session managing the unit-of-work for the operation.
**repo_kwargs: Optional configuration values to pass into the repository
"""
self.repository = SQLAlchemySyncQueryRepository(
session=session,
**repo_kwargs,
)
@classmethod
@contextmanager
def new(
cls,
session: Optional[Union[Session, scoped_session[Session]]] = None,
config: Optional[SQLAlchemySyncConfig] = None,
) -> Iterator[Self]:
"""Context manager that returns instance of service object.
Handles construction of the database session._create_select_for_model
Returns:
The service object instance.
"""
if not config and not session:
raise AdvancedAlchemyError(detail="Please supply an optional configuration or session to use.")
if session:
yield cls(session=session)
elif config:
with config.get_session() as db_session:
yield cls(session=db_session)
class SQLAlchemySyncRepositoryReadService(ResultConverter, Generic[ModelT, SQLAlchemySyncRepositoryT]):
"""Service object that operates on a repository object."""
repository_type: type[SQLAlchemySyncRepositoryT]
"""Type of the repository to use."""
loader_options: ClassVar[Optional[LoadSpec]] = None
"""Default loader options for the repository."""
execution_options: ClassVar[Optional[dict[str, Any]]] = None
"""Default execution options for the repository."""
match_fields: ClassVar[Optional[Union[list[str], str]]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: ClassVar[bool] = False
"""Optionally apply the ``unique()`` method to results before returning."""
count_with_window_function: ClassVar[bool] = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query."""
_repository_instance: SQLAlchemySyncRepositoryT
def __init__(
self,
session: Union[Session, scoped_session[Session]],
*,
statement: Optional[Select[Any]] = None,
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
**repo_kwargs: Any,
) -> None:
"""Configure the service object.
Args:
session: Session managing the unit-of-work for the operation.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap exceptions in a RepositoryError
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
**repo_kwargs: passed as keyword args to repo instantiation.
"""
load = load if load is not None else self.loader_options
execution_options = execution_options if execution_options is not None else self.execution_options
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._repository_instance: SQLAlchemySyncRepositoryT = self.repository_type( # type: ignore[assignment]
session=session,
statement=statement,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
auto_commit=auto_commit,
order_by=order_by,
error_messages=error_messages,
wrap_exceptions=wrap_exceptions,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
count_with_window_function=count_with_window_function,
**repo_kwargs,
)
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
return bool(uniquify or self.uniquify)
@property
def repository(self) -> SQLAlchemySyncRepositoryT:
"""Return the repository instance."""
if not self._repository_instance:
msg = "Repository not initialized"
raise ImproperConfigurationError(msg)
return self._repository_instance
@cached_property
def model_type(self) -> type[ModelT]:
"""Return the model type."""
return cast("type[ModelT]", self.repository.model_type)
def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> int:
"""Count of records returned by query.
Args:
*filters: arguments for filtering.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: key value pairs of filter types.
Returns:
A count of the collection, filtered, but ignoring pagination.
"""
return self.repository.count(
*filters,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
)
def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> bool:
"""Wrap repository exists operation.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Keyword arguments for attribute based filtering.
Returns:
Representation of instance with identifier `item_id`.
"""
return self.repository.exists(
*filters,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
)
def get(
self,
item_id: Any,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository scalar operation.
Args:
item_id: Identifier of instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"ModelT",
self.repository.get(
item_id=item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> ModelT:
"""Wrap repository scalar operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"ModelT",
self.repository.get_one(
*filters,
auto_expunge=auto_expunge,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Wrap repository scalar operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of instance with identifier `item_id`.
"""
return cast(
"Optional[ModelT]",
self.repository.get_one_or_none(
*filters,
auto_expunge=auto_expunge,
statement=statement,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
def to_model_on_create(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on create.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
def to_model_on_update(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on update.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
def to_model_on_delete(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on delete.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
def to_model_on_upsert(self, data: "ModelDictT[ModelT]") -> "ModelDictT[ModelT]":
"""Convenience method to allow for custom behavior on upsert.
Args:
data: The data to be converted to a model.
Returns:
The data to be converted to a model.
"""
return data
def to_model(
self,
data: "ModelDictT[ModelT]",
operation: Optional[str] = None,
) -> ModelT:
"""Parse and Convert input into a model.
Args:
data: Representations to be created.
operation: Optional operation flag so that you can provide behavior based on CRUD operation
Returns:
Representation of created instances.
"""
operation_map = {
"create": self.to_model_on_create,
"update": self.to_model_on_update,
"delete": self.to_model_on_delete,
"upsert": self.to_model_on_upsert,
}
if operation and (op := operation_map.get(operation)):
data = op(data)
if is_dict(data):
return model_from_dict(model=self.model_type, **data)
if is_pydantic_model(data):
return model_from_dict(
model=self.model_type,
**data.model_dump(exclude_unset=True),
)
if is_msgspec_struct(data):
from msgspec import UNSET
return model_from_dict(
model=self.model_type,
**{f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET},
)
if is_dto_data(data):
return cast("ModelT", data.create_instance())
return cast("ModelT", data)
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[Sequence[ModelT], int]:
"""List of records and total count returned by query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
List of instances and count of total collection, ignoring pagination.
"""
return cast(
"tuple[Sequence[ModelT], int]",
self.repository.list_and_count(
*filters,
statement=statement,
auto_expunge=auto_expunge,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
@classmethod
@contextmanager
def new(
cls,
session: Optional[Union[Session, scoped_session[Session]]] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
config: Optional[SQLAlchemySyncConfig] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
) -> Iterator[Self]:
"""Context manager that returns instance of service object.
Handles construction of the database session._create_select_for_model
Returns:
The service object instance.
"""
if not config and not session:
raise AdvancedAlchemyError(detail="Please supply an optional configuration or session to use.")
if session:
yield cls(
statement=statement,
session=session,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
elif config:
with config.get_session() as db_session:
yield cls(
statement=statement,
session=db_session,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=uniquify,
count_with_window_function=count_with_window_function,
)
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Wrap repository scalars operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances retrieved from the repository.
"""
return cast(
"Sequence[ModelT]",
self.repository.list(
*filters,
statement=statement,
auto_expunge=auto_expunge,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
class SQLAlchemySyncRepositoryService(
SQLAlchemySyncRepositoryReadService[ModelT, SQLAlchemySyncRepositoryT],
Generic[ModelT, SQLAlchemySyncRepositoryT],
):
"""Service object that operates on a repository object."""
def create(
self,
data: "ModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> "ModelT":
"""Wrap repository instance creation.
Args:
data: Representation to be created.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
Returns:
Representation of created instance.
"""
data = self.to_model(data, "create")
return cast(
"ModelT",
self.repository.add(
data=data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
),
)
def create_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance creation.
Args:
data: Representations to be created.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of created instances.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(self.to_model(datum, "create")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
self.repository.add_many(
data=cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
),
)
def update(
self,
data: "ModelDictT[ModelT]",
item_id: Optional[Any] = None,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> "ModelT":
"""Wrap repository update operation.
Args:
data: Representation to be updated.
item_id: Identifier of item to be updated.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated representation.
"""
data = self.to_model(data, "update")
if (
item_id is None
and self.repository.get_id_attribute_value( # pyright: ignore[reportUnknownMemberType]
item=data,
id_attribute=id_attribute,
)
is None
):
msg = (
"Could not identify ID attribute value. One of the following is required: "
f"``item_id`` or ``data.{id_attribute or self.repository.id_attribute}``"
)
raise RepositoryError(msg)
if item_id is not None:
data = self.repository.set_id_attribute_value(item_id=item_id, item=data, id_attribute=id_attribute) # pyright: ignore[reportUnknownMemberType]
return cast(
"ModelT",
self.repository.update(
data=data,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def update_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance update.
Args:
data: Representations to be updated.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of updated instances.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(self.to_model(datum, "update")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
self.repository.update_many(
cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def upsert(
self,
data: "ModelDictT[ModelT]",
item_id: Optional[Any] = None,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository upsert operation.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
item_id: Identifier of the object for upsert.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated or created representation.
"""
data = self.to_model(data, "upsert")
item_id = item_id if item_id is not None else self.repository.get_id_attribute_value(item=data) # pyright: ignore[reportUnknownMemberType]
if item_id is not None:
self.repository.set_id_attribute_value(item_id, data) # pyright: ignore[reportUnknownMemberType]
return cast(
"ModelT",
self.repository.upsert(
data=data,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_expunge=auto_expunge,
auto_commit=auto_commit,
auto_refresh=auto_refresh,
match_fields=match_fields,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def upsert_many(
self,
data: "BulkModelDictT[ModelT]",
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository upsert operation.
Args:
data: Instance to update existing, or be created.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements (**reserved for future use**)
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Updated or created representation.
"""
if is_dto_data(data):
data = data.create_instance()
data = [(self.to_model(datum, "upsert")) for datum in cast("ModelDictListT[ModelT]", data)]
return cast(
"Sequence[ModelT]",
self.repository.upsert_many(
data=cast("list[ModelT]", data), # pyright: ignore[reportUnnecessaryCast]
auto_expunge=auto_expunge,
auto_commit=auto_commit,
no_merge=no_merge,
match_fields=match_fields,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, perform an update operation on the model.
create: Should a model be created. If no model is found, an exception is raised.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of created instance.
"""
match_fields = match_fields or self.match_fields
validated_model = self.to_model(kwargs, "create")
return cast(
"tuple[ModelT, bool]",
self.repository.get_or_upsert(
*filters,
match_fields=match_fields,
upsert=upsert,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**validated_model.to_dict(),
),
)
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of updated instance.
"""
match_fields = match_fields or self.match_fields
validated_model = self.to_model(kwargs, "update")
return cast(
"tuple[ModelT, bool]",
self.repository.get_and_update(
*filters,
match_fields=match_fields,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**validated_model.to_dict(),
),
)
def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> ModelT:
"""Wrap repository delete operation.
Args:
item_id: Identifier of instance to be deleted.
auto_commit: Commit objects before returning.
auto_expunge: Remove object from session before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of the deleted instance.
"""
return cast(
"ModelT",
self.repository.delete(
item_id=item_id,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
) -> Sequence[ModelT]:
"""Wrap repository bulk instance deletion.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
Returns:
Representation of removed instances.
"""
return cast(
"Sequence[ModelT]",
self.repository.delete_many(
item_ids=item_ids,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
id_attribute=id_attribute,
chunk_size=chunk_size,
error_messages=error_messages,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
),
)
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Wrap repository scalars operation.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
**kwargs: Instance attribute value filters.
Returns:
The list of instances deleted from the repository.
"""
return cast(
"Sequence[ModelT]",
self.repository.delete_where(
*filters,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
error_messages=error_messages,
sanity_check=sanity_check,
load=load,
execution_options=execution_options,
uniquify=self._get_uniquify(uniquify),
**kwargs,
),
)
python-advanced-alchemy-1.0.1/advanced_alchemy/service/_typing.py 0000664 0000000 0000000 00000007425 14766637146 0025203 0 ustar 00root root 0000000 0000000 """This is a simple wrapper around a few important classes in each library.
This is used to ensure compatibility when one or more of the libraries are installed.
"""
from typing import (
Any,
ClassVar,
Optional,
Protocol,
cast,
runtime_checkable,
)
from typing_extensions import TypeVar, dataclass_transform
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
try:
from pydantic import BaseModel, FailFast, TypeAdapter # pyright: ignore[reportGeneralTypeIssues]
PYDANTIC_INSTALLED = True
except ImportError:
@runtime_checkable
class BaseModel(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""
model_fields: "ClassVar[dict[str, Any]]"
def model_dump(self, *args: Any, **kwargs: Any) -> "dict[str, Any]":
"""Placeholder"""
return {}
@runtime_checkable
class TypeAdapter(Protocol[T_co]): # type: ignore[no-redef]
"""Placeholder Implementation"""
def __init__(
self,
type: Any, # noqa: A002
*,
config: "Optional[Any]" = None,
_parent_depth: int = 2,
module: "Optional[str]" = None,
) -> None:
"""Init"""
def validate_python(
self,
object: Any, # noqa: A002
/,
*,
strict: "Optional[bool]" = None,
from_attributes: "Optional[bool]" = None,
context: "Optional[dict[str, Any]]" = None,
) -> T_co:
"""Stub"""
return cast("T_co", object)
@runtime_checkable
class FailFast(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation for FailFast"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init"""
PYDANTIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
try:
from msgspec import (
UNSET,
Struct,
UnsetType, # pyright: ignore[reportAssignmentType,reportGeneralTypeIssues]
convert, # pyright: ignore[reportGeneralTypeIssues]
)
MSGSPEC_INSTALLED: bool = True
except ImportError:
import enum
@dataclass_transform()
@runtime_checkable
class Struct(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""
__struct_fields__: "ClassVar[tuple[str, ...]]"
def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef] # noqa: ARG001
"""Placeholder implementation"""
return {}
class UnsetType(enum.Enum): # type: ignore[no-redef]
UNSET = "UNSET"
UNSET = UnsetType.UNSET # pyright: ignore[reportConstantRedefinition]
MSGSPEC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
try:
from litestar.dto.data_structures import DTOData
LITESTAR_INSTALLED = True
except ImportError:
@runtime_checkable
class DTOData(Protocol[T]): # type: ignore[no-redef]
"""Placeholder implementation"""
__slots__ = ("_backend", "_data_as_builtins")
def __init__(self, backend: Any, data_as_builtins: Any) -> None:
"""Placeholder init"""
def create_instance(self, **kwargs: Any) -> T:
"""Placeholder implementation"""
return cast("T", kwargs)
def update_instance(self, instance: T, **kwargs: Any) -> T:
"""Placeholder implementation"""
return cast("T", kwargs)
def as_builtins(self) -> Any:
"""Placeholder implementation"""
return {}
LITESTAR_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
__all__ = (
"LITESTAR_INSTALLED",
"MSGSPEC_INSTALLED",
"PYDANTIC_INSTALLED",
"UNSET",
"BaseModel",
"DTOData",
"FailFast",
"Struct",
"TypeAdapter",
"UnsetType",
"convert",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/service/_util.py 0000664 0000000 0000000 00000025222 14766637146 0024641 0 ustar 00root root 0000000 0000000 """Service object implementation for SQLAlchemy.
RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""
import datetime
from collections.abc import Sequence
from enum import Enum
from functools import partial
from pathlib import Path, PurePath
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, overload
from uuid import UUID
from advanced_alchemy.exceptions import AdvancedAlchemyError
from advanced_alchemy.filters import LimitOffset, StatementFilter
from advanced_alchemy.repository.typing import ModelOrRowMappingT
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service.typing import (
MSGSPEC_INSTALLED,
PYDANTIC_INSTALLED,
BaseModel,
FilterTypeT,
ModelDTOT,
Struct,
convert,
get_type_adapter,
)
if TYPE_CHECKING:
from sqlalchemy import ColumnElement, RowMapping
from advanced_alchemy.base import ModelProtocol
__all__ = ("ResultConverter", "find_filter")
DEFAULT_TYPE_DECODERS = [ # pyright: ignore[reportUnknownVariableType]
(lambda x: x is UUID, lambda t, v: t(v.hex)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
(lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
(lambda x: x is datetime.date, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
(lambda x: x is datetime.time, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
(lambda x: x is Enum, lambda t, v: t(v.value)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
]
def _default_msgspec_deserializer(
target_type: Any,
value: Any,
type_decoders: "Union[Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]], None]" = None,
) -> Any: # pragma: no cover
"""Transform values non-natively supported by ``msgspec``
Args:
target_type: Encountered type
value: Value to coerce
type_decoders: Optional sequence of type decoders
Returns:
A ``msgspec``-supported type
"""
if isinstance(value, target_type):
return value
if type_decoders:
for predicate, decoder in type_decoders:
if predicate(target_type):
return decoder(target_type, value)
if issubclass(target_type, (Path, PurePath, UUID)):
return target_type(value)
try:
return target_type(value)
except Exception as e:
msg = f"Unsupported type: {type(value)!r}"
raise TypeError(msg) from e
def find_filter(
filter_type: type[FilterTypeT],
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter]]",
) -> "Union[FilterTypeT, None]":
"""Get the filter specified by filter type from the filters.
Args:
filter_type: The type of filter to find.
filters: filter types to apply to the query
Returns:
The match filter instance or None
"""
return next(
(cast("Optional[FilterTypeT]", filter_) for filter_ in filters if isinstance(filter_, filter_type)),
None,
)
class ResultConverter:
"""Simple mixin to help convert to a paginated response model.
Single objects are transformed to the supplied schema type, and lists of objects are automatically transformed into an `OffsetPagination` response of the supplied schema type.
Args:
data: A database model instance or row mapping.
Type: :class:`~advanced_alchemy.repository.typing.ModelOrRowMappingT`
Returns:
The converted schema object.
"""
@overload
def to_schema(
self,
data: "ModelOrRowMappingT",
*,
schema_type: None = None,
) -> "ModelOrRowMappingT": ...
@overload
def to_schema(
self,
data: "Union[ModelProtocol, RowMapping]",
*,
schema_type: "type[ModelDTOT]",
) -> "ModelDTOT": ...
@overload
def to_schema(
self,
data: "ModelOrRowMappingT",
total: "Optional[int]" = None,
*,
schema_type: None = None,
) -> "ModelOrRowMappingT": ...
@overload
def to_schema(
self,
data: "Union[ModelProtocol, RowMapping]",
total: "Optional[int]" = None,
*,
schema_type: "type[ModelDTOT]",
) -> "ModelDTOT": ...
@overload
def to_schema(
self,
data: "ModelOrRowMappingT",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: None = None,
) -> "ModelOrRowMappingT": ...
@overload
def to_schema(
self,
data: "Union[ModelProtocol, RowMapping]",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: "type[ModelDTOT]",
) -> "ModelDTOT": ...
@overload
def to_schema(
self,
data: "Sequence[ModelOrRowMappingT]",
*,
schema_type: None = None,
) -> "OffsetPagination[ModelOrRowMappingT]": ...
@overload
def to_schema(
self,
data: "Union[Sequence[ModelProtocol], Sequence[RowMapping]]",
*,
schema_type: "type[ModelDTOT]",
) -> "OffsetPagination[ModelDTOT]": ...
@overload
def to_schema(
self,
data: "Sequence[ModelOrRowMappingT]",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: None = None,
) -> "OffsetPagination[ModelOrRowMappingT]": ...
@overload
def to_schema(
self,
data: "Union[Sequence[ModelProtocol], Sequence[RowMapping]]",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: "type[ModelDTOT]",
) -> "OffsetPagination[ModelDTOT]": ...
def to_schema(
self,
data: "Union[ModelOrRowMappingT, Sequence[ModelOrRowMappingT], ModelProtocol, Sequence[ModelProtocol], RowMapping, Sequence[RowMapping]]",
total: "Optional[int]" = None,
filters: "Union[Sequence[Union[StatementFilter, ColumnElement[bool]]], Sequence[StatementFilter], None]" = None,
*,
schema_type: "Optional[type[ModelDTOT]]" = None,
) -> "Union[ModelOrRowMappingT, OffsetPagination[ModelOrRowMappingT], ModelDTOT, OffsetPagination[ModelDTOT]]":
"""Convert the object to a response schema.
When `schema_type` is None, the model is returned with no conversion.
Args:
data: The return from one of the service calls.
Type: :class:`~advanced_alchemy.repository.typing.ModelOrRowMappingT`
total: The total number of rows in the data.
filters: :class:`~advanced_alchemy.filters.StatementFilter`| :class:`sqlalchemy.sql.expression.ColumnElement` Collection of route filters.
schema_type: :class:`~advanced_alchemy.service.typing.ModelDTOT` Optional schema type to convert the data to
Returns:
- :class:`~advanced_alchemy.base.ModelProtocol` | :class:`sqlalchemy.orm.RowMapping` | :class:`~advanced_alchemy.service.pagination.OffsetPagination` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`
"""
if filters is None:
filters = []
if schema_type is None:
if not isinstance(data, Sequence):
return cast("ModelOrRowMappingT", data) # type: ignore[unreachable,unused-ignore]
limit_offset = find_filter(LimitOffset, filters=filters)
total = total or len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[ModelOrRowMappingT](
items=cast("Sequence[ModelOrRowMappingT]", data),
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
)
if MSGSPEC_INSTALLED and issubclass(schema_type, Struct):
if not isinstance(data, Sequence):
return cast(
"ModelDTOT",
convert(
obj=data,
type=schema_type,
from_attributes=True,
dec_hook=partial(
_default_msgspec_deserializer,
type_decoders=DEFAULT_TYPE_DECODERS,
),
),
)
limit_offset = find_filter(LimitOffset, filters=filters)
total = total or len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[ModelDTOT](
items=convert(
obj=data,
type=list[schema_type], # type: ignore[valid-type]
from_attributes=True,
dec_hook=partial(
_default_msgspec_deserializer,
type_decoders=DEFAULT_TYPE_DECODERS,
),
),
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
)
if PYDANTIC_INSTALLED and issubclass(schema_type, BaseModel):
if not isinstance(data, Sequence):
return cast(
"ModelDTOT",
get_type_adapter(schema_type).validate_python(data, from_attributes=True),
)
limit_offset = find_filter(LimitOffset, filters=filters)
total = total if total else len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[ModelDTOT](
items=get_type_adapter(list[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
)
if not MSGSPEC_INSTALLED and not PYDANTIC_INSTALLED:
msg = "Either Msgspec or Pydantic must be installed to use schema conversion"
raise AdvancedAlchemyError(msg)
msg = "`schema_type` should be a valid Pydantic or Msgspec schema"
raise AdvancedAlchemyError(msg)
python-advanced-alchemy-1.0.1/advanced_alchemy/service/pagination.py 0000664 0000000 0000000 00000001153 14766637146 0025653 0 ustar 00root root 0000000 0000000 from collections.abc import Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar
T = TypeVar("T")
__all__ = ("OffsetPagination",)
@dataclass
class OffsetPagination(Generic[T]):
"""Container for data returned using limit/offset pagination."""
__slots__ = ("items", "limit", "offset", "total")
items: Sequence[T]
"""List of data being sent as part of the response."""
limit: int
"""Maximal number of items to send."""
offset: int
"""Offset from the beginning of the query.
Identical to an index.
"""
total: int
"""Total number of items."""
python-advanced-alchemy-1.0.1/advanced_alchemy/service/typing.py 0000664 0000000 0000000 00000026360 14766637146 0025043 0 ustar 00root root 0000000 0000000 """Service object implementation for SQLAlchemy.
RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Annotated,
Any,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias, TypeGuard
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.service._typing import (
LITESTAR_INSTALLED,
MSGSPEC_INSTALLED,
PYDANTIC_INSTALLED,
UNSET,
BaseModel,
DTOData,
FailFast,
Struct,
TypeAdapter,
convert,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from advanced_alchemy.filters import StatementFilter
PYDANTIC_USE_FAILFAST = False # leave permanently disabled for now
T = TypeVar("T")
FilterTypeT = TypeVar("FilterTypeT", bound="StatementFilter")
"""Type variable for filter types.
:class:`~advanced_alchemy.filters.StatementFilter`
"""
SupportedSchemaModel: TypeAlias = Union[Struct, BaseModel]
"""Type alias for objects that support to_dict or model_dump methods."""
ModelDTOT = TypeVar("ModelDTOT", bound="SupportedSchemaModel")
"""Type variable for model DTOs.
:class:`msgspec.Struct`|:class:`pydantic.BaseModel`
"""
PydanticOrMsgspecT = SupportedSchemaModel
"""Type alias for pydantic or msgspec models.
:class:`msgspec.Struct` or :class:`pydantic.BaseModel`
"""
ModelDictT: TypeAlias = "Union[dict[str, Any], ModelT, SupportedSchemaModel, DTOData[ModelT]]"
"""Type alias for model dictionaries.
Represents:
- :type:`dict[str, Any]` | :class:`~advanced_alchemy.base.ModelProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`litestar.dto.data_structures.DTOData` | :class:`~advanced_alchemy.base.ModelProtocol`
"""
ModelDictListT: TypeAlias = "Sequence[Union[dict[str, Any], ModelT, SupportedSchemaModel]]"
"""Type alias for model dictionary lists.
A list or sequence of any of the following:
- :type:`Sequence`[:type:`dict[str, Any]` | :class:`~advanced_alchemy.base.ModelProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`]
"""
BulkModelDictT: TypeAlias = (
"Union[Sequence[Union[dict[str, Any], ModelT, SupportedSchemaModel]], DTOData[list[ModelT]]]"
)
"""Type alias for bulk model dictionaries.
:type:`Sequence`[ :type:`dict[str, Any]` | :class:`~advanced_alchemy.base.ModelProtocol` | :class:`msgspec.Struct` :class:`pydantic.BaseModel`] | :class:`litestar.dto.data_structures.DTOData`
"""
@lru_cache(typed=True)
def get_type_adapter(f: type[T]) -> TypeAdapter[T]:
"""Caches and returns a pydantic type adapter.
Args:
f: Type to create a type adapter for.
Returns:
:class:`pydantic.TypeAdapter`[:class:`typing.TypeVar`[T]]
"""
if PYDANTIC_USE_FAILFAST:
return TypeAdapter(
Annotated[f, FailFast()],
)
return TypeAdapter(f)
def is_dto_data(v: Any) -> TypeGuard[DTOData[Any]]:
"""Check if a value is a Litestar DTOData object.
Args:
v: Value to check.
Returns:
bool
"""
return LITESTAR_INSTALLED and isinstance(v, DTOData)
def is_pydantic_model(v: Any) -> TypeGuard[BaseModel]:
"""Check if a value is a pydantic model.
Args:
v: Value to check.
Returns:
bool
"""
return PYDANTIC_INSTALLED and isinstance(v, BaseModel)
def is_msgspec_struct(v: Any) -> TypeGuard[Struct]:
"""Check if a value is a msgspec struct.
Args:
v: Value to check.
Returns:
bool
"""
return MSGSPEC_INSTALLED and isinstance(v, Struct)
def is_dataclass(obj: Any) -> TypeGuard[Any]:
"""Check if an object is a dataclass."""
return hasattr(obj, "__dataclass_fields__")
def is_dataclass_with_field(obj: Any, field_name: str) -> TypeGuard[object]: # Can't specify dataclass type directly
"""Check if an object is a dataclass and has a specific field."""
return is_dataclass(obj) and hasattr(obj, field_name)
def is_dataclass_without_field(obj: Any, field_name: str) -> TypeGuard[object]:
"""Check if an object is a dataclass and does not have a specific field."""
return is_dataclass(obj) and not hasattr(obj, field_name)
def is_dict(v: Any) -> TypeGuard[dict[str, Any]]:
"""Check if a value is a dictionary.
Args:
v: Value to check.
Returns:
bool
"""
return isinstance(v, dict)
def is_dict_with_field(v: Any, field_name: str) -> TypeGuard[dict[str, Any]]:
"""Check if a dictionary has a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_dict(v) and field_name in v
def is_dict_without_field(v: Any, field_name: str) -> TypeGuard[dict[str, Any]]:
"""Check if a dictionary does not have a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_dict(v) and field_name not in v
def is_pydantic_model_with_field(v: Any, field_name: str) -> TypeGuard[BaseModel]:
"""Check if a pydantic model has a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_pydantic_model(v) and hasattr(v, field_name)
def is_pydantic_model_without_field(v: Any, field_name: str) -> TypeGuard[BaseModel]:
"""Check if a pydantic model does not have a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_pydantic_model(v) and not hasattr(v, field_name)
def is_msgspec_struct_with_field(v: Any, field_name: str) -> TypeGuard[Struct]:
"""Check if a msgspec struct has a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_msgspec_struct(v) and hasattr(v, field_name)
def is_msgspec_struct_without_field(v: Any, field_name: str) -> "TypeGuard[Struct]":
"""Check if a msgspec struct does not have a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_msgspec_struct(v) and not hasattr(v, field_name)
def is_schema(v: Any) -> "TypeGuard[SupportedSchemaModel]":
"""Check if a value is a msgspec Struct or Pydantic model.
Args:
v: Value to check.
Returns:
bool
"""
return is_msgspec_struct(v) or is_pydantic_model(v)
def is_schema_or_dict(v: Any) -> "TypeGuard[Union[SupportedSchemaModel, dict[str, Any]]]":
"""Check if a value is a msgspec Struct, Pydantic model, or dict.
Args:
v: Value to check.
Returns:
bool
"""
return is_schema(v) or is_dict(v)
def is_schema_with_field(v: Any, field_name: str) -> "TypeGuard[SupportedSchemaModel]":
"""Check if a value is a msgspec Struct or Pydantic model with a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_msgspec_struct_with_field(v, field_name) or is_pydantic_model_with_field(v, field_name)
def is_schema_without_field(v: Any, field_name: str) -> "TypeGuard[SupportedSchemaModel]":
"""Check if a value is a msgspec Struct or Pydantic model without a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return not is_schema_with_field(v, field_name)
def is_schema_or_dict_with_field(v: Any, field_name: str) -> "TypeGuard[Union[SupportedSchemaModel, dict[str, Any]]]":
"""Check if a value is a msgspec Struct, Pydantic model, or dict with a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return is_schema_with_field(v, field_name) or is_dict_with_field(v, field_name)
def is_schema_or_dict_without_field(
v: Any, field_name: str
) -> "TypeGuard[Union[SupportedSchemaModel, dict[str, Any]]]":
"""Check if a value is a msgspec Struct, Pydantic model, or dict without a specific field.
Args:
v: Value to check.
field_name: Field name to check for.
Returns:
bool
"""
return not is_schema_or_dict_with_field(v, field_name)
@overload
def schema_dump(
data: "Union[dict[str, Any], SupportedSchemaModel, DTOData[ModelT]]", exclude_unset: bool = True
) -> "Union[dict[str, Any], ModelT]": ...
@overload
def schema_dump(data: ModelT, exclude_unset: bool = True) -> ModelT: ...
def schema_dump( # noqa: PLR0911
data: "Union[dict[str, Any], ModelT, SupportedSchemaModel, DTOData[ModelT]]", exclude_unset: bool = True
) -> "Union[dict[str, Any], ModelT]":
"""Dump a data object to a dictionary.
Args:
data: :type:`dict[str, Any]` | :class:`advanced_alchemy.base.ModelProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`litestar.dto.data_structures.DTOData[ModelT]`
exclude_unset: :type:`bool` Whether to exclude unset values.
Returns:
Union[:type: dict[str, Any], :class:`~advanced_alchemy.base.ModelProtocol`]
"""
if is_dict(data):
return data
if is_pydantic_model(data):
return data.model_dump(exclude_unset=exclude_unset)
if is_msgspec_struct(data):
if exclude_unset:
return {f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET}
return {f: getattr(data, f, None) for f in data.__struct_fields__}
if is_dto_data(data):
return cast("dict[str, Any]", data.as_builtins())
if hasattr(data, "__dict__"):
return data.__dict__
return cast("ModelT", data) # type: ignore[no-return-any]
__all__ = (
"LITESTAR_INSTALLED",
"MSGSPEC_INSTALLED",
"PYDANTIC_INSTALLED",
"PYDANTIC_USE_FAILFAST",
"UNSET",
"BaseModel",
"BulkModelDictT",
"DTOData",
"FailFast",
"FilterTypeT",
"ModelDTOT",
"ModelDictListT",
"ModelDictT",
"PydanticOrMsgspecT",
"Struct",
"SupportedSchemaModel",
"TypeAdapter",
"UnsetType",
"convert",
"get_type_adapter",
"is_dataclass",
"is_dataclass_with_field",
"is_dataclass_without_field",
"is_dict",
"is_dict_with_field",
"is_dict_without_field",
"is_dto_data",
"is_msgspec_struct",
"is_msgspec_struct_with_field",
"is_msgspec_struct_without_field",
"is_pydantic_model",
"is_pydantic_model_with_field",
"is_pydantic_model_without_field",
"is_schema",
"is_schema_or_dict",
"is_schema_or_dict_with_field",
"is_schema_or_dict_without_field",
"is_schema_with_field",
"is_schema_without_field",
"schema_dump",
)
if TYPE_CHECKING:
if not PYDANTIC_INSTALLED:
from advanced_alchemy.service._typing import BaseModel, FailFast, TypeAdapter
else:
from pydantic import BaseModel, FailFast, TypeAdapter # type: ignore[assignment] # noqa: TC004
if not MSGSPEC_INSTALLED:
from advanced_alchemy.service._typing import UNSET, Struct, UnsetType, convert
else:
from msgspec import UNSET, Struct, UnsetType, convert # type: ignore[assignment] # noqa: TC004
if not LITESTAR_INSTALLED:
from advanced_alchemy.service._typing import DTOData
else:
from litestar.dto import DTOData # type: ignore[assignment] # noqa: TC004
python-advanced-alchemy-1.0.1/advanced_alchemy/types/ 0000775 0000000 0000000 00000000000 14766637146 0022654 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/types/__init__.py 0000664 0000000 0000000 00000001247 14766637146 0024771 0 ustar 00root root 0000000 0000000 from advanced_alchemy.types.datetime import DateTimeUTC
from advanced_alchemy.types.encrypted_string import (
EncryptedString,
EncryptedText,
EncryptionBackend,
FernetBackend,
PGCryptoBackend,
)
from advanced_alchemy.types.guid import GUID, NANOID_INSTALLED, UUID_UTILS_INSTALLED
from advanced_alchemy.types.identity import BigIntIdentity
from advanced_alchemy.types.json import ORA_JSONB, JsonB
__all__ = (
"GUID",
"NANOID_INSTALLED",
"ORA_JSONB",
"UUID_UTILS_INSTALLED",
"BigIntIdentity",
"DateTimeUTC",
"EncryptedString",
"EncryptedText",
"EncryptionBackend",
"FernetBackend",
"JsonB",
"PGCryptoBackend",
)
python-advanced-alchemy-1.0.1/advanced_alchemy/types/datetime.py 0000664 0000000 0000000 00000002225 14766637146 0025023 0 ustar 00root root 0000000 0000000 # ruff: noqa: FA100
import datetime
from typing import Optional
from sqlalchemy import DateTime
from sqlalchemy.engine import Dialect
from sqlalchemy.types import TypeDecorator
__all__ = ("DateTimeUTC",)
class DateTimeUTC(TypeDecorator[datetime.datetime]):
"""Timezone Aware DateTime.
Ensure UTC is stored in the database and that TZ aware dates are returned for all dialects.
"""
impl = DateTime(timezone=True)
cache_ok = True
@property
def python_type(self) -> type[datetime.datetime]:
return datetime.datetime
def process_bind_param(self, value: Optional[datetime.datetime], dialect: Dialect) -> Optional[datetime.datetime]:
if value is None:
return value
if not value.tzinfo:
msg = "tzinfo is required"
raise TypeError(msg)
return value.astimezone(datetime.timezone.utc)
def process_result_value(self, value: Optional[datetime.datetime], dialect: Dialect) -> Optional[datetime.datetime]:
if value is None:
return value
if value.tzinfo is None:
return value.replace(tzinfo=datetime.timezone.utc)
return value
python-advanced-alchemy-1.0.1/advanced_alchemy/types/encrypted_string.py 0000664 0000000 0000000 00000031371 14766637146 0026616 0 ustar 00root root 0000000 0000000 import abc
import base64
import contextlib
import os
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from sqlalchemy import String, Text, TypeDecorator
from sqlalchemy import func as sql_func
from advanced_alchemy.exceptions import IntegrityError
if TYPE_CHECKING:
from sqlalchemy.engine import Dialect
cryptography = None # type: ignore[var-annotated,unused-ignore]
with contextlib.suppress(ImportError):
from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
__all__ = ("EncryptedString", "EncryptedText", "EncryptionBackend", "FernetBackend", "PGCryptoBackend")
class EncryptionBackend(abc.ABC):
"""Abstract base class for encryption backends.
This class defines the interface that all encryption backends must implement.
Concrete implementations should provide the actual encryption/decryption logic.
Attributes:
passphrase (bytes): The encryption passphrase used by the backend.
"""
def mount_vault(self, key: "Union[str, bytes]") -> None:
"""Mounts the vault with the provided encryption key.
Args:
key (str | bytes): The encryption key used to initialize the backend.
"""
if isinstance(key, str):
key = key.encode()
@abc.abstractmethod
def init_engine(self, key: "Union[bytes, str]") -> None: # pragma: nocover
"""Initializes the encryption engine with the provided key.
Args:
key (bytes | str): The encryption key.
Raises:
NotImplementedError: If the method is not implemented by the subclass.
"""
@abc.abstractmethod
def encrypt(self, value: Any) -> str: # pragma: nocover
"""Encrypts the given value.
Args:
value (Any): The value to encrypt.
Returns:
str: The encrypted value.
Raises:
NotImplementedError: If the method is not implemented by the subclass.
"""
@abc.abstractmethod
def decrypt(self, value: Any) -> str: # pragma: nocover
"""Decrypts the given value.
Args:
value (Any): The value to decrypt.
Returns:
str: The decrypted value.
Raises:
NotImplementedError: If the method is not implemented by the subclass.
"""
class PGCryptoBackend(EncryptionBackend):
"""PostgreSQL pgcrypto-based encryption backend.
This backend uses PostgreSQL's pgcrypto extension for encryption/decryption operations.
Requires the pgcrypto extension to be installed in the database.
Attributes:
passphrase (bytes): The base64-encoded passphrase used for encryption and decryption.
"""
def init_engine(self, key: "Union[bytes, str]") -> None:
"""Initializes the pgcrypto engine with the provided key.
Args:
key (bytes | str): The encryption key.
"""
if isinstance(key, str):
key = key.encode()
self.passphrase = base64.urlsafe_b64encode(key)
def encrypt(self, value: Any) -> str:
"""Encrypts the given value using pgcrypto.
Args:
value (Any): The value to encrypt.
Returns:
str: The encrypted value.
Raises:
TypeError: If the value is not a string.
"""
if not isinstance(value, str): # pragma: nocover
value = repr(value)
value = value.encode()
return sql_func.pgp_sym_encrypt(value, self.passphrase) # type: ignore[return-value]
def decrypt(self, value: Any) -> str:
"""Decrypts the given value using pgcrypto.
Args:
value (Any): The value to decrypt.
Returns:
str: The decrypted value.
Raises:
TypeError: If the value is not a string.
"""
if not isinstance(value, str): # pragma: nocover
value = str(value)
return sql_func.pgp_sym_decrypt(value, self.passphrase) # type: ignore[return-value]
class FernetBackend(EncryptionBackend):
"""Fernet-based encryption backend.
This backend uses the Python cryptography library's Fernet implementation
for encryption/decryption operations. Provides symmetric encryption with
built-in rotation support.
Attributes:
key (bytes): The base64-encoded key used for encryption and decryption.
fernet (cryptography.fernet.Fernet): The Fernet instance used for encryption/decryption.
"""
def mount_vault(self, key: "Union[str, bytes]") -> None:
"""Mounts the vault with the provided encryption key.
This method hashes the key using SHA256 before initializing the engine.
Args:
key (str | bytes): The encryption key.
"""
if isinstance(key, str):
key = key.encode()
digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) # pyright: ignore[reportPossiblyUnboundVariable]
digest.update(key)
engine_key = digest.finalize()
self.init_engine(engine_key)
def init_engine(self, key: "Union[bytes, str]") -> None:
"""Initializes the Fernet engine with the provided key.
Args:
key (bytes | str): The encryption key.
"""
if isinstance(key, str):
key = key.encode()
self.key = base64.urlsafe_b64encode(key)
self.fernet = Fernet(self.key) # pyright: ignore[reportPossiblyUnboundVariable]
def encrypt(self, value: Any) -> str:
"""Encrypts the given value using Fernet.
Args:
value (Any): The value to encrypt.
Returns:
str: The encrypted value.
Raises:
TypeError: If the value is not a string.
cryptography.fernet.InvalidToken: If encryption fails.
"""
if not isinstance(value, str):
value = repr(value)
value = value.encode()
encrypted = self.fernet.encrypt(value)
return encrypted.decode("utf-8")
def decrypt(self, value: Any) -> str:
"""Decrypts the given value using Fernet.
Args:
value (Any): The value to decrypt.
Returns:
str: The decrypted value.
Raises:
TypeError: If the value is not a string.
cryptography.fernet.InvalidToken: If decryption fails.
"""
if not isinstance(value, str): # pragma: nocover
value = str(value)
decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode())
if not isinstance(decrypted, str):
decrypted = decrypted.decode("utf-8") # pyright: ignore[reportAttributeAccessIssue]
return decrypted
DEFAULT_ENCRYPTION_KEY = os.urandom(32)
class EncryptedString(TypeDecorator[str]):
"""SQLAlchemy TypeDecorator for storing encrypted string values in a database.
This type provides transparent encryption/decryption of string values using the specified backend.
It extends :class:`sqlalchemy.types.TypeDecorator` and implements String as its underlying type.
Args:
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
length (int | None): The length of the unencrypted string. This is used for documentation and validation purposes only, as encrypted strings will be longer.
**kwargs (Any | None): Additional arguments passed to the underlying String type.
Attributes:
key (str | bytes | Callable[[], str | bytes]): The encryption key.
backend (EncryptionBackend): The encryption backend instance.
length (int | None): The unencrypted string length.
"""
impl = String
cache_ok = True
def __init__(
self,
key: "Union[str, bytes, Callable[[], Union[str, bytes]]]" = DEFAULT_ENCRYPTION_KEY,
backend: "type[EncryptionBackend]" = FernetBackend,
length: "Optional[int]" = None,
**kwargs: Any,
) -> None:
"""Initializes the EncryptedString TypeDecorator.
Args:
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
length (int | None): The length of the unencrypted string. This is used for documentation and validation purposes only.
**kwargs (Any | None): Additional arguments passed to the underlying String type.
"""
super().__init__()
self.key = key
self.backend = backend()
self.length = length
@property
def python_type(self) -> type[str]:
"""Returns the Python type for this type decorator.
Returns:
Type[str]: The Python string type.
"""
return str
def load_dialect_impl(self, dialect: "Dialect") -> Any:
"""Loads the appropriate dialect implementation based on the database dialect.
Note: The actual column length will be larger than the specified length due to encryption overhead.
For most encryption methods, the encrypted string will be approximately 1.35x longer than the original.
Args:
dialect (Dialect): The SQLAlchemy dialect.
Returns:
Any: The dialect-specific type descriptor.
"""
if dialect.name in {"mysql", "mariadb"}:
# For MySQL/MariaDB, always use Text to avoid length limitations
return dialect.type_descriptor(Text())
if dialect.name == "oracle":
# Oracle has a 4000-byte limit for VARCHAR2 (by default)
return dialect.type_descriptor(String(length=4000))
return dialect.type_descriptor(String())
def process_bind_param(self, value: Any, dialect: "Dialect") -> "Union[str, None]":
"""Processes the value before binding it to the SQL statement.
This method encrypts the value using the specified backend and validates length if specified.
Args:
value (Any): The value to process.
dialect (Dialect): The SQLAlchemy dialect.
Returns:
str | None: The encrypted value or None if the input is None.
Raises:
ValueError: If the value exceeds the specified length.
"""
if value is None:
return value
# Validate length if specified
if self.length is not None and len(str(value)) > self.length:
msg = f"Unencrypted value exceeds maximum unencrypted length of {self.length}"
raise IntegrityError(msg)
self.mount_vault()
return self.backend.encrypt(value)
def process_result_value(self, value: Any, dialect: "Dialect") -> "Union[str, None]":
"""Processes the value after retrieving it from the database.
This method decrypts the value using the specified backend.
Args:
value (Any): The value to process.
dialect (Dialect): The SQLAlchemy dialect.
Returns:
str | None: The decrypted value or None if the input is None.
"""
if value is None:
return value
self.mount_vault()
return self.backend.decrypt(value)
def mount_vault(self) -> None:
"""Mounts the vault with the encryption key.
If the key is callable, it is called to retrieve the key. Otherwise, the key is used directly.
"""
key = self.key() if callable(self.key) else self.key
self.backend.mount_vault(key)
class EncryptedText(EncryptedString):
"""SQLAlchemy TypeDecorator for storing encrypted text/CLOB values in a database.
This type provides transparent encryption/decryption of text values using the specified backend.
It extends :class:`EncryptedString` and implements Text as its underlying type.
This is suitable for storing larger encrypted text content compared to EncryptedString.
Args:
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
**kwargs (Any | None): Additional arguments passed to the underlying String type.
"""
impl = Text
cache_ok = True
def load_dialect_impl(self, dialect: "Dialect") -> Any:
"""Loads the appropriate dialect implementation for Text type.
Args:
dialect (Dialect): The SQLAlchemy dialect.
Returns:
Any: The dialect-specific Text type descriptor.
"""
return dialect.type_descriptor(Text())
python-advanced-alchemy-1.0.1/advanced_alchemy/types/guid.py 0000664 0000000 0000000 00000006453 14766637146 0024166 0 ustar 00root root 0000000 0000000 # ruff: noqa: FA100
from base64 import b64decode
from importlib.util import find_spec
from typing import Any, Optional, Union, cast
from uuid import UUID
from sqlalchemy.dialects.mssql import UNIQUEIDENTIFIER as MSSQL_UNIQUEIDENTIFIER
from sqlalchemy.dialects.oracle import RAW as ORA_RAW
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
from sqlalchemy.engine import Dialect
from sqlalchemy.types import BINARY, CHAR, TypeDecorator
from typing_extensions import Buffer
__all__ = ("GUID",)
UUID_UTILS_INSTALLED = find_spec("uuid_utils")
NANOID_INSTALLED = find_spec("fastnanoid")
class GUID(TypeDecorator[UUID]):
"""Platform-independent GUID type.
Uses PostgreSQL's UUID type (Postgres, DuckDB, Cockroach),
MSSQL's UNIQUEIDENTIFIER type, Oracle's RAW(16) type,
otherwise uses BINARY(16) or CHAR(32),
storing as stringified hex values.
Will accept stringified UUIDs as a hexstring or an actual UUID
"""
impl = BINARY(16)
cache_ok = True
@property
def python_type(self) -> type[UUID]:
return UUID
def __init__(self, *args: Any, binary: bool = True, **kwargs: Any) -> None:
self.binary = binary
def load_dialect_impl(self, dialect: Dialect) -> Any:
if dialect.name in {"postgresql", "duckdb", "cockroachdb"}:
return dialect.type_descriptor(PG_UUID())
if dialect.name == "oracle":
return dialect.type_descriptor(ORA_RAW(16))
if dialect.name == "mssql":
return dialect.type_descriptor(MSSQL_UNIQUEIDENTIFIER())
if self.binary:
return dialect.type_descriptor(BINARY(16))
return dialect.type_descriptor(CHAR(32))
def process_bind_param(
self,
value: Optional[Union[bytes, str, UUID]],
dialect: Dialect,
) -> Optional[Union[bytes, str]]:
if value is None:
return value
if dialect.name in {"postgresql", "duckdb", "cockroachdb", "mssql"}:
return str(value)
value = self.to_uuid(value)
if value is None:
return value
if dialect.name in {"oracle", "spanner+spanner"}:
return value.bytes
return value.bytes if self.binary else value.hex
def process_result_value(
self,
value: Optional[Union[bytes, str, UUID]],
dialect: Dialect,
) -> Optional[UUID]:
if value is None:
return value
if value.__class__.__name__ == "UUID":
return cast("UUID", value)
if dialect.name == "spanner+spanner":
return UUID(bytes=b64decode(cast("str | Buffer", value)))
if self.binary:
return UUID(bytes=cast("bytes", value))
return UUID(hex=cast("str", value))
@staticmethod
def to_uuid(value: Any) -> Optional[UUID]:
if value.__class__.__name__ == "UUID" or value is None:
return cast("Optional[UUID]", value)
try:
value = UUID(hex=value)
except (TypeError, ValueError):
value = UUID(bytes=value)
return cast("Optional[UUID]", value)
def compare_values(self, x: Any, y: Any) -> bool:
"""Compare two values for equality."""
if x.__class__.__name__ == "UUID" and y.__class__.__name__ == "UUID":
return cast("bool", x.bytes == y.bytes)
return cast("bool", x == y)
python-advanced-alchemy-1.0.1/advanced_alchemy/types/identity.py 0000664 0000000 0000000 00000000310 14766637146 0025051 0 ustar 00root root 0000000 0000000 from sqlalchemy.types import BigInteger, Integer
BigIntIdentity = BigInteger().with_variant(Integer, "sqlite")
"""A ``BigInteger`` variant that reverts to an ``Integer`` for unsupported variants."""
python-advanced-alchemy-1.0.1/advanced_alchemy/types/json.py 0000664 0000000 0000000 00000006160 14766637146 0024202 0 ustar 00root root 0000000 0000000 # ruff: noqa: FA100
from typing import Any, Optional, Union, cast
from sqlalchemy import text, util
from sqlalchemy.dialects.oracle import BLOB as ORA_BLOB
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
from sqlalchemy.engine import Dialect
from sqlalchemy.types import JSON as _JSON
from sqlalchemy.types import SchemaType, TypeDecorator, TypeEngine
from advanced_alchemy._serialization import decode_json, encode_json
__all__ = ("ORA_JSONB",)
class ORA_JSONB(TypeDecorator[dict[str, Any]], SchemaType): # noqa: N801
"""Oracle Binary JSON type.
JsonB = _JSON().with_variant(PG_JSONB, "postgresql").with_variant(ORA_JSONB, "oracle")
"""
impl = ORA_BLOB
cache_ok = True
@property
def python_type(self) -> type[dict[str, Any]]:
return dict
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize JSON type"""
self.name = kwargs.pop("name", None)
self.oracle_strict = kwargs.pop("oracle_strict", True)
def coerce_compared_value(self, op: Any, value: Any) -> Any:
return self.impl.coerce_compared_value(op=op, value=value) # type: ignore[no-untyped-call, call-arg]
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(ORA_BLOB())
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[Any]:
return value if value is None else encode_json(value)
def process_result_value(self, value: Union[bytes, None], dialect: Dialect) -> Optional[Any]:
if dialect.oracledb_ver < (2,): # type: ignore[attr-defined]
return value if value is None else decode_json(value)
return value
def _should_create_constraint(self, compiler: Any, **kw: Any) -> bool:
return cast("bool", compiler.dialect.name == "oracle")
def _variant_mapping_for_set_table(self, column: Any) -> Optional[dict[str, Any]]:
if column.type._variant_mapping: # noqa: SLF001
variant_mapping = dict(column.type._variant_mapping) # noqa: SLF001
variant_mapping["_default"] = column.type
else:
variant_mapping = None
return variant_mapping
@util.preload_module("sqlalchemy.sql.schema")
def _set_table(self, column: Any, table: Any) -> None:
schema = util.preloaded.sql_schema
variant_mapping = self._variant_mapping_for_set_table(column)
constraint_options = "(strict)" if self.oracle_strict else ""
sqltext = text(f"{column.name} is json {constraint_options}")
e = schema.CheckConstraint(
sqltext,
name=f"{column.name}_is_json",
_create_rule=util.portable_instancemethod( # type: ignore[no-untyped-call]
self._should_create_constraint,
{"variant_mapping": variant_mapping},
),
_type_bound=True,
)
table.append_constraint(e)
JsonB = (
_JSON().with_variant(PG_JSONB, "postgresql").with_variant(ORA_JSONB, "oracle").with_variant(PG_JSONB, "cockroachdb")
)
"""A JSON type that uses native ``JSONB`` where possible and ``Binary`` or ``Blob`` as
an alternative.
"""
python-advanced-alchemy-1.0.1/advanced_alchemy/utils/ 0000775 0000000 0000000 00000000000 14766637146 0022650 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/utils/__init__.py 0000664 0000000 0000000 00000000000 14766637146 0024747 0 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/advanced_alchemy/utils/dataclass.py 0000664 0000000 0000000 00000012124 14766637146 0025161 0 ustar 00root root 0000000 0000000 from dataclasses import Field, fields, is_dataclass
from inspect import isclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, final, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Set as AbstractSet
from typing_extensions import TypeAlias, TypeGuard
__all__ = (
"DataclassProtocol",
"Empty",
"EmptyType",
"extract_dataclass_fields",
"extract_dataclass_items",
"is_dataclass_class",
"is_dataclass_instance",
"simple_asdict",
)
@final
class Empty:
"""A sentinel class used as placeholder."""
EmptyType: "TypeAlias" = type[Empty]
"""Type alias for the :class:`~advanced_alchemy.utils.dataclass.Empty` sentinel class."""
@runtime_checkable
class DataclassProtocol(Protocol):
"""Protocol for instance checking dataclasses"""
__dataclass_fields__: "ClassVar[dict[str, Any]]"
def extract_dataclass_fields(
dt: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
include: "Optional[AbstractSet[str]]" = None,
exclude: "Optional[AbstractSet[str]]" = None,
) -> "tuple[Field[Any], ...]":
"""Extract dataclass fields.
Args:
dt: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
include: An iterable of fields to include.
exclude: An iterable of fields to exclude.
Returns:
A tuple of dataclass fields.
"""
include = include or set()
exclude = exclude or set()
if common := (include & exclude):
msg = f"Fields {common} are both included and excluded."
raise ValueError(msg)
dataclass_fields: Iterable[Field[Any]] = fields(dt)
if exclude_none:
dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not None)
if exclude_empty:
dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not Empty)
if include:
dataclass_fields = (field for field in dataclass_fields if field.name in include)
if exclude:
dataclass_fields = (field for field in dataclass_fields if field.name not in exclude)
return tuple(dataclass_fields)
def extract_dataclass_items(
dt: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
include: "Optional[AbstractSet[str]]" = None,
exclude: "Optional[AbstractSet[str]]" = None,
) -> tuple[tuple[str, Any], ...]:
"""Extract dataclass name, value pairs.
Unlike the 'asdict' method exports by the stdlib, this function does not pickle values.
Args:
dt: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
include: An iterable of fields to include.
exclude: An iterable of fields to exclude.
Returns:
A tuple of key/value pairs.
"""
dataclass_fields = extract_dataclass_fields(dt, exclude_none, exclude_empty, include, exclude)
return tuple((field.name, getattr(dt, field.name)) for field in dataclass_fields)
def simple_asdict(
obj: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
convert_nested: bool = True,
exclude: "Optional[AbstractSet[str]]" = None,
) -> "dict[str, Any]":
"""Convert a dataclass to a dictionary.
This method has important differences to the standard library version:
- it does not deepcopy values
- it does not recurse into collections
Args:
obj: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
convert_nested: Whether to recursively convert nested dataclasses.
exclude: An iterable of fields to exclude.
Returns:
A dictionary of key/value pairs.
"""
ret: dict[str, Any] = {}
for field in extract_dataclass_fields(obj, exclude_none, exclude_empty, exclude=exclude):
value = getattr(obj, field.name)
if is_dataclass_instance(value) and convert_nested:
ret[field.name] = simple_asdict(value, exclude_none, exclude_empty)
else:
ret[field.name] = getattr(obj, field.name)
return ret
def is_dataclass_instance(obj: Any) -> "TypeGuard[DataclassProtocol]":
"""Check if an object is a dataclass instance.
Args:
obj: An object to check.
Returns:
True if the object is a dataclass instance.
"""
return hasattr(type(obj), "__dataclass_fields__") # pyright: ignore[reportUnknownArgumentType]
def is_dataclass_class(annotation: Any) -> "TypeGuard[type[DataclassProtocol]]":
"""Wrap :func:`is_dataclass ` in a :data:`typing.TypeGuard`.
Args:
annotation: tested to determine if instance or type of :class:`dataclasses.dataclass`.
Returns:
``True`` if instance or type of ``dataclass``.
"""
try:
return isclass(annotation) and is_dataclass(annotation)
except TypeError: # pragma: no cover
return False
python-advanced-alchemy-1.0.1/advanced_alchemy/utils/deprecation.py 0000664 0000000 0000000 00000007507 14766637146 0025530 0 ustar 00root root 0000000 0000000 import inspect
from functools import wraps
from typing import Callable, Literal, Optional
from warnings import warn
from typing_extensions import ParamSpec, TypeVar
__all__ = ("deprecated", "warn_deprecation")
T = TypeVar("T")
P = ParamSpec("P")
DeprecatedKind = Literal["function", "method", "classmethod", "attribute", "property", "class", "parameter", "import"]
def warn_deprecation(
version: str,
deprecated_name: str,
kind: DeprecatedKind,
*,
removal_in: Optional[str] = None,
alternative: Optional[str] = None,
info: Optional[str] = None,
pending: bool = False,
) -> None:
"""Warn about a call to a (soon to be) deprecated function.
Args:
version: Advanced Alchemy version where the deprecation will occur
deprecated_name: Name of the deprecated function
removal_in: Advanced Alchemy version where the deprecated function will be removed
alternative: Name of a function that should be used instead
info: Additional information
pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning`
kind: Type of the deprecated thing
"""
parts = []
if kind == "import":
access_type = "Import of"
elif kind in {"function", "method"}:
access_type = "Call to"
else:
access_type = "Use of"
if pending:
parts.append(f"{access_type} {kind} awaiting deprecation {deprecated_name!r}") # pyright: ignore[reportUnknownMemberType]
else:
parts.append(f"{access_type} deprecated {kind} {deprecated_name!r}") # pyright: ignore[reportUnknownMemberType]
parts.extend( # pyright: ignore[reportUnknownMemberType]
(
f"Deprecated in advanced-alchemy {version}",
f"This {kind} will be removed in {removal_in or 'the next major version'}",
),
)
if alternative:
parts.append(f"Use {alternative!r} instead") # pyright: ignore[reportUnknownMemberType]
if info:
parts.append(info) # pyright: ignore[reportUnknownMemberType]
text = ". ".join(parts) # pyright: ignore[reportUnknownArgumentType]
warning_class = PendingDeprecationWarning if pending else DeprecationWarning
warn(text, warning_class, stacklevel=2)
def deprecated(
version: str,
*,
removal_in: Optional[str] = None,
alternative: Optional[str] = None,
info: Optional[str] = None,
pending: bool = False,
kind: Optional[Literal["function", "method", "classmethod", "property"]] = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Create a decorator wrapping a function, method or property with a warning call about a (pending) deprecation.
Args:
version: Advanced Alchemy version where the deprecation will occur
removal_in: Advanced Alchemy version where the deprecated function will be removed
alternative: Name of a function that should be used instead
info: Additional information
pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning`
kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure
out if it's a function or method
Returns:
A decorator wrapping the function call with a warning
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
warn_deprecation(
version=version,
deprecated_name=func.__name__,
info=info,
alternative=alternative,
pending=pending,
removal_in=removal_in,
kind=kind or ("method" if inspect.ismethod(func) else "function"),
)
return func(*args, **kwargs)
return wrapped
return decorator
python-advanced-alchemy-1.0.1/advanced_alchemy/utils/fixtures.py 0000664 0000000 0000000 00000004231 14766637146 0025073 0 ustar 00root root 0000000 0000000 from typing import TYPE_CHECKING, Any, Union
from advanced_alchemy._serialization import decode_json
from advanced_alchemy.exceptions import MissingDependencyError
if TYPE_CHECKING:
from pathlib import Path
from anyio import Path as AsyncPath
__all__ = ("open_fixture", "open_fixture_async")
def open_fixture(fixtures_path: "Union[Path, AsyncPath]", fixture_name: str) -> Any:
"""Loads JSON file with the specified fixture name
Args:
fixtures_path: :class:`pathlib.Path` | :class:`anyio.Path` The path to look for fixtures
fixture_name (str): The fixture name to load.
Raises:
:class:`FileNotFoundError`: Fixtures not found.
Returns:
Any: The parsed JSON data
"""
from pathlib import Path
fixture = Path(fixtures_path / f"{fixture_name}.json")
if fixture.exists():
with fixture.open(mode="r", encoding="utf-8") as f:
f_data = f.read()
return decode_json(f_data)
msg = f"Could not find the {fixture_name} fixture"
raise FileNotFoundError(msg)
async def open_fixture_async(fixtures_path: "Union[Path, AsyncPath]", fixture_name: str) -> Any:
"""Loads JSON file with the specified fixture name
Args:
fixtures_path: :class:`pathlib.Path` | :class:`anyio.Path` The path to look for fixtures
fixture_name (str): The fixture name to load.
Raises:
:class:`~advanced_alchemy.exceptions.MissingDependencyError`: The `anyio` library is required to use this function.
:class:`FileNotFoundError`: Fixtures not found.
Returns:
Any: The parsed JSON data
"""
try:
from anyio import Path as AsyncPath
except ImportError as exc:
msg = "The `anyio` library is required to use this function. Please install it with `pip install anyio`."
raise MissingDependencyError(msg) from exc
fixture = AsyncPath(fixtures_path / f"{fixture_name}.json")
if await fixture.exists():
async with await fixture.open(mode="r", encoding="utf-8") as f:
f_data = await f.read()
return decode_json(f_data)
msg = f"Could not find the {fixture_name} fixture"
raise FileNotFoundError(msg)
python-advanced-alchemy-1.0.1/advanced_alchemy/utils/module_loader.py 0000664 0000000 0000000 00000005213 14766637146 0026036 0 ustar 00root root 0000000 0000000 """General utility functions."""
import sys
from importlib import import_module
from importlib.util import find_spec
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from types import ModuleType
__all__ = (
"import_string",
"module_to_os_path",
)
def module_to_os_path(dotted_path: str = "app") -> Path:
"""Find Module to OS Path.
Return a path to the base directory of the project or the module
specified by `dotted_path`.
Args:
dotted_path: The path to the module. Defaults to "app".
Raises:
TypeError: The module could not be found.
Returns:
Path: The path to the module.
"""
try:
if (src := find_spec(dotted_path)) is None: # pragma: no cover
msg = f"Couldn't find the path for {dotted_path}"
raise TypeError(msg)
except ModuleNotFoundError as e:
msg = f"Couldn't find the path for {dotted_path}"
raise TypeError(msg) from e
path = Path(str(src.origin))
return path.parent if path.is_file() else path
def import_string(dotted_path: str) -> Any:
"""Dotted Path Import.
Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import failed.
Args:
dotted_path: The path of the module to import.
Raises:
ImportError: Could not import the module.
Returns:
object: The imported object.
"""
def _is_loaded(module: "Optional[ModuleType]") -> bool:
spec = getattr(module, "__spec__", None)
initializing = getattr(spec, "_initializing", False)
return bool(module and spec and not initializing)
def _cached_import(module_path: str, class_name: str) -> Any:
"""Import and cache a class from a module.
Args:
module_path: dotted path to module.
class_name: Class or function name.
Returns:
object: The imported class or function
"""
# Check whether module is loaded and fully initialized.
module = sys.modules.get(module_path)
if not _is_loaded(module):
module = import_module(module_path)
return getattr(module, class_name)
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError as e:
msg = "%s doesn't look like a module path"
raise ImportError(msg, dotted_path) from e
try:
return _cached_import(module_path, class_name)
except AttributeError as e:
msg = "Module '%s' does not define a '%s' attribute/class"
raise ImportError(msg, module_path, class_name) from e
python-advanced-alchemy-1.0.1/advanced_alchemy/utils/portals.py 0000664 0000000 0000000 00000014765 14766637146 0024723 0 ustar 00root root 0000000 0000000 """This module provides a portal provider and portal for calling async functions from synchronous code."""
import asyncio
import functools
import queue
import threading
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, cast
from warnings import warn
from advanced_alchemy.exceptions import ImproperConfigurationError
if TYPE_CHECKING:
from collections.abc import Coroutine
__all__ = ("Portal", "PortalProvider", "PortalProviderSingleton")
_R = TypeVar("_R")
class PortalProviderSingleton(type):
"""A singleton metaclass for PortalProvider."""
_instances: "ClassVar[dict[type, PortalProvider]]" = {}
def __call__(cls, *args: Any, **kwargs: Any) -> "PortalProvider":
if cls not in cls._instances: # pyright: ignore[reportUnnecessaryContains]
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls] # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
class PortalProvider(metaclass=PortalProviderSingleton):
"""A provider for creating and managing threaded portals."""
def __init__(self) -> None:
"""Initialize the PortalProvider."""
self._request_queue: queue.Queue[
tuple[
Callable[..., Coroutine[Any, Any, Any]],
tuple[Any, ...],
dict[str, Any],
queue.Queue[tuple[Optional[Any], Optional[Exception]]],
]
] = queue.Queue()
self._result_queue: queue.Queue[tuple[Optional[Any], Optional[Exception]]] = queue.Queue()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._thread: Optional[threading.Thread] = None
self._ready_event: threading.Event = threading.Event()
@property
def portal(self) -> "Portal":
"""The portal instance."""
return Portal(self)
@property
def is_running(self) -> bool:
"""Whether the portal provider is running."""
return self._thread is not None and self._thread.is_alive()
@property
def is_ready(self) -> bool:
"""Whether the portal provider is ready."""
return self._ready_event.is_set()
@property
def loop(self) -> "asyncio.AbstractEventLoop": # pragma: no cover
"""The event loop."""
if self._loop is None:
msg = "The PortalProvider is not started. Did you forget to call .start()?"
raise ImproperConfigurationError(msg)
return self._loop
def start(self) -> None:
"""Starts the background thread and event loop."""
if self._thread is not None: # pragma: no cover
warn("PortalProvider already started", stacklevel=2)
return
self._thread = threading.Thread(target=self._run_event_loop, daemon=True)
self._thread.start()
self._ready_event.wait() # Wait for the loop to be ready
def stop(self) -> None:
"""Stops the background thread and event loop."""
if self._loop is None or self._thread is None:
return
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
self._loop.close()
self._loop = None
self._thread = None
self._ready_event.clear()
def _run_event_loop(self) -> None: # pragma: no cover
"""The main function of the background thread."""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._ready_event.set() # Signal that the loop is ready
self._loop.run_forever()
async def _async_caller(
self,
func: "Callable[..., Coroutine[Any, Any, _R]]",
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> _R:
"""Wrapper to run the async function and send the result to the result queue."""
result: _R = await func(*args, **kwargs)
return result
def call(self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, **kwargs: Any) -> _R:
"""Calls an async function from a synchronous context.
Args:
func: The async function to call.
*args: Positional arguments to the function.
**kwargs: Keyword arguments to the function.
Returns:
The result of the async function.
Raises:
Exception: If the async function raises an exception.
"""
if self._loop is None:
msg = "The PortalProvider is not started. Did you forget to call .start()?"
raise ImproperConfigurationError(msg)
# Create a new result queue
local_result_queue: queue.Queue[tuple[Optional[_R], Optional[Exception]]] = queue.Queue()
# Send the request to the background thread
self._request_queue.put((func, args, kwargs, local_result_queue))
# Trigger the execution in the event loop
_handle = self._loop.call_soon_threadsafe(self._process_request)
# Wait for the result from the background thread
result, exception = local_result_queue.get()
if exception:
raise exception
return cast("_R", result)
def _process_request(self) -> None: # pragma: no cover
"""Processes a request from the request queue in the event loop."""
assert self._loop is not None # noqa: S101
if not self._request_queue.empty():
func, args, kwargs, local_result_queue = self._request_queue.get()
future = asyncio.run_coroutine_threadsafe(self._async_caller(func, args, kwargs), self._loop)
# Attach a callback to handle the result/exception
future.add_done_callback(
functools.partial(self._handle_future_result, local_result_queue=local_result_queue), # pyright: ignore[reportArgumentType]
)
def _handle_future_result(
self,
future: "asyncio.Future[Any]",
local_result_queue: "queue.Queue[tuple[Optional[Any], Optional[Exception]]]",
) -> None: # pragma: no cover
"""Handles the result or exception from the completed future."""
try:
result = future.result()
local_result_queue.put((result, None))
except Exception as e: # noqa: BLE001
local_result_queue.put((None, e))
class Portal:
def __init__(self, provider: "PortalProvider") -> None:
self._provider = provider
def call(self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, **kwargs: Any) -> _R:
"""Calls an async function using the associated PortalProvider."""
return self._provider.call(func, *args, **kwargs)
python-advanced-alchemy-1.0.1/advanced_alchemy/utils/text.py 0000664 0000000 0000000 00000002702 14766637146 0024207 0 ustar 00root root 0000000 0000000 """General utility functions."""
import re
import unicodedata
from typing import Optional
__all__ = (
"check_email",
"slugify",
)
def check_email(email: str) -> str:
"""Validate an email."""
if "@" not in email:
msg = "Invalid email!"
raise ValueError(msg)
return email.lower()
def slugify(value: str, allow_unicode: bool = False, separator: Optional[str] = None) -> str:
"""Slugify.
Convert to ASCII if ``allow_unicode`` is ``False``. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Args:
value (str): the string to slugify
allow_unicode (bool, optional): allow unicode characters in slug. Defaults to False.
separator (str, optional): by default a `-` is used to delimit word boundaries.
Set this to configure something different.
Returns:
str: a slugified string of the value parameter
"""
if allow_unicode:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[^\w\s-]", "", value.lower())
if separator is not None:
return re.sub(r"[-\s]+", "-", value).strip("-_").replace("-", separator)
return re.sub(r"[-\s]+", "-", value).strip("-_")
python-advanced-alchemy-1.0.1/codecov.yml 0000664 0000000 0000000 00000000253 14766637146 0020406 0 ustar 00root root 0000000 0000000 coverage:
status:
project:
default:
target: auto
threshold: 2%
patch:
default:
target: auto
comment:
require_changes: true
python-advanced-alchemy-1.0.1/docs/ 0000775 0000000 0000000 00000000000 14766637146 0017171 5 ustar 00root root 0000000 0000000 python-advanced-alchemy-1.0.1/docs/Makefile 0000664 0000000 0000000 00000001172 14766637146 0020632 0 ustar 00root root 0000000 0000000 # Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
python-advanced-alchemy-1.0.1/docs/PYPI_README.md 0000664 0000000 0000000 00000002234 14766637146 0021312 0 ustar 00root root 0000000 0000000