pax_global_header00006660000000000000000000000064147357352700014527gustar00rootroot0000000000000052 comment=c84730e2a4ea7f8a085924ed42d6e64de7d88246 python-0.20.7/000077500000000000000000000000001473573527000131365ustar00rootroot00000000000000python-0.20.7/.editorconfig000066400000000000000000000005271473573527000156170ustar00rootroot00000000000000root = true [*] indent_style = tab indent_size = 4 end_of_line = lf charset = utf-8 trim_trailing_whitespace = true insert_final_newline = true [*.py] max_line_length = 99 [*.md] trim_trailing_whitespace = false [*.{yaml,yml,py,pyi}] indent_style = space [{.gitlab-ci.yml,.github/workflows/*.yml,.pre-commit-config.yaml}] indent_size = 2 python-0.20.7/.github/000077500000000000000000000000001473573527000144765ustar00rootroot00000000000000python-0.20.7/.github/workflows/000077500000000000000000000000001473573527000165335ustar00rootroot00000000000000python-0.20.7/.github/workflows/python-package.yml000066400000000000000000000033451473573527000221750ustar00rootroot00000000000000name: Python package on: [push, pull_request] jobs: build: runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install libolm run: sudo apt-get install libolm3 - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install python-olm --extra-index-url https://gitlab.matrix.org/api/v4/projects/27/packages/pypi/simple python -m pip install .[test] - name: Test with pytest run: | export MEOW_TEST_PG_URL=postgres://meow:meow@localhost/meow pytest services: postgres: image: postgres env: POSTGRES_USER: meow POSTGRES_PASSWORD: meow POSTGRES_DB: meow ports: - 5432:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: "3.13" - uses: isort/isort-action@master with: sortPaths: "./mautrix" - uses: psf/black@stable with: src: "./mautrix" version: "24.10.0" - name: pre-commit run: | pip install pre-commit pre-commit run -av trailing-whitespace pre-commit run -av end-of-file-fixer pre-commit run -av check-yaml pre-commit run -av check-added-large-files python-0.20.7/.gitignore000066400000000000000000000001041473573527000151210ustar00rootroot00000000000000build/ dist/ *.egg-info .venv pip-selfcheck.json *.pyc __pycache__ python-0.20.7/.gitlab-ci.yml000066400000000000000000000020111473573527000155640ustar00rootroot00000000000000build docs builder: stage: build image: docker:stable tags: - amd64 only: refs: - master changes: - docs/Dockerfile - docs/requirements.txt before_script: - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY script: - cd docs - docker build --tag $CI_REGISTRY_IMAGE/doc-builder:latest . - docker push $CI_REGISTRY_IMAGE/doc-builder:latest - docker rmi $CI_REGISTRY_IMAGE/doc-builder:latest build docs: stage: deploy image: dock.mau.dev/mautrix/python/doc-builder tags: - webdeploy only: - master script: - cd docs - make html - mkdir -p /srv/web/docs.mau.fi/python/latest/ - rsync -rcthvl --delete _build/html/ /srv/web/docs.mau.fi/python/latest/ build tag docs: stage: deploy image: dock.mau.dev/mautrix/python/doc-builder tags: - webdeploy only: - tags script: - cd docs - make html - mkdir -p /srv/web/docs.mau.fi/python/$CI_COMMIT_TAG/ - rsync -rcthvl --delete _build/html/ /srv/web/docs.mau.fi/python/$CI_COMMIT_TAG/ python-0.20.7/.pre-commit-config.yaml000066400000000000000000000010071473573527000174150ustar00rootroot00000000000000repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files - repo: https://github.com/psf/black rev: 24.10.0 hooks: - id: black language_version: python3 files: ^mautrix/.*\.pyi?$ - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: - id: isort files: ^mautrix/.*\.pyi?$ python-0.20.7/CHANGELOG.md000066400000000000000000001551421473573527000147570ustar00rootroot00000000000000## v0.20.7 (2025-01-03) * *(types)* Removed support for generating reply fallbacks to implement [MSC2781]. Stripping fallbacks is still supported. [MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 ## v0.20.6 (2024-07-12) * *(bridge)* Added `/register` call if `/versions` fails with `M_FORBIDDEN`. ## v0.20.5 (2024-07-09) **Note:** The `bridge` module is deprecated as all bridges are being rewritten in Go. See for more info. * *(client)* Added support for authenticated media downloads. * *(bridge)* Stopped using cached homeserver URLs for double puppeting if one is set in the config file. * *(crypto)* Fixed error when checking OTK counts before uploading new keys. * *(types)* Added MSC2530 (captions) fields to `MediaMessageEventContent`. ## v0.20.4 (2024-01-09) * Dropped Python 3.9 support. * *(client)* Changed media download methods to log requests and to raise exceptions on non-successful status codes. ## v0.20.3 (2023-11-10) * *(client)* Deprecated MSC2716 methods and added new Beeper-specific batch send methods, as upstream MSC2716 support has been abandoned. * *(util.async_db)* Added `PRAGMA synchronous = NORMAL;` to default pragmas. * *(types)* Fixed `guest_can_join` field name in room directory response (thanks to [@ashfame] in [#163]). [@ashfame]: https://github.com/ashfame [#163]: https://github.com/mautrix/python/pull/163 ## v0.20.2 (2023-09-09) * *(crypto)* Changed `OlmMachine.share_keys` to make the OTK count parameter optional. When omitted, the count is fetched from the server. * *(appservice)* Added option to run appservice transaction event handlers synchronously. * *(appservice)* Added `log` and `hs_token` parameters to `AppServiceServerMixin` to allow using it as a standalone class without extending. * *(api)* Added support for setting appservice `user_id` and `device_id` query parameters manually without using `AppServiceAPI`. ## v0.20.1 (2023-08-29) * *(util.program)* Removed `--base-config` flag in bridges, as there are no valid use cases (package data should always work) and it's easy to cause issues by pointing the flag at the wrong file. * *(bridge)* Added support for the `com.devture.shared_secret_auth` login type for automatic double puppeting. * *(bridge)* Dropped support for syncing with double puppets. MSC2409 is now the only way to receive ephemeral events. * *(bridge)* Added support for double puppeting with arbitrary `as_token`s. ## v0.20.0 (2023-06-25) * Dropped Python 3.8 support. * **Breaking change *(.state_store)*** Removed legacy SQLAlchemy state store implementations. * **Mildly breaking change *(util.async_db)*** Changed `SQLiteDatabase` to not remove prefix slashes from database paths. * Library users should use `sqlite:path.db` instead of `sqlite:///path.db` for relative paths, and `sqlite:/path.db` instead of `sqlite:////path.db` for absolute paths. * Bridge configs do this migration automatically. * *(util.async_db)* Added warning log if using SQLite database path that isn't writable. * *(util.program)* Fixed `manual_stop` not working if it's called during startup. * *(client)* Stabilized support for asynchronous uploads. * `unstable_create_msc` was renamed to `create_mxc`, and the `max_stall_ms` parameters for downloading were renamed to `timeout_ms`. * *(crypto)* Added option to not rotate keys when devices change. * *(crypto)* Added option to remove all keys that were received before the automatic ratcheting was implemented (in v0.19.10). * *(types)* Improved reply fallback removal to have a smaller chance of false positives for messages that don't use reply fallbacks. ## v0.19.16 (2023-05-26) * *(appservice)* Fixed Python 3.8 compatibility. ## v0.19.15 (2023-05-24) * *(client)* Fixed dispatching room ephemeral events (i.e. typing notifications) in syncer. ## v0.19.14 (2023-05-16) * *(bridge)* Implemented appservice pinging using MSC2659. * *(bridge)* Started reusing aiosqlite connection pool for crypto db. * This fixes the crypto pool getting stuck if the bridge exits unexpectedly (the default pool is closed automatically at any type of exit). ## v0.19.13 (2023-04-24) * *(crypto)* Fixed bug with redacting megolm sessions when device is deleted. ## v0.19.12 (2023-04-18) * *(bridge)* Fixed backwards-compatibility with new key deletion config options. ## v0.19.11 (2023-04-14) * *(crypto)* Fixed bug in previous release which caused errors if the `max_age` of a megolm session was not known. * *(crypto)* Changed key receiving handler to fetch encryption config from server if it's not cached locally (to find `max_age` and `max_messages` more reliably). ## v0.19.10 (2023-04-13) * *(crypto, bridge)* Added options to automatically ratchet/delete megolm sessions to minimize access to old messages. ## v0.19.9 (2023-04-12) * *(crypto)* Fixed bug in crypto store migration when using outbound sessions with max age higher than usual. ## v0.19.8 (2023-04-06) * *(crypto)* Updated crypto store schema to match mautrix-go. * *(types)* Fixed `set_thread_parent` adding reply fallbacks to the message body. ## v0.19.7 (2023-03-22) * *(bridge, crypto)* Fixed key sharing trust checker not resolving cross-signing signatures when minimum trust level is set to cross-signed. ## v0.19.6 (2023-03-13) * *(crypto)* Added cache checks to prevent invalidating group session when the server sends a duplicate member event in /sync. * *(util.proxy)* Fixed `min_wait_seconds` behavior and added `max_wait_seconds` and `multiply_wait_seconds` to `proxy_with_retry`. ## v0.19.5 (2023-03-07) * *(util.proxy)* Added utility for dynamic proxies (from mautrix-instagram/facebook). * *(types)* Added default value for `upload_size` in `MediaRepoConfig` as the field is optional in the spec. * *(bridge)* Changed ghost invite handling to only process one per room at a time (thanks to [@maltee1] in [#132]). [#132]: https://github.com/mautrix/python/pull/132 ## v0.19.4 (2023-02-12) * *(types)* Changed `set_thread_parent` to inherit the existing thread parent if a `MessageEvent` is passed, as starting threads from a message in a thread is not allowed. * *(util.background_task)* Added new utility for creating background tasks safely, by ensuring that the task is not garbage collected before finishing and logging uncaught exceptions immediately. ## v0.19.3 (2023-01-27) * *(bridge)* Bumped default timeouts for decrypting incoming messages. ## v0.19.2 (2023-01-14) * *(util.async_body)* Added utility for reading aiohttp response into a bytearray (so that the output is mutable, e.g. for decrypting or encrypting media). * *(client.api)* Fixed retry loop for MSC3870 URL uploads not exiting properly after too many errors. ## v0.19.1 (2023-01-11) * Marked Python 3.11 as supported. Python 3.8 support will likely be dropped in the coming months. * *(client.api)* Added request payload memory optimization to MSC3870 URL uploads. * aiohttp will duplicate the entire request body if it's raw bytes, which wastes a lot of memory. The optimization is passing an iterator instead of raw bytes, so aiohttp won't accidentally duplicate the whole thing. * The main `HTTPAPI` has had the optimization for a while, but uploading to URL calls aiohttp manually. ## v0.19.0 (2023-01-10) * **Breaking change *(appservice)*** Removed typing status from state store. * **Breaking change *(appservice)*** Removed `is_typing` parameter from `IntentAPI.set_typing` to make the signature match `ClientAPI.set_typing`. `timeout=0` is equivalent to the old `is_typing=False`. * **Breaking change *(types)*** Removed legacy fields in Beeper MSS events. * *(bridge)* Removed accidentally nested reply loop when accepting invites as the bridge bot. * *(bridge)* Fixed decoding JSON values in config override env vars. ## v0.18.9 (2022-12-14) * *(util.async_db)* Changed aiosqlite connector to force-enable foreign keys, WAL mode and busy_timeout. * The values can be changed by manually specifying the same PRAGMAs in the `init_commands` db arg, e.g. `- PRAGMA foreign_keys = OFF`. * *(types)* Added workaround to `StateEvent.deserialize` to handle Conduit's broken `unsigned` fields. * *(client.state_store)* Fixed `set_power_level` to allow raw dicts the same way as `set_encryption_info` does (thanks to [@bramenn] in [#127]). [@bramenn]: https://github.com/bramenn [#127]: https://github.com/mautrix/python/pull/127 ## v0.18.8 (2022-11-18) * *(crypto.store.asyncpg)* Fixed bug causing `put_group_session` to fail when trying to log unique key errors. * *(client)* Added wrapper for `create_room` to update the state store with initial state and invites (applies to anything extending `StoreUpdatingAPI`, such as the high-level `Client` and appservice `IntentAPI` classes). ## v0.18.7 (2022-11-08) ## v0.18.6 (2022-10-24) * *(util.formatter)* Added conversion method for `
` tag and defaulted to converting back to `---`. ## v0.18.5 (2022-10-20) * *(appservice)* Added try blocks around [MSC3202] handler functions to log errors instead of failing the entire transaction. This matches the behavior of errors in normal appservice event handlers. ## v0.18.4 (2022-10-13) * *(client.api)* Added option to pass custom data to `/createRoom` to enable using custom fields and testing MSCs without changing the library. * *(client.api)* Updated [MSC3870] support to send file name in upload complete call. * *(types)* Changed `set_edit` to clear reply metadata as edits can't change the reply status. * *(util.formatter)* Fixed edge case causing negative entity lengths when splitting entity strings. ## v0.18.3 (2022-10-11) * *(util.async_db)* Fixed mistake in default no-op database error handler causing the wrong exception to be raised. * *(crypto.store.asyncpg)* Updated `put_group_session` to catch unique key errors and log instead of raising. * *(client.api)* Updated [MSC3870] support to catch and retry on all connection errors instead of only non-200 status codes when uploading. ## v0.18.2 (2022-09-24) * *(crypto)* Fixed handling key requests when using appservice-mode (MSC2409) encryption. * *(appservice)* Added workaround for dumb servers that send `"unsigned": null` in events. ## v0.18.1 (2022-09-15) * *(crypto)* Fixed error sharing megolm session if a single recipient device has ran out of one-time keys. ## v0.18.0 (2022-09-15) * **Breaking change *(util.async_db)*** Added checks to prevent calling `.start()` on a database multiple times. * *(appservice)* Fixed [MSC2409] support to read to-device events from the correct field. * *(appservice)* Added support for automatically calling functions when a transaction contains [MSC2409] to-device events or [MSC3202] encryption data. * *(bridge)* Added option to use [MSC2409] and [MSC3202] for end-to-bridge encryption. However, this may not work with the Synapse implementation as it hasn't been tested yet. * *(bridge)* Replaced `homeserver` -> `asmux` flag with more generic `software` field. * *(bridge)* Added support for overriding parts of config with environment variables. * If the value starts with `json::`, it'll be parsed as JSON instead of using as a raw string. * *(client.api)* Added support for [MSC3870] for both uploading and downloading media. * *(types)* Added `knock_restricted` join rule to `JoinRule` enum. * *(crypto)* Added warning logs if claiming one-time keys for other users fails. [MSC3870]: https://github.com/matrix-org/matrix-spec-proposals/pull/3870 ## v0.17.8 (2022-08-22) * *(crypto)* Fixed parsing `/keys/claim` responses with no `failures` field. * *(bridge)* Fixed parsing e2ee key sharing allow/minimum level config. ## v0.17.7 (2022-08-22) * *(util.async_db)* Added `init_commands` to run commands on each SQLite connection (e.g. to enable `PRAGMA`s). No-op on Postgres. * *(bridge)* Added check to make sure e2ee keys are intact on server. If they aren't, the crypto database will be wiped and the bridge will stop. ## v0.17.6 (2022-08-17) * *(bridge)* Added hidden option to use appservice login for double puppeting. * *(client)* Fixed sync handling throwing an error if event parsing failed. * *(errors)* Added `M_UNKNOWN_ENDPOINT` error code from [MSC3743] * *(appservice)* Updated [MSC3202] support to handle one time keys correctly. [MSC3743]: https://github.com/matrix-org/matrix-spec-proposals/pull/3743 ## v0.17.5 (2022-08-15) * *(types)* Added `m.read.private` to receipt types. * *(appservice)* Stopped `ensure_registered` and `invite_user` raising `IntentError`s (now they raise the original Matrix error instead). ## v0.17.4 (2022-07-28) * *(bridge)* Started rejecting reusing access tokens when enabling double puppeting. Reuse is detected by presence of encryption keys on the device. * *(client.api)* Added wrapper method for the `/context` API. * *(api, errors)* Implemented new error codes from [MSC3848]. * *(types)* Disabled deserializing `m.direct` content (it didn't work and it wasn't really necessary). * *(client.state_store)* Updated `set_encryption_info` to allow raw dicts. This fixes the bug where sending a `m.room.encryption` event with a raw dict as the content would throw an error from the state store. * *(crypto)* Fixed error when fetching keys for user with no cross-signing keys (thanks to [@maltee1] in [#109]). [MSC3848]: https://github.com/matrix-org/matrix-spec-proposals/pull/3848 [#109]: https://github.com/mautrix/python/pull/109 ## v0.17.3 (2022-07-12) * *(types)* Updated `BeeperMessageStatusEventContent` fields. ## v0.17.2 (2022-07-06) * *(api)* Updated request logging to log full URL instead of only path. * *(bridge)* Fixed migrating key sharing allow flag to new config format. * *(appservice)* Added `beeper_new_messages` flag for `batch_send` method. ## v0.17.1 (2022-07-05) * *(crypto)* Fixed Python 3.8/9 compatibility broken in v0.17.0. * *(crypto)* Added some tests for attachments and store code. * *(crypto)* Improved logging when device change validation fails. ## v0.17.0 (2022-07-05) * **Breaking change *(bridge)*** Added options to check cross-signing status for bridge users. This requires changes to the base config. * New options include requiring cross-signed devices (with TOFU) for sending and/or receiving messages, and an option to drop any unencrypted messages. * **Breaking change *(crypto)*** Removed `sender_key` parameter from CryptoStore's `has_group_session` and `put_group_session`, and also OlmMachine's `wait_for_session`. * **Breaking change *(crypto.store.memory)*** Updated the key of the `_inbound_sessions` dict to be (room_id, session_id), removing the identity key in the middle. This only affects custom stores based on the memory store. * *(crypto)* Added basic cross-signing validation code. * *(crypto)* Marked device_id and sender_key as deprecated in Megolm events as per Matrix 1.3. * *(api)* Bumped request logs to `DEBUG` level. * Also added new `sensitive` parameter to the `request` method to prevent logging content in sensitive requests. The `login` method was updated to mark the content as sensitive if a password or token is provided. * *(bridge.commands)* Switched the order of the user ID parameter in `set-pl`, `set-avatar` and `set-displayname`. ## v0.16.11 (2022-06-28) * *(appservice)* Fixed the `extra_content` parameter in membership methods causing duplicate join events through the `ensure_joined` mechanism. ## v0.16.10 (2022-06-24) * *(bridge)* Started requiring Matrix v1.1 support from homeservers. * *(bridge)* Added hack to automatically send a read receipt for messages sent to Matrix with double puppeting (to work around weird unread count issues). ## v0.16.9 (2022-06-22) * *(client)* Added support for knocking on rooms (thanks to [@maltee1] in [#105]). * *(bridge)* Added config option to set key rotation settings with e2be. [#105]: https://github.com/mautrix/python/pull/105 ## v0.16.8 (2022-06-20) * *(bridge)* Updated e2be helper to stop bridge if syncing fails. * *(util.async_db)* Updated asyncpg connector to stop program if an asyncpg `InternalClientError` is thrown. These errors usually cause everything to get stuck. * The behavior can be disabled by passing `meow_exit_on_ice` = `false` in the `db_args`. ## v0.16.7 (2022-06-19) * *(util.formatter)* Added support for parsing `img` tags * By default, the `alt` or `title` attribute will be used as plaintext. * *(types)* Added `notifications` object to power level content class. * *(bridge)* Added utility methods for handling incoming knocks in `MatrixHandler` (thanks to [@maltee1] in [#103]). * *(appservice)* Updated `IntentAPI` to add the `fi.mau.double_puppet_source` to all state events sent with double puppeted intents (previously it was only added to non-state events). [#103]: https://github.com/mautrix/python/pull/103 ## v0.16.6 (2022-06-02) * *(bridge)* Fixed double puppeting `start` method not handling some errors from /whoami correctly. * *(types)* Added `com.beeper.message_send_status` event type for bridging status. ## v0.16.5 (2022-05-26) * *(bridge.commands)* Added `reason` field for `CommandEvent.redact`. * *(client.api)* Added `reason` field for the `unban_user` method (thanks to [@maltee1] in [#101]). * *(bridge)* Changed automatic DM portal creation to only apply when the invite event specifies `"is_direct": true` (thanks to [@maltee1] in [#102]). * *(util.program)* Changed `Program` to use create and set an event loop explicitly instead of using `get_event_loop`. * *(util.program)* Added optional `exit_code` parameter to `manual_stop`. * *(util.manhole)* Removed usage of loop parameters to fix Python 3.10 compatibility. * *(appservice.api)* Switched `IntentAPI.batch_send` method to use custom Event classes instead of the default ones (since some normal event fields aren't applicable when batch sending). [@maltee1]: https://github.com/maltee1 [#101]: https://github.com/mautrix/python/pull/101 [#102]: https://github.com/mautrix/python/pull/102 ## v0.16.4 (2022-05-10) * *(types, bridge)* Dropped support for appservice login with unstable prefix. * *(util.async_db)* Fixed some database start errors causing unnecessary noise in logs. * *(bridge.commands)* Added helper method to redact bridge commands. ## v0.16.3 (2022-04-21) * *(types)* Changed `set_thread_parent` to have an explicit option for disabling the thread-as-reply fallback. ## v0.16.2 (2022-04-21) * *(types)* Added `get_thread_parent` and `set_thread_parent` helper methods for `MessageEventContent`. * *(bridge)* Increased timeout for `MessageSendCheckpoint.send`. ## v0.16.1 (2022-04-17) * **Breaking change** Removed `r0` path support. * The new `v3` paths are implemented since Synapse 1.48, Dendrite 0.6.5, and Conduit 0.4.0. Servers older than these are no longer supported. ## v0.16.0 (2022-04-11) * **Breaking change *(types)*** Removed custom `REPLY` relation type and changed `RelatesTo` structure to match the actual event content. * Applications using `content.get_reply_to()` and `content.set_reply()` will keep working with no changes. * *(types)* Added `THREAD` relation type and `is_falling_back` field to `RelatesTo`. ## v0.15.8 (2022-04-08) * *(client.api)* Added experimental prometheus metric for file upload speed. * *(util.async_db)* Improved type hints for `UpgradeTable.register` * *(util.async_db)* Changed connection string log to redact database password. ## v0.15.7 (2022-04-05) * *(api)* Added `file_name` parameter to `HTTPAPI.get_download_url`. ## v0.15.6 (2022-03-30) * *(types)* Fixed removing nested (i.e. malformed) reply fallbacks generated by some clients. * *(types)* Added automatic reply fallback trimming to `set_reply()` to prevent accidentally creating nested reply fallbacks. ## v0.15.5 (2022-03-28) * *(crypto)* Changed default behavior of OlmMachine to ignore instead of reject key requests from other users. * Fixed some type hints ## v0.15.3 & v0.15.4 (2022-03-25) * *(client.api)* Fixed incorrect HTTP methods in async media uploads. ## v0.15.2 (2022-03-25) * *(client.api)* Added support for async media uploads ([MSC2246]). * Moved `async_getter_lock` decorator to `mautrix.util` (from `mautrix.bridge`). * The old import path will keep working. [MSC2246]: https://github.com/matrix-org/matrix-spec-proposals/pull/2246 ## v0.15.1 (2022-03-23) * *(types)* Added `ensure_has_html` method for `TextMessageEventContent` to generate a HTML `formatted_body` from the plaintext `body` correctly (i.e. escaping HTML and replacing newlines). ## v0.15.0 (2022-03-16) * **Breaking change** Removed Python 3.7 support. * **Breaking change *(api)*** Removed `r0` from default path builders in order to update to `v3` and per-endpoint versioning. * The client API modules have been updated to specify v3 in the paths, other direct usage of `Path`, `ClientPath` and `MediaPath` will have to be updated manually. `UnstableClientPath` no longer exists and should be replaced with `Path.unstable`. * There's a temporary hacky backwards-compatibility layer which replaces /v3 with /r0 if the server doesn't advertise support for Matrix v1.1 or higher. It can be activated by calling the `.versions()` method in `ClientAPI`. The bridge module calls that method automatically. * **Breaking change *(util.formatter)*** Removed lxml-based HTML parser. * The parsed data format is still compatible with lxml, so it is possible to use lxml with `MatrixParser` by setting `lxml.html.fromstring` as the `read_html` method. * **Breaking change *(crypto)*** Moved `TrustState`, `DeviceIdentity`, `OlmEventKeys` and `DecryptedOlmEvent` dataclasses from `crypto.types` into `types.crypto`. * **Breaking change *(bridge)*** Made `User.get_puppet` abstract and added new abstract `User.get_portal_with` and `Portal.get_dm_puppet` methods. * Added a redundant `__all__` to various `__init__.py` files to appease pyright. * *(api)* Reduced aiohttp memory usage when uploading large files by making an in-memory async iterable instead of passing the bytes directly. * *(bridge)* Removed legacy community utilities. * *(bridge)* Added support for creating DM portals with minimal bridge-specific code. * *(util.async_db)* Fixed counting number of db upgrades. * *(util.async_db)* Added support for schema migrations that jump versions. * *(util.async_db)* Added system for preventing using the same database for multiple programs. * To enable it, provide an unique program name as the `owner_name` parameter in `Database.create`. * Additionally, if `ignore_foreign_tables` is set to `True`, it will check for tables of some known software like Synapse and Dendrite. * The `bridge` module enables both options by default. * *(util.db)* Module deprecated. The async_db module is recommended. However, the SQLAlchemy helpers will remain until maubot has switched to asyncpg. * *(util.magic)* Allowed `bytearray` as an input type for the `mimetype` method. * *(crypto.attachments)* Added method to encrypt a `bytearray` in-place to avoid unnecessarily duplicating data in memory. ## v0.14.10 (2022-02-01) * *(bridge)* Fixed accidentally broken Python 3.7 compatibility. ## v0.14.9 (2022-02-01) * *(client.api)* Added `reason` field to `leave_room` and `invite_user` methods. ## v0.14.8 (2022-01-31) * *(util.formatter)* Deprecated the lxml-based HTML parser and made the htmlparser-based parser the default. The lxml-based parser will be removed in v0.15. * *(client.api)* Fixed `filter_json` parameter in `get_messages` not being sent to the server correctly. * *(bridge)* Added utilities for implementing disappearing messages. ## v0.14.7 (2022-01-29) * *(client)* Fixed error inviting users with custom member event content if the server had disabled fetching profiles. * *(util.utf16_surrogate)* Added utilities for adding/removing unicode surrogate pairs in strings. * *(util.magic)* Added check to make sure the parameter to `mimetype()` is either `bytes` or `str`. ## v0.14.6 (2022-01-26) * **Breaking change *(util.message_send_checkpoint)*** Changed order of `send` parameters to match `BridgeState.send` (this is not used by most software, which is why the breaking change is in a patch release). * *(util.async_db)* Changed the default size of the aiosqlite thread pool to 1, as it doesn't reliably work with higher values. * *(util.async_db)* Added logging for database queries that take a long time (>1 second). * *(client)* Added logging for sync requests that take a long time (>40 seconds, with the timeout being 30 seconds). * *(util.variation_selector)* Fixed variation selectors being incorrectly added even if the emoji had a skin tone selector. * *(bridge)* Fixed the process getting stuck if a config error caused the bridge to stop itself without stopping the SQLite thread. * Added pre-commit hooks to run black, isort and some other checks. ## v0.14.5 (2022-01-14) * *(util.formatter)* Removed the default handler for room pill conversion. * This means they'll be formatted as normal links unless the bridge or other thing using the formatter overrides `room_pill_to_fstring`. * *(types)* Fixed the `event_id` property of `MatrixURI`s throwing an error (instead of returning `None`) when the parsed link didn't contain a second part with an event ID. ## v0.14.4 (2022-01-13) * Bumped minimum yarl version to 1.5. v1.4 and below didn't allow `URL.build()` with a scheme but no host, which is used in the `matrix:` URI generator that was added in v0.14.3. * *(appservice)* Removed support for adding a `group_id` to user namespaces in registration files. * *(types)* Updated `Serializable.parse_json` type hint to allow `bytes` in addition to `str` (because `json.loads` allows both). * *(bridge)* Added `retry_num` parameter to `User.send_remote_checkpoint`. ## v0.14.3 (2022-01-05) * *(types)* Added `MatrixURI` type to parse and build `matrix:` URIs and `https://matrix.to` URLs. * *(util.formatter)* `matrix:` URIs are now supported in incoming messages (using the new parser mentioned above). * *(util.variation_selector)* Switched to generating list of emoji using data directly from the Unicode spec instead of emojibase. * *(util.formatter)* Whitespace in non-`pre` elements is now compressed into a single space. Newlines are also replaced with a space instead of removed completely. Whitespace after a block element is removed completely. * *(util.ffmpeg)* Added option to override output path, which allows outputting to stdout (by specifying `-`). * *(util.config)* Changed `ConfigUpdateHelper.copy` to ignore comments if the entity being copied is a commentable yaml object (e.g. map or list). ## v0.14.2 (2021-12-30) * *(appservice)* Fixed `IntentAPI` throwing an error when `redact` was called with a `reason`, but without `extra_content`. ## v0.14.1 (2021-12-29) * *(util.ffmpeg)* Added simple utility module that wraps ffmpeg and tempfiles to convert audio/video files to different formats, primarily intended for bridging. FFmpeg must be installed separately and be present in `$PATH`. ## v0.14.0 (2021-12-26) * **Breaking change *(mautrix.util.formatter)*** Made `MatrixParser` async and non-static. * Being async is necessary for bridges that need to make database calls to convert mentions (e.g. Telegram has @username mentions, which can't be extracted from the Matrix user ID). * Being non-static allows passing additional context into the formatter by extending the class and setting instance variables. * *(util.formatter)* Added support for parsing [spoilers](https://spec.matrix.org/v1.1/client-server-api/#spoiler-messages). * *(crypto.olm)* Added `describe` method for `OlmSession`s. * *(crypto)* Fixed sorting Olm sessions (now sorted by last successful decrypt time instead of alphabetically by session ID). * *(crypto.store.asyncpg)* Fixed caching Olm sessions so that using the same session twice wouldn't cause corruption. * *(crypto.attachments)* Added support for decrypting files from non-spec-compliant clients (e.g. FluffyChat) that have a non-zero counter part in the AES initialization vector. * *(util.async_db)* Added support for using Postgres positional param syntax in the async SQLite helper (by regex-replacing `$` with `?`). * *(util.async_db)* Added wrapper methods for `executemany` in `Database` and aiosqlite `TxnConnection`. * *(bridge)* Changed portal cleanup to leave and forget rooms using double puppeting instead of just kicking the user. ## v0.13.3 (2021-12-15) * Fixed type hints in the `mautrix.crypto.store` module. * Added debug logs for detecting crypto sync handling slowness. ## v0.13.2 (2021-12-15) * Switched message double puppet indicator convention from `"net.maunium..puppet": true` to `"fi.mau.double_puppet_source": ""`. * Added double puppet indicator to redactions made with `IntentAPI.redact`. ## v0.13.1 (2021-12-12) * Changed lack of media encryption dependencies (pycryptodome) to be a fatal error like lack of normal encryption dependencies (olm) are in v0.13.0. * Added base methods for implementing relay mode in bridges (started by [@Alejo0290] in [#72]). [@Alejo0290]: https://github.com/Alejo0290 [#72]: https://github.com/mautrix/python/pull/72 ## v0.13.0 (2021-12-09) * Formatted all code using [black](https://github.com/psf/black) and [isort](https://github.com/PyCQA/isort). * Added `power_level_override` parameter to `ClientAPI.create_room`. * Added default implementations of `delete-portal` and `unbridge` commands for bridges * Added automatic Olm session recreation if an incoming message fails to decrypt. * Added automatic key re-requests in bridges if the Megolm session doesn't arrive on time. * Changed `ClientAPI.send_text` to parse the HTML to generate a plaintext body instead of using the HTML directly when a separate plaintext body is not provided (also affects `send_notice` and `send_emote`). * Changed lack of encryption dependencies to be a fatal error if encryption is enabled in bridge config. * Fixed `StoreUpdatingAPI` not updating the local state store when using friendly membership methods like `kick_user`. * Switched Bridge class to use async_db (asyncpg/aiosqlite) instead of the legacy SQLAlchemy db by default. * Removed deprecated `ClientAPI.parse_mxid` method (use `ClientAPI.parse_user_id` instead). * Renamed `ClientAPI.get_room_alias` to `ClientAPI.resolve_room_alias`. ## v0.12.5 (2021-11-30) * Added wrapper for [MSC2716]'s `/batch_send` endpoint in `IntentAPI`. * Added some Matrix request metrics (thanks to [@jaller94] in [#68]). * Added utility method for adding variation selector 16 to emoji strings the same way as Element does (using emojibase data). [MSC2716]: https://github.com/matrix-org/matrix-spec-proposals/pull/2716 [@jaller94]: https://github.com/jaller94 [#68]: https://github.com/mautrix/python/pull/68 ## v0.12.4 (2021-11-25) * *(util.formatter)* Added support for parsing Matrix HTML colors. ## v0.12.3 (2021-11-23) * Added autogenerated docs with Sphinx. * Rendered version available at https://docs.mau.fi/python/latest/ (also version-specific docs at https://docs.mau.fi/python/v0.12.3/). * Added asyncpg to client state store unit tests. * Fixed client state store `get_members` being broken on asyncpg (broken in 0.12.2). * Fixed `get_members_filtered` not taking the `memberships` parameter into account in the memory store. ## v0.12.2 (2021-11-20) * Added more control over which membership states to return in client state store. * Added some basic tests for the client state store. * Fixed `OlmMachine.account` property not being defined before calling `load`. ## v0.12.1 (2021-11-19) * Added default (empty) value for `unsigned` in the event classes. * Updated the `PgStateStore` in the client module to fully implement the crypto `StateStore` abstract class. * The crypto module now has a `PgCryptoStateStore` that combines the client `PgStateStore` with the abstract crypto state store. ## v0.12.0 (2021-11-19) * **Breaking change (client):** The `whoami` method now returns a dataclass with `user_id` and `device_id` fields, instead of just returning the `user_id` as a string. * Added `delete` method for crypto stores (useful when changing the device ID). * Added `DECRYPTED` step for message send checkpoints. * Added proper user agent to bridge state and message send checkpoint requests. ## v0.11.4 (2021-11-16) * Improved default event filter in bridges * The filtering method is now `allow_matrix_event` instead of `filter_matrix_event` and the return value is reversed. * Most bridges now don't need to override the method, so the old method isn't used at all. * Added support for the stable version of [MSC2778]. ## v0.11.3 (2021-11-13) * Updated registering appservice ghosts to use `inhibit_login` flag to prevent lots of unnecessary access tokens from being created. * If you want to log in as an appservice ghost, you should use [MSC2778]'s appservice login (e.g. like the [bridge e2ee module does](https://github.com/mautrix/python/blob/v0.11.2/mautrix/bridge/e2ee.py#L178-L182) for example) * Fixed unnecessary warnings about message send endpoints in some cases where the endpoint wasn't configured. ## v0.11.2 (2021-11-11) * Updated message send checkpoint system to handle all cases where messages are dropped or consumed by mautrix-python. ## v0.11.1 (2021-11-10) * Fixed regression in Python 3.8 support in v0.11.0 due to `asyncio.Queue` type hinting. * Made the limit of HTTP connections to the homeserver configurable (thanks to [@justinbot] in [#64]). [#64]: https://github.com/mautrix/python/pull/64 ## v0.11.0 (2021-11-09) * Added support for message send checkpoints (as HTTP requests, similar to the bridge state reporting system) by [@sumnerevans]. * Added support for aiosqlite with the same interface as asyncpg. * This includes some minor breaking changes to the asyncpg interface. * Made config writing atomic (using a tempfile) to prevent the config disappearing when disk is full. * Changed prometheus to start before rest of `startup_actions` (thanks to [@Half-Shot] in [#63]). * Stopped reporting `STARTING` bridge state on startup by [@sumnerevans]. [@Half-Shot]: https://github.com/Half-Shot [#63]: https://github.com/mautrix/python/pull/63 ## v0.10.11 (2021-10-26) * Added support for custom bridge bot welcome messages (thanks to [@justinbot] in [#58]). [@justinbot]: https://github.com/justinbot [#58]: https://github.com/mautrix/python/pull/58 ## v0.10.10 (2021-10-08) * Added support for disabling bridge management commands based on custom rules (thanks to [@tadzik] in [#56]). [@tadzik]: https://github.com/tadzik [#56]: https://github.com/mautrix/python/pull/56 ## v0.10.9 (2021-09-29) * Changed `remove_room_alias` to ignore `M_NOT_FOUND` errors by default, to preserve Synapse behavior on spec-compliant server implementations. The `raise_404` argument can be set to `True` to not suppress the errors. * Fixed bridge state pings returning `UNCONFIGURED` as a global state event. ## v0.10.8 (2021-09-23) * **Breaking change (serialization):** Removed `Generic[T]` backwards compatibility from `SerializableAttrs` (announced in [v0.9.6](https://github.com/mautrix/python/releases/tag/v0.9.6)). * Stopped using `self.log` in `Program` config load errors as the logger won't be initialized yet. * Added check to ensure reply fallback removal is only attempted once. * Fixed `remove_event_handler` throwing a `KeyError` if no event handlers had been registered for the specified event type. * Fixed deserialization showing wrong key names on missing key errors. ## v0.10.7 (2021-08-31) * Removed Python 3.9+ features that were accidentally used in v0.10.6. ## v0.10.6 (2021-08-30) * Split `_http_handle_transaction` in `AppServiceServerMixin` to allow easier reuse. ## v0.10.5 (2021-08-25) * Fixed `MemoryStateStore`'s `get_members()` implementation (thanks to [@hifi] in [#54]). * Re-added `/_matrix/app/com.beeper.bridge_state` endpoint. [@hifi]: https://github.com/hifi [#54]: https://github.com/mautrix/python/pull/54 ## v0.10.4 (2021-08-18) * Improved support for sending member events manually (when using the `extra_content` field in join, invite, etc). * There's now a `fill_member_event` method that's called by manual member event sending that adds the displayname and avatar URL. Alternatively, `fill_member_event_callback` can be set to fill the member event manually. ## v0.10.3 (2021-08-14) * **Breaking change:** The bridge status notification system now uses a `BridgeStateEvent` enum instead of the `ok` boolean. * Added better log messages when bridge encryption error notice fails to send. * Added manhole for all bridges. * Dropped Python 3.6 support in manhole. * Switched to using `PyCF_ALLOW_TOP_LEVEL_AWAIT` for manhole in Python 3.8+. ## v0.9.10 (2021-07-24) * Fixed async `Database` class mutating the `db_args` dict passed to it. * Fixed `None`/`null` values with factory defaults being deserialized into the `attr.Factory` object instead of the expected value. ## v0.9.9 (2021-07-16) * **Breaking change:** Made the `is_direct` property required in the bridge `Portal` class. The property was first added in v0.8.4 and is used for handling `m.room.encryption` events (enabling encryption). * Added PEP 561 typing info (by [@sumnerevans] in [#49]). * Added support for [MSC3202] in appservice module. * Made bridge state filling more customizable. * Moved `BridgeState` class from `mautrix.bridge` to `mautrix.util.bridge_state`. * Fixed receiving appservice transactions with `Authorization` header (i.e. fixed [MSC2832] support). [MSC3202]: https://github.com/matrix-org/matrix-spec-proposals/pull/3202 [MSC2832]: https://github.com/matrix-org/matrix-spec-proposals/pull/2832 [@sumnerevans]: https://github.com/sumnerevans [#49]: https://github.com/mautrix/python/pull/49 ## v0.9.8 (2021-06-24) * Added `remote_id` field to `push_bridge_state` method. ## v0.9.7 (2021-06-22) * Added tests for `factory` and `hidden` serializable attrs. * Added `login-matrix`, `logout-matrix`, `ping-matrix` and `clear-cache-matrix` commands in the bridge module. To enable the commands, bridges must implement the `User.get_puppet()` method to return the `Puppet` instance corresponding to the user's remote ID. * Fixed logging events that were ignored due to lack of permissions of the sender. * Fixed deserializing encrypted edit events ([mautrix/telegram#623]). [mautrix/telegram#623]: https://github.com/mautrix/telegram/issues/623 ## v0.9.6 (2021-06-20) * Replaced `GenericSerializable` with a bound `TypeVar`. * This means that classes extending `SerializableAttrs` no longer have to use the `class Foo(SerializableAttrs['Foo'])` syntax to get type hints, just `class Foo(SerializableAttrs)` is enough. * Backwards compatibility for using the `['Foo']` syntax will be kept until v0.10. * Added `field()` as a wrapper for `attr.ib()` that makes it easier to add custom metadata for serializable attrs things. * Added some tests for type utilities. * Changed attribute used to exclude links from output in HTML parser. * New attribute is `data-mautrix-exclude-plaintext` and works for basic formatting (e.g. ``) in addition to ``. * The previous attribute wasn't actually checked correctly, so it never worked. ## v0.9.5 (2021-06-11) * Added `SynapseAdminPath` to build `/_synapse/admin` paths. ## v0.9.4 (2021-06-09) * Updated bridge status pushing utility to support `remote_id` and `remote_name` fields to specify which account on the remote network is bridged. ## v0.9.3 (2021-06-04) * Switched to stable space prefixes. * Added option to send arbitrary content with membership events. * Added warning if media encryption dependencies aren't installed. * Added support for pycryptodomex for media encryption. * Added utilities for pushing bridge status to an arbitrary HTTP endpoint. ## v0.9.2 (2021-04-26) * Changed `update_direct_chats` bridge method to only send updated `m.direct` data if the content was modified. ## v0.9.1 (2021-04-20) * Added type classes for VoIP. * Added methods for modifying push rules and room tags. * Switched to `asyncio.create_task` everywhere (replacing the older `loop.create_task` and `asyncio.ensure_future`). ## v0.9.0 (2021-04-16) * Added option to retry all HTTP requests when encountering a HTTP network error or gateway error response (502/503/504) * Disabled by default, you need to set the `default_retry_count` field in `HTTPAPI` (or `Client`), or the `default_http_retry_count` field in `AppService` to enable. * Can also be enabled with `HTTPAPI.request()`s `retry_count` parameter. * The `mautrix.util.network_retry` module was removed as it became redundant. * Fixed GET requests having a body ([#44]). [#44]: https://github.com/mautrix/python/issues/44 ## v0.8.18 (2021-04-01) * Made HTTP request user agents more configurable. * Bridges will now include the name and version by default. * Added some event types and classes for space events. * Fixed local power level check failing for `m.room.member` events. ## v0.8.17 (2021-03-22) * Added warning log when giving up on decrypting message. * Added mimetype magic utility that supports both file-magic and python-magic. * Updated asmux DM endpoint (`net.maunium.asmux` -> `com.beeper.asmux`). * Moved RowProxy and ResultProxy imports into type checking ([#46]). This should fix SQLAlchemy 1.4+, but SQLAlchemy databases will likely be deprecated entirely in the future. [#46]: https://github.com/mautrix/python/issues/46 ## v0.8.16 (2021-02-16) * Made the Bridge class automatically fetch media repo config at startup. Bridges are recommended to check `bridge.media_config.upload_size` before even downloading remote media. ## v0.8.15 (2021-02-08) * Fixed the high-level `Client` class to not try to update state if there' no `state_store` set. ## v0.8.14 (2021-02-07) * Added option to override the asyncpg pool used in the async `Database` wrapper. ## v0.8.13 (2021-02-07) * Stopped checking error message when checking if user is not registered on whoami. Now it only requires the `M_FORBIDDEN` errcode instead of a specific human-readable error message. * Added handling for missing `unsigned` object in membership events (thanks to [@jevolk] in [#39]). * Added warning message when receiving encrypted messages with end-to-bridge encryption disabled. * Added utility for mutexes in caching async getters to prevent race conditions. [@jevolk]: https://github.com/jevolk [#39]: https://github.com/mautrix/python/pull/39 ## v0.8.12 (2021-02-01) * Added handling for `M_NOT_FOUND` errors when getting pinned messages. * Fixed bridge message send retrying so it always uses the same transaction ID. * Fixed high-level `Client` class to automatically update state store with events from sync. ## v0.8.11 (2021-01-22) * Added automatic login retry if double puppeting token is invalid on startup or gets invalidated while syncing. * Fixed ExtensibleEnum leaking keys between different types. * Allowed changing bot used in ensure_joined. ## v0.8.10 (2021-01-22) * Changed attr deserialization errors to log full data instead of only known fields when deserialization fails. ## v0.8.9 (2021-01-21) * Allowed `postgresql://` scheme in end-to-bridge encryption database URL (in addition to `postgres://`). * Slightly improved attr deserialization error messages. ## v0.8.8 (2021-01-19) * Changed end-to-bridge encryption to fail if homeserver doesn't advertise appservice login. This breaks Synapse 1.21, but there have been plenty of releases since then. * Switched BaseFileConfig to use the built-in [pkgutil] instead of pkg_resources (which requires setuptools). * Added handling for `M_NOT_FOUND` errors when updating `m.direct` account data through double puppeting in bridges. * Added logging of data when attr deserializing fails. * Exposed ExtensibleEnum in `mautrix.types` module. [pkgutil]: https://docs.python.org/3/library/pkgutil.html ## v0.8.7 (2021-01-15) * Changed attr deserializer to deserialize optional missing fields into `None` instead of `attr.NOTHING` by default. * Added option not to use transaction for asyncpg database upgrades. ## v0.8.6 (2020-12-31) * Added logging when sync errors are resolved. * Made `.well-known` fetching ignore the response content type header. * Added handling for users enabling encryption in private chat portals. ## v0.8.5 (2020-12-06) * Made SerializableEnum work with int values. * Added TraceLogger type hints to command handling classes. ## v0.8.4 (2020-12-02) * Added logging when sync errors are resolved. * Made `.well-known` fetching ignore the response content type header. * Added handling for users enabling encryption in private chat portals. ## v0.8.3 (2020-11-17) * Fixed typo in HTML reply fallback generation when target message is plaintext. * Made `CommandEvent.mark_read` async instead of returning an awaitable, because sometimes it didn't return an awaitable. ## v0.8.2 (2020-11-10) * Added utility function for retrying network calls (`from mautrix.util.network_retry import call_with_net_retry`). * Updated `Portal._send_message` to use aforementioned utility function. ## v0.8.1 (2020-11-09) * Changed `Portal._send_message` to retry after 5 seconds (up to 5 attempts total by default) if server returns 502/504 error or the connection fails. ## v0.8.0 (2020-11-07) * Added support for cross-server double puppeting (thanks to [@ShadowJonathan] in [#26]). * Added support for receiving ephemeral events pushed directly ([MSC2409]). * Added `opt_prometheus` utility to add support for metrics without a hard dependency on the prometheus_client library. * Added `formatted()` helper method to get the `formatted_body` of a text message. * Bridge command system improvements (thanks to [@witchent] in [#29], [#30] and [#31]). * `CommandEvent`s now know which portal they were ran in. They also have a `main_intent` property that gets the portal's main intent or the bridge bot. * `CommandEvent.reply()` will now use the portal's main intent if the bridge bot is not in the room. * The `needs_auth` and `needs_admin` permissions are now included here instead of separately in each bridge. * Added `discard-megolm-session` command. * Moved `set-pl` and `clean-rooms` commands from mautrix-telegram. * Switched to using yarl instead of manually concatenating base URL with path. * Switched to appservice login ([MSC2778]) instead of shared secret login for bridge bot login in the end-to-bridge encryption helper. * Switched to `TEXT` instead of `VARCHAR(255)` in all databases ([#28]). * Changed replies to use a custom `net.maunium.reply` relation type instead of `m.reference`. * Fixed potential db unique key conflicts when the membership state caches were updated from `get_joined_members`. * Fixed database connection errors causing sync loops to stop completely. * Fixed `EventType`s sometimes having `None` instead of `EventType.Class.UNKNOWN` as the type class. * Fixed regex escaping in bridge registration generation. [MSC2778]: https://github.com/matrix-org/matrix-spec-proposals/pull/2778 [MSC2409]: https://github.com/matrix-org/matrix-spec-proposals/pull/2409 [@ShadowJonathan]: https://github.com/ShadowJonathan [@witchent]: https://github.com/witchent [#26]: https://github.com/mautrix/python/pull/26 [#28]: https://github.com/mautrix/python/issues/28 [#29]: https://github.com/mautrix/python/pull/29 [#30]: https://github.com/mautrix/python/pull/30 [#31]: https://github.com/mautrix/python/pull/31 ## v0.7.14 (2020-10-27) * Wrapped union types in `NewType` to allow `setattr`. This fixes Python 3.6 and 3.9 compatibility. ## v0.7.13 (2020-10-09) * Extended session wait time when handling encrypted messages in bridges: it'll now wait for 5 seconds, then send an error, then wait for 10 more seconds. If the keys arrive in those 10 seconds, the message is bridged and the error is redacted, otherwise the error is edited. ## v0.7.11 (2020-10-02) * Lock olm sessions between encrypting and sending to make sure messages go out in the correct order. ## v0.7.10 (2020-09-29) * Fixed deserializing the `info` object in media msgtypes into dataclasses. ## v0.7.9 (2020-09-28) * Added parameter to change how long `EncryptionManager.decrypt()` should wait for the megolm session to arrive. * Changed `get_displayname` and `get_avatar_url` to ignore `M_NOT_FOUND` errors. * Updated type hint of `set_reply` to allow `EventID`s. ## v0.7.8 (2020-09-27) * Made the `UUID` type de/serializable by default. ## v0.7.7 (2020-09-25) * Added utility method for waiting for incoming group sessions in OlmMachine. * Made end-to-bridge encryption helper wait for incoming group sessions for 3 seconds. ## v0.7.6 (2020-09-22) * Fixed bug where parsing invite fails if `unsigned` is not set or null. * Added trace logs when bridge module ignores messages. ## v0.7.5 (2020-09-19) * Added utility for measuring async method time in prometheus. ## v0.7.4 (2020-09-19) * Made `sender_device` optional in decrypted olm events. * Added opt_prometheus utility for using prometheus as an optional dependency. * Added Matrix event time processing metric for bridges when prometheus is installed. ## v0.7.3 (2020-09-17) * Added support for telling the user about decryption errors in bridge module. ## v0.7.2 (2020-09-12) * Added bridge config option to pass custom arguments to SQLAlchemy's `create_engine`. ## v0.7.1 (2020-09-09) * Added optional automatic prometheus config to the `Program` class. ## v0.7.0 (2020-09-04) * Added support for e2ee key sharing in `OlmMachine` (both sending and responding to requests). * Added option for automatically sharing keys from bridges. * Added account data get/set methods for `ClientAPI`. * Added helper for bridges to update `m.direct` account data. * Added default user ID and alias namespaces for bridge registration generation. * Added asyncpg-based client state store implementation. * Added filtering query parameters to `ClientAPI.get_members`. * Changed attachment encryption methods to return `EncryptedFile` objects instead of dicts. * Changed `SimpleLock` to use `asyncio.Event` instead of `asyncio.Future`. * Made SQLAlchemy optional for bridges. * Fixed error when profile endpoint responses are missing keys. ## v0.6.1 (2020-07-30) * Fixed disabling notifications in many rooms at the same time. ## v0.6.0 (2020-07-27) * Added native end-to-end encryption module. * Switched e2be helper to use native e2ee instead of matrix-nio. * Includes crypto stores based on pickle and asyncpg. * Added e2ee helper to high-level client module. * Added support for getting `prev_content` from the top level in addition to `unsigned`. ## v0.5.8 (2020-07-27) * Fixed deserializer using `attr.NOTHING` instead of `None` when there's no default value. ## v0.5.7 (2020-06-16) * Added `alt_aliases` to canonical alias state event content (added in Matrix client-server spec r0.6.1). ## v0.5.6 (2020-06-15) * Added support for adding aliases for bridge commands. ## v0.5.5 (2020-06-15) * Added option to set default event type class in `EventType.find()`. ## v0.5.4 (2020-06-09) * Fixed notification disabler breaking when not using double puppeting. ## v0.5.3 (2020-06-08) * Added `NotificationDisabler` utility class for easily disabling notifications while a bridge backfills messages. ## v0.5.2 (2020-06-08) * Added support for automatically calling `ensure_registered` if `whoami` says the bridge bot is not registered in `Bridge.wait_for_connection`. ## v0.5.1 (2020-06-05) * Moved initializing end-to-bridge encryption to before other startup actions. ## v0.5.0 (2020-06-03) * Added extensible enum class ([#14]). * Added some asyncpg utilities. * Added basic config validation support to disallow default values. * Added matrix-nio based end-to-bridge encryption helper for bridges. * Added option to use TLS for appservice listener. * Added support for `Authorization` header from homeserver in appservice transaction handler. * Added option to override appservice transaction handling method. * Split `Bridge` initialization class into a more abstract `Program`. * Split config loading. [#14]: https://github.com/mautrix/python/issues/14 ## v0.4.2 (2020-02-14) * Added option to add custom arguments for programs based on the `Bridge` class. * Added method for stopping a `Bridge`. * Made `Obj` picklable. ## v0.4.1 (2020-01-07) * Removed unfinished `enum.py`. * Increased default config line wrapping width. * Fixed default visibility when adding rooms and users with bridge community helper. ## v0.4.0 (2019-12-28) * Initial "stable" release of the major restructuring. * Package now includes the Matrix client framework and other utilities instead of just an appservice module. * Package renamed from mautrix-appservice to mautrix. * Switched license from MIT to MPLv2. ## v0.3.11 (2019-06-20) * Update state store after sending state event. This is required for some servers like t2bot.io that have disabled echoing state events to appservices. ## v0.3.10.dev1 (2019-05-23) * Hacky fix for null `m.relates_to`'s. ## v0.3.9 (2019-05-11) * Only use json.dumps() in request() if content is json-serializable. ## v0.3.8 (2019-02-13) * Added missing room/event ID quotings. ## v0.3.7 (2018-09-28) * Fixed `get_room_members()` returning `dict_keys` rather than `list` when getting only joined members. ## v0.3.6 (2018-08-06 * Fixed `get_room_joined_memberships()` (thanks to [@turt2live] in [#6]). [@turt2live]: https://github.com/turt2live [#6]: https://github.com/mautrix/python/pull/6 ## v0.3.5 (2018-08-06) * Added parameter to change aiohttp Application parameters. * Fixed `get_power_levels()` with state store implementations that don't throw a `ValueError` on cache miss. ## v0.3.4 (2018-08-05) * Updated `get_room_members()` to use `/joined_members` instead of `/members` when possible. ## v0.3.3 (2018-07-25) * Updated some type hints. ## v0.3.2 (2018-07-23) * Fixed HTTPAPI init for real users. * Fixed content-type for empty objects. ## v0.3.1 (2018-07-22) * Added support for real users. ## v0.3.0 (2018-07-10) * Made `StateStore` into an abstract class for easier custom storage backends. * Fixed response of `/transaction` to return empty object with 200 OK's as per spec. * Fixed URL parameter encoding. * Exported `IntentAPI` for type hinting. ## v0.2.0 (2018-06-24) * Switched to GPLv3 to MIT license. * Updated state store to store full member events rather than just the membership status. ## v0.1.5 (2018-05-06) * Made room avatar in `set_room_avatar()` optional to allow unsetting avatar. ## v0.1.4 (2018-04-26) * Added `send_sticker()`. ## v0.1.3 (2018-03-29) * Fixed AppService log parameter type hint. * Fixed timestamp handling. ## v0.1.2 (2018-03-29) * Return 400 Bad Request if user/room query doesn't have user ID/alias field (respectively). * Added support for timestamp massaging and source URLs. ## v0.1.1 (2018-03-11) * Added type hints. * Added power level checks to `set_state_event()`. * Renamed repo to mautrix-appservice-python (PyPI package is still mautrix-appservice). ## v0.1.0 (2018-03-08) * Initial version. Transferred from mautrix-telegram. python-0.20.7/LICENSE000066400000000000000000000405251473573527000141510ustar00rootroot00000000000000Mozilla Public License Version 2.0 ================================== 1. Definitions -------------- 1.1. "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 1.2. "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution. 1.3. "Contribution" means Covered Software of a particular Contributor. 1.4. "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 1.5. "Incompatible With Secondary Licenses" means (a) that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or (b) that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. 1.6. "Executable Form" means any form of the work other than Source Code Form. 1.7. "Larger Work" means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" means this document. 1.9. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 1.10. "Modifications" means any of the following: (a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or (b) any new file in Source Code Form that contains any Covered Software. 1.11. "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 1.12. "Secondary License" means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. 1.13. "Source Code Form" means the form of the work preferred for making modifications. 1.14. "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 2. License Grants and Conditions -------------------------------- 2.1. Grants Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: (a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and (b) under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. 2.2. Effective Date The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 2.3. Limitations on Grant Scope The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: (a) for any code that a Contributor has removed from Covered Software; or (b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or (c) under Patent Claims infringed by Covered Software in the absence of its Contributions. This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). 2.4. Subsequent Licenses No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). 2.5. Representation Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 2.6. Fair Use This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 2.7. Conditions Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. 3. Responsibilities ------------------- 3.1. Distribution of Source Form All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form. 3.2. Distribution of Executable Form If You distribute Covered Software in Executable Form then: (a) such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and (b) You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. 3.3. Distribution of a Larger Work You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). 3.4. Notices You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 3.5. Application of Additional Terms You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. 4. Inability to Comply Due to Statute or Regulation --------------------------------------------------- If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 5. Termination -------------- 5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. 5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. ************************************************************************ * * * 6. Disclaimer of Warranty * * ------------------------- * * * * Covered Software is provided under this License on an "as is" * * basis, without warranty of any kind, either expressed, implied, or * * statutory, including, without limitation, warranties that the * * Covered Software is free of defects, merchantable, fit for a * * particular purpose or non-infringing. The entire risk as to the * * quality and performance of the Covered Software is with You. * * Should any Covered Software prove defective in any respect, You * * (not any Contributor) assume the cost of any necessary servicing, * * repair, or correction. This disclaimer of warranty constitutes an * * essential part of this License. No use of any Covered Software is * * authorized under this License except under this disclaimer. * * * ************************************************************************ ************************************************************************ * * * 7. Limitation of Liability * * -------------------------- * * * * Under no circumstances and under no legal theory, whether tort * * (including negligence), contract, or otherwise, shall any * * Contributor, or anyone who distributes Covered Software as * * permitted above, be liable to You for any direct, indirect, * * special, incidental, or consequential damages of any character * * including, without limitation, damages for lost profits, loss of * * goodwill, work stoppage, computer failure or malfunction, or any * * and all other commercial damages or losses, even if such party * * shall have been informed of the possibility of such damages. This * * limitation of liability shall not apply to liability for death or * * personal injury resulting from such party's negligence to the * * extent applicable law prohibits such limitation. Some * * jurisdictions do not allow the exclusion or limitation of * * incidental or consequential damages, so this exclusion and * * limitation may not apply to You. * * * ************************************************************************ 8. Litigation ------------- Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims. 9. Miscellaneous ---------------- This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 10. Versions of the License --------------------------- 10.1. New Versions Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 10.2. Effect of New Versions You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 10.3. Modified Versions If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. Exhibit A - Source Code Form License Notice ------------------------------------------- This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership. Exhibit B - "Incompatible With Secondary Licenses" Notice --------------------------------------------------------- This Source Code Form is "Incompatible With Secondary Licenses", as defined by the Mozilla Public License, v. 2.0. python-0.20.7/MANIFEST.in000066400000000000000000000001631473573527000146740ustar00rootroot00000000000000include README.rst include CHANGELOG.md include LICENSE include requirements.txt include optional-requirements.txt python-0.20.7/README.rst000066400000000000000000000060631473573527000146320ustar00rootroot00000000000000mautrix-python ============== |PyPI| |Python versions| |License| |Docs| |Code style| |Imports| A Python 3.10+ asyncio Matrix framework. Matrix room: `#maunium:maunium.net`_ Components ---------- * Basic HTTP request sender (mautrix.api_) * `Client API`_ endpoints as functions (mautrix.client.api_) * Medium-level application service framework (mautrix.appservice_) * Basic transaction and user/alias query support (based on Cadair's python-appservice-framework_) * Basic room state storage * Intent wrapper around the client API functions (design based on matrix-appservice-bridge) * Medium-level end-to-end encryption framework (mautrix.crypto_) * Handles all the complicated e2ee key exchange * Uses libolm through python-olm for the low-level crypto * High-level bridging utility framework (mautrix.bridge_) * Base class for bridges * Common bridge configuration and appservice registration generation things * Double-puppeting helper * End-to-bridge encryption helper * High-level client framework (mautrix.client_) * Syncing and event handling helper. * End-to-end encryption helper. * Utilities (mautrix.util_) * Matrix HTML parsing and generating utilities * Manhole system (get a python shell in a running process) * YAML config helpers * Database helpers (new: asyncpg, legacy: SQLAlchemy) * Color logging utility * Very simple HMAC-SHA256 utility for signing tokens (like JWT, but hardcoded to use a single good algorithm) .. _#maunium:maunium.net: https://matrix.to/#/#maunium:maunium.net .. _python-appservice-framework: https://github.com/Cadair/python-appservice-framework/ .. _Client API: https://matrix.org/docs/spec/client_server/r0.6.1.html .. _mautrix.api: https://docs.mau.fi/python/latest/api/mautrix.api.html .. _mautrix.client.api: https://docs.mau.fi/python/latest/api/mautrix.client.api.html .. _mautrix.appservice: https://docs.mau.fi/python/latest/api/mautrix.appservice/index.html .. _mautrix.bridge: https://docs.mau.fi/python/latest/api/mautrix.bridge/index.html .. _mautrix.client: https://docs.mau.fi/python/latest/api/mautrix.client.html .. _mautrix.crypto: https://docs.mau.fi/python/latest/api/mautrix.crypto.html .. _mautrix.util: https://docs.mau.fi/python/latest/api/mautrix.util/index.html .. |PyPI| image:: https://img.shields.io/pypi/v/mautrix.svg :target: https://pypi.python.org/pypi/mautrix :alt: PyPI: mautrix .. |Python versions| image:: https://img.shields.io/pypi/pyversions/mautrix.svg .. |License| image:: https://img.shields.io/github/license/mautrix/python.svg :target: https://github.com/mautrix/python/blob/master/LICENSE :alt: License: MPL-2.0 .. |Docs| image:: https://img.shields.io/gitlab/pipeline-status/mautrix/python?branch=master&gitlab_url=https%3A%2F%2Fmau.dev&label=docs :target: https://docs.mau.fi/python/latest/ .. |Code style| image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/psf/black .. |Imports| image:: https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336 :target: https://pycqa.github.io/isort/ python-0.20.7/dev-requirements.txt000066400000000000000000000000651473573527000171770ustar00rootroot00000000000000pre-commit>=2.10.1,<3 isort>=5.10.1,<6 black>=24,<25 python-0.20.7/docs/000077500000000000000000000000001473573527000140665ustar00rootroot00000000000000python-0.20.7/docs/.dockerignore000066400000000000000000000000241473573527000165360ustar00rootroot00000000000000* !requirements.txt python-0.20.7/docs/.gitignore000066400000000000000000000000071473573527000160530ustar00rootroot00000000000000_build python-0.20.7/docs/Dockerfile000066400000000000000000000002471473573527000160630ustar00rootroot00000000000000FROM python:latest RUN apt-get update && apt-get install -y libolm-dev rsync COPY ./requirements.txt / RUN pip install -r /requirements.txt && rm -f /requirements.txt python-0.20.7/docs/Makefile000066400000000000000000000011721473573527000155270ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) python-0.20.7/docs/api/000077500000000000000000000000001473573527000146375ustar00rootroot00000000000000python-0.20.7/docs/api/mautrix.api.rst000066400000000000000000000000651473573527000176330ustar00rootroot00000000000000mautrix.api =========== .. automodule:: mautrix.api python-0.20.7/docs/api/mautrix.appservice/000077500000000000000000000000001473573527000204705ustar00rootroot00000000000000python-0.20.7/docs/api/mautrix.appservice/api.rst000066400000000000000000000002721473573527000217740ustar00rootroot00000000000000Appservice client API ===================== .. currentmodule:: mautrix.appservice .. autoclass:: mautrix.appservice.AppServiceAPI .. autoclass:: mautrix.appservice.ChildAppServiceAPI python-0.20.7/docs/api/mautrix.appservice/as_handler.rst000066400000000000000000000002201473573527000233140ustar00rootroot00000000000000Appservice server mixin ======================= .. currentmodule:: mautrix.appservice .. autoclass:: mautrix.appservice.AppServiceServerMixin python-0.20.7/docs/api/mautrix.appservice/index.rst000066400000000000000000000003531473573527000223320ustar00rootroot00000000000000mautrix.appservice ================== .. module:: mautrix.appservice .. autoclass:: mautrix.appservice.AppService .. toctree:: :maxdepth: 1 Appservice server mixin Appservice client API Intents python-0.20.7/docs/api/mautrix.appservice/intent.rst000066400000000000000000000001521473573527000225210ustar00rootroot00000000000000Intents ============= .. currentmodule:: mautrix.appservice .. autoclass:: mautrix.appservice.IntentAPI python-0.20.7/docs/api/mautrix.client.api.rst000066400000000000000000000002371473573527000211110ustar00rootroot00000000000000mautrix.client.api ================== .. module:: mautrix.client.api .. autoclass:: mautrix.client.ClientAPI :inherited-members: :no-show-inheritance: python-0.20.7/docs/api/mautrix.client.mixins.rst000066400000000000000000000014251473573527000216470ustar00rootroot00000000000000mautrix.client mixins ===================== The :class:`Client ` class itself is very small, most of the functionality on top of :class:`ClientAPI ` comes from mixins that it includes. In some cases it might be useful to extend from a mixin instead of the high-level client class (e.g. the appservice module's :class:`IntentAPI ` extends :class:`StoreUpdatingAPI `). Syncer ------ .. autoclass:: mautrix.client.Syncer DecryptionDispatcher -------------------- .. autoclass:: mautrix.client.DecryptionDispatcher EncryptingAPI ------------- .. autoclass:: mautrix.client.EncryptingAPI StoreUpdatingAPI ---------------- .. autoclass:: mautrix.client.StoreUpdatingAPI python-0.20.7/docs/api/mautrix.client.rst000066400000000000000000000006171473573527000203430ustar00rootroot00000000000000mautrix.client ============== .. module:: mautrix.client .. autoclass:: mautrix.client.Client :members: state_store, crypto, crypto_enabled, crypto_log, encryption_blacklist, encrypt, share_group_session, send_message_event, on, add_event_handler, remove_event_handler, add_dispatcher, remove_dispatcher, start, stop .. toctree:: :maxdepth: 1 Mixins python-0.20.7/docs/api/mautrix.client.state_store/000077500000000000000000000000001473573527000221405ustar00rootroot00000000000000python-0.20.7/docs/api/mautrix.client.state_store/asyncpg.rst000066400000000000000000000002421473573527000243340ustar00rootroot00000000000000mautrix.client.state\_store.async\_pg ===================================== .. autoclass:: mautrix.client.state_store.asyncpg.PgStateStore :no-undoc-members: python-0.20.7/docs/api/mautrix.client.state_store/file.rst000066400000000000000000000002221473573527000236050ustar00rootroot00000000000000mautrix.client.state\_store.file ================================ .. autoclass:: mautrix.client.state_store.FileStateStore :no-undoc-members: python-0.20.7/docs/api/mautrix.client.state_store/index.rst000066400000000000000000000004551473573527000240050ustar00rootroot00000000000000mautrix.client.state\_store =========================== .. module:: mautrix.client.state_store .. autoclass:: mautrix.client.state_store.StateStore Implementations --------------- .. toctree:: :maxdepth: 1 In-memory Async database (asyncpg/aiosqlite) Flat file python-0.20.7/docs/api/mautrix.client.state_store/memory.rst000066400000000000000000000002301473573527000241750ustar00rootroot00000000000000mautrix.client.state\_store.memory ================================== .. autoclass:: mautrix.client.state_store.MemoryStateStore :no-undoc-members: python-0.20.7/docs/api/mautrix.crypto.attachments.rst000066400000000000000000000001701473573527000227110ustar00rootroot00000000000000mautrix.crypto.attachments ========================== .. automodule:: mautrix.crypto.attachments :imported-members: python-0.20.7/docs/api/mautrix.types.rst000066400000000000000000000012141473573527000202230ustar00rootroot00000000000000mautrix.types ============= .. currentmodule:: mautrix.types .. autoclass:: mautrix.types.JSON .. autoclass:: mautrix.types.UserID .. autoclass:: mautrix.types.EventID .. autoclass:: mautrix.types.RoomID .. autoclass:: mautrix.types.RoomAlias .. autoclass:: mautrix.types.FilterID .. autoclass:: mautrix.types.ContentURI .. autoclass:: mautrix.types.SyncToken .. autoclass:: mautrix.types.DeviceID .. autoclass:: mautrix.types.SessionID .. autoclass:: mautrix.types.SigningKey .. autoclass:: mautrix.types.IdentityKey .. automodule:: mautrix.types :imported-members: :exclude-members: __init__ :no-special-members: :no-show-inheritance: python-0.20.7/docs/api/mautrix.util/000077500000000000000000000000001473573527000173045ustar00rootroot00000000000000python-0.20.7/docs/api/mautrix.util/async_db.rst000066400000000000000000000001211473573527000216120ustar00rootroot00000000000000async\_db ========= .. automodule:: mautrix.util.async_db :imported-members: python-0.20.7/docs/api/mautrix.util/bridge_state.rst000066400000000000000000000001071473573527000224700ustar00rootroot00000000000000bridge\_state ============= .. automodule:: mautrix.util.bridge_state python-0.20.7/docs/api/mautrix.util/config.rst000066400000000000000000000001111473573527000212740ustar00rootroot00000000000000config ====== .. automodule:: mautrix.util.config :imported-members: python-0.20.7/docs/api/mautrix.util/db.rst000066400000000000000000000002531473573527000204230ustar00rootroot00000000000000db == .. automodule:: mautrix.util.db :imported-members: .. deprecated:: 0.15.0 The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy. python-0.20.7/docs/api/mautrix.util/ffmpeg.rst000066400000000000000000000001111473573527000212730ustar00rootroot00000000000000ffmpeg ====== .. automodule:: mautrix.util.ffmpeg :imported-members: python-0.20.7/docs/api/mautrix.util/file_store.rst000066400000000000000000000001011473573527000221610ustar00rootroot00000000000000file\_store =========== .. automodule:: mautrix.util.file_store python-0.20.7/docs/api/mautrix.util/format_duration.rst000066400000000000000000000001201473573527000232240ustar00rootroot00000000000000format\_duration ================ .. automodule:: mautrix.util.format_duration python-0.20.7/docs/api/mautrix.util/formatter.rst000066400000000000000000000001221473573527000220340ustar00rootroot00000000000000formatter ========= .. automodule:: mautrix.util.formatter :imported-members: python-0.20.7/docs/api/mautrix.util/index.rst000066400000000000000000000005101473573527000211410ustar00rootroot00000000000000mautrix.util ============ .. toctree:: :maxdepth: 4 async_db bridge_state config db ffmpeg file_store formatter format_duration logging magic manhole markdown message_send_checkpoint opt_prometheus program signed_token simple_lock simple_template variation_selector python-0.20.7/docs/api/mautrix.util/logging.rst000066400000000000000000000001141473573527000214600ustar00rootroot00000000000000logging ======= .. automodule:: mautrix.util.logging :imported-members: python-0.20.7/docs/api/mautrix.util/magic.rst000066400000000000000000000000601473573527000211120ustar00rootroot00000000000000magic ===== .. automodule:: mautrix.util.magic python-0.20.7/docs/api/mautrix.util/manhole.rst000066400000000000000000000000661473573527000214630ustar00rootroot00000000000000manhole ======= .. automodule:: mautrix.util.manhole python-0.20.7/docs/api/mautrix.util/markdown.rst000066400000000000000000000000711473573527000216560ustar00rootroot00000000000000markdown ======== .. automodule:: mautrix.util.markdown python-0.20.7/docs/api/mautrix.util/message_send_checkpoint.rst000066400000000000000000000001521473573527000247000ustar00rootroot00000000000000message\_send\_checkpoint ========================= .. automodule:: mautrix.util.message_send_checkpoint python-0.20.7/docs/api/mautrix.util/opt_prometheus.rst000066400000000000000000000015571473573527000231230ustar00rootroot00000000000000opt\_prometheus =============== The opt\_prometheus module contains no-op implementations of prometheus's ``Counter``, ``Gauge``, ``Summary``, ``Histogram``, ``Info`` and ``Enum``, as well as a helper method for timing async methods. It's useful for creating metrics unconditionally without a hard dependency on prometheus\_client. .. attribute:: is_installed A boolean indicating whether ``prometheus_client`` was successfully imported. :type: bool :canonical: mautrix.util.opt_prometheus.is_installed .. decorator:: async_time(metric) Measure the time that each execution of the decorated async function takes. This is equivalent to the ``time`` method-decorator in the metrics, but supports async functions. :param Gauge/Summary/Histogram metric: The metric instance to store the measures in. :canonical: mautrix.util.opt_prometheus.async_time python-0.20.7/docs/api/mautrix.util/program.rst000066400000000000000000000000661473573527000215070ustar00rootroot00000000000000program ======= .. automodule:: mautrix.util.program python-0.20.7/docs/api/mautrix.util/signed_token.rst000066400000000000000000000001071473573527000225050ustar00rootroot00000000000000signed\_token ============= .. automodule:: mautrix.util.signed_token python-0.20.7/docs/api/mautrix.util/simple_lock.rst000066400000000000000000000001041473573527000223320ustar00rootroot00000000000000simple\_lock ============ .. automodule:: mautrix.util.simple_lock python-0.20.7/docs/api/mautrix.util/simple_template.rst000066400000000000000000000001201473573527000232130ustar00rootroot00000000000000simple\_template ================ .. automodule:: mautrix.util.simple_template python-0.20.7/docs/api/mautrix.util/variation_selector.rst000066400000000000000000000001311473573527000237250ustar00rootroot00000000000000variation\_selector =================== .. automodule:: mautrix.util.variation_selector python-0.20.7/docs/conf.py000066400000000000000000000025411473573527000153670ustar00rootroot00000000000000from datetime import datetime import os import sys sys.path.insert(0, os.path.abspath("..")) import mautrix.fixmodule mautrix.__optional_imports__ = True project = "mautrix-python" copyright = f"{datetime.today().year}, Tulir Asokan" author = "Tulir Asokan" release = mautrix.__version__ extensions = [ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", "sphinx.ext.inheritance_diagram", "sphinx.ext.napoleon", ] templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] html_theme = "sphinx_rtd_theme" html_static_path = ["_static"] autodoc_typehints = "description" autodoc_member_order = "bysource" autoclass_content = "class" autodoc_class_signature = "separated" autodoc_type_aliases = { "EventContent": "mautrix.types.EventContent", "StateEventContent": "mautrix.types.StateEventContent", "Event": "mautrix.types.Event", } autodoc_default_options = { "special-members": "__init__", "class-doc-from": "class", "members": True, "undoc-members": True, "show-inheritance": True, } napoleon_google_docstring = True napoleon_numpy_docstring = False napoleon_include_init_with_doc = True intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "aiohttp": ("https://docs.aiohttp.org/en/stable/", None), "yarl": ("https://yarl.readthedocs.io/en/stable/", None), } python-0.20.7/docs/index.rst000066400000000000000000000007121473573527000157270ustar00rootroot00000000000000Welcome to mautrix-python's documentation! ========================================== .. toctree:: :maxdepth: 4 :caption: API reference api/mautrix.api api/mautrix.client.api api/mautrix.client api/mautrix.client.state_store/index api/mautrix.appservice/index api/mautrix.crypto.attachments api/mautrix.types api/mautrix.util/index Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` python-0.20.7/docs/requirements.txt000066400000000000000000000004701473573527000173530ustar00rootroot00000000000000sphinx sphinx-rtd-theme # requirements.txt aiohttp attrs yarl # This is most of optional-requirements.txt, except some things like uvloop # that aren't used for anything that's in the docs python-magic ruamel.yaml SQLAlchemy commonmark asyncpg aiosqlite prometheus_client python-olm unpaddedbase64 pycryptodome python-0.20.7/mautrix/000077500000000000000000000000001473573527000146275ustar00rootroot00000000000000python-0.20.7/mautrix/__init__.py000066400000000000000000000004471473573527000167450ustar00rootroot00000000000000__version__ = "0.20.7" __author__ = "Tulir Asokan " __all__ = [ "api", "appservice", "bridge", "client", "crypto", "errors", "util", "types", "__optional_imports__", ] from typing import TYPE_CHECKING __optional_imports__ = TYPE_CHECKING python-0.20.7/mautrix/api.py000066400000000000000000000454151473573527000157630ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import ClassVar, Literal, Mapping from enum import Enum from json.decoder import JSONDecodeError from urllib.parse import quote as urllib_quote, urljoin as urllib_join import asyncio import inspect import json import logging import platform import time from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version from aiohttp.client_exceptions import ClientError, ContentTypeError from yarl import URL from mautrix import __optional_imports__, __version__ as mautrix_version from mautrix.errors import MatrixConnectionError, MatrixRequestError, make_request_error from mautrix.util.async_body import AsyncBody, async_iter_bytes from mautrix.util.logging import TraceLogger from mautrix.util.opt_prometheus import Counter if __optional_imports__: # Safe to import, but it's not actually needed, so don't force-import the whole types module. from mautrix.types import JSON, DeviceID, UserID API_CALLS = Counter( name="bridge_matrix_api_calls", documentation="The number of Matrix client API calls made", labelnames=("method",), ) API_CALLS_FAILED = Counter( name="bridge_matrix_api_calls_failed", documentation="The number of Matrix client API calls which failed", labelnames=("method",), ) class APIPath(Enum): """ The known Matrix API path prefixes. These don't start with a slash so they can be used nicely with yarl. """ CLIENT = "_matrix/client" MEDIA = "_matrix/media" SYNAPSE_ADMIN = "_synapse/admin" def __repr__(self): return self.value def __str__(self): return self.value class Method(Enum): """A HTTP method.""" GET = "GET" POST = "POST" PUT = "PUT" DELETE = "DELETE" PATCH = "PATCH" def __repr__(self): return self.value def __str__(self): return self.value class PathBuilder: """ A utility class to build API paths. Examples: >>> from mautrix.api import Path >>> room_id = "!foo:example.com" >>> event_id = "$bar:example.com" >>> str(Path.v3.rooms[room_id].event[event_id]) "_matrix/client/v3/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com" """ def __init__(self, path: str | APIPath = "") -> None: self.path: str = str(path) def __str__(self) -> str: return self.path def __repr__(self): return self.path def __getattr__(self, append: str) -> PathBuilder: if append is None: return self return PathBuilder(f"{self.path}/{append}") def raw(self, append: str) -> PathBuilder: """ Directly append a string to the path. Args: append: The string to append. """ if append is None: return self return PathBuilder(self.path + append) def __eq__(self, other: PathBuilder | str) -> bool: return other.path == self.path if isinstance(other, PathBuilder) else other == self.path @staticmethod def _quote(string: str) -> str: return urllib_quote(string, safe="") def __getitem__(self, append: str | int) -> PathBuilder: if append is None: return self return PathBuilder(f"{self.path}/{self._quote(str(append))}") def replace(self, find: str, replace: str) -> PathBuilder: return PathBuilder(self.path.replace(find, replace)) ClientPath = PathBuilder(APIPath.CLIENT) ClientPath.__doc__ = """ A path builder with the standard client prefix ( ``/_matrix/client``, :attr:`APIPath.CLIENT`). """ Path = PathBuilder(APIPath.CLIENT) Path.__doc__ = """A shorter alias for :attr:`ClientPath`""" MediaPath = PathBuilder(APIPath.MEDIA) MediaPath.__doc__ = """ A path builder with the standard media prefix (``/_matrix/media``, :attr:`APIPath.MEDIA`) Examples: >>> from mautrix.api import MediaPath >>> str(MediaPath.v3.config) "_matrix/media/v3/config" """ SynapseAdminPath = PathBuilder(APIPath.SYNAPSE_ADMIN) SynapseAdminPath.__doc__ = """ A path builder for synapse-specific admin API paths (``/_synapse/admin``, :attr:`APIPath.SYNAPSE_ADMIN`) Examples: >>> from mautrix.api import SynapseAdminPath >>> user_id = "@user:example.com" >>> str(SynapseAdminPath.v1.users[user_id]/login) "_synapse/admin/v1/users/%40user%3Aexample.com/login" """ _req_id = 0 def _next_global_req_id() -> int: global _req_id _req_id += 1 return _req_id class HTTPAPI: """HTTPAPI is a simple asyncio Matrix API request sender.""" default_ua: ClassVar[str] = ( f"mautrix-python/{mautrix_version} aiohttp/{aiohttp_version} " f"Python/{platform.python_version()}" ) """ The default value for the ``User-Agent`` header. You should prepend your program name and version here before creating any HTTPAPI instances in order to have proper user agents for all requests. """ global_default_retry_count: ClassVar[int] = 0 """The default retry count to use if an instance-specific value is not passed.""" base_url: URL """The base URL of the homeserver's client-server API to use.""" token: str """The access token to use in requests.""" log: TraceLogger """The :class:`logging.Logger` instance to log requests with.""" session: ClientSession """The aiohttp ClientSession instance to make requests with.""" txn_id: int | None """A counter used for generating transaction IDs.""" default_retry_count: int """The default retry count to use if a custom value is not passed to :meth:`request`""" as_user_id: UserID | None """An optional user ID to set as the user_id query parameter for appservice requests.""" as_device_id: DeviceID | None """ An optional device ID to set as the user_id query parameter for appservice requests (MSC3202). """ def __init__( self, base_url: URL | str, token: str = "", *, client_session: ClientSession = None, default_retry_count: int = None, txn_id: int = 0, log: TraceLogger | None = None, loop: asyncio.AbstractEventLoop | None = None, as_user_id: UserID | None = None, as_device_id: UserID | None = None, ) -> None: """ Args: base_url: The base URL of the homeserver's client-server API to use. token: The access token to use. client_session: The aiohttp client session to use. txn_id: The outgoing transaction ID to start with. log: The :class:`logging.Logger` instance to log requests with. default_retry_count: Default number of retries to do when encountering network errors. as_user_id: An optional user ID to set as the user_id query parameter for appservice requests. as_device_id: An optional device ID to set as the user_id query parameter for appservice requests (MSC3202). """ self.base_url = URL(base_url) self.token = token self.log = log or logging.getLogger("mau.http") self.session = client_session or ClientSession( loop=loop, headers={"User-Agent": self.default_ua} ) self.as_user_id = as_user_id self.as_device_id = as_device_id if txn_id is not None: self.txn_id = txn_id if default_retry_count is not None: self.default_retry_count = default_retry_count else: self.default_retry_count = self.global_default_retry_count async def _send( self, method: Method, url: URL, content: bytes | bytearray | str | AsyncBody, query_params: dict[str, str], headers: dict[str, str], ) -> tuple[JSON, ClientResponse]: request = self.session.request( str(method), url, data=content, params=query_params, headers=headers ) async with request as response: if response.status < 200 or response.status >= 300: errcode = unstable_errcode = message = None try: response_data = await response.json() errcode = response_data["errcode"] message = response_data["error"] unstable_errcode = response_data.get("org.matrix.msc3848.unstable.errcode") except (JSONDecodeError, ContentTypeError, KeyError): pass raise make_request_error( http_status=response.status, text=await response.text(), errcode=errcode, message=message, unstable_errcode=unstable_errcode, ) return await response.json(), response def _log_request( self, method: Method, url: URL, content: str | bytes | bytearray | AsyncBody | None, orig_content, query_params: dict[str, str], headers: dict[str, str], req_id: int, sensitive: bool, ) -> None: if not self.log: return if isinstance(content, (bytes, bytearray)): log_content = f"<{len(content)} bytes>" elif inspect.isasyncgen(content): size = headers.get("Content-Length", None) log_content = f"<{size} async bytes>" if size else f"" elif sensitive: log_content = f"<{len(content)} sensitive bytes>" else: log_content = content as_user = query_params.get("user_id", None) level = 5 if url.path.endswith("/v3/sync") else 10 self.log.log( level, f"req #{req_id}: {method} {url} {log_content}".strip(" "), extra={ "matrix_http_request": { "req_id": req_id, "method": str(method), "url": str(url), "content": ( orig_content if isinstance(orig_content, (dict, list)) and not sensitive else log_content ), "user": as_user, } }, ) def _log_request_done( self, path: PathBuilder | str, req_id: int, duration: float, status: int ) -> None: level = 5 if path == Path.v3.sync else 10 duration_str = f"{duration * 1000:.1f}ms" if duration < 1 else f"{duration:.3f}s" path_without_prefix = f"/{path}".replace("/_matrix/client", "") self.log.log( level, f"req #{req_id} ({path_without_prefix}) completed in {duration_str} " f"with status {status}", ) def _full_path(self, path: PathBuilder | str) -> str: path = str(path) if path and path[0] == "/": path = path[1:] base_path = self.base_url.raw_path if base_path[-1] != "/": base_path += "/" return urllib_join(base_path, path) def log_download_request(self, url: URL, query_params: dict[str, str]) -> int: req_id = _next_global_req_id() self._log_request(Method.GET, url, None, None, query_params, {}, req_id, False) return req_id def log_download_request_done( self, url: URL, req_id: int, duration: float, status: int ) -> None: self._log_request_done(url.path.removeprefix("/_matrix/media/"), req_id, duration, status) async def request( self, method: Method, path: PathBuilder | str, content: dict | list | bytes | bytearray | str | AsyncBody | None = None, headers: dict[str, str] | None = None, query_params: Mapping[str, str] | None = None, retry_count: int | None = None, metrics_method: str = "", min_iter_size: int = 25 * 1024 * 1024, sensitive: bool = False, ) -> JSON: """ Make a raw Matrix API request. Args: method: The HTTP method to use. path: The full API endpoint to call (including the _matrix/... prefix) content: The content to post as a dict/list (will be serialized as JSON) or bytes/str (will be sent as-is). headers: A dict of HTTP headers to send. If the headers don't contain ``Content-Type``, it'll be set to ``application/json``. The ``Authorization`` header is always overridden if :attr:`token` is set. query_params: A dict of query parameters to send. retry_count: Number of times to retry if the homeserver isn't reachable. Defaults to :attr:`default_retry_count`. metrics_method: Name of the method to include in Prometheus timing metrics. min_iter_size: If the request body is larger than this value, it will be passed to aiohttp as an async iterable to stop it from copying the whole thing in memory. sensitive: If True, the request content will not be logged. Returns: The parsed response JSON. """ headers = headers or {} if self.token: headers["Authorization"] = f"Bearer {self.token}" query_params = query_params or {} if isinstance(query_params, dict): query_params = {k: v for k, v in query_params.items() if v is not None} if self.as_user_id: query_params["user_id"] = self.as_user_id if self.as_device_id: query_params["org.matrix.msc3202.device_id"] = self.as_device_id query_params["device_id"] = self.as_device_id if method != Method.GET: content = content or {} if "Content-Type" not in headers: headers["Content-Type"] = "application/json" orig_content = content is_json = headers.get("Content-Type", None) == "application/json" if is_json and isinstance(content, (dict, list)): content = json.dumps(content) else: orig_content = content = None full_url = self.base_url.with_path(self._full_path(path), encoded=True) req_id = _next_global_req_id() if retry_count is None: retry_count = self.default_retry_count if inspect.isasyncgen(content): # Can't retry with non-static body retry_count = 0 do_fake_iter = content and hasattr(content, "__len__") and len(content) > min_iter_size if do_fake_iter: headers["Content-Length"] = str(len(content)) backoff = 4 log_url = full_url.with_query(query_params) while True: self._log_request( method, log_url, content, orig_content, query_params, headers, req_id, sensitive ) API_CALLS.labels(method=metrics_method).inc() req_content = async_iter_bytes(content) if do_fake_iter else content start = time.monotonic() try: resp_data, resp = await self._send( method, full_url, req_content, query_params, headers or {} ) self._log_request_done(path, req_id, time.monotonic() - start, resp.status) return resp_data except MatrixRequestError as e: API_CALLS_FAILED.labels(method=metrics_method).inc() if retry_count > 0 and e.http_status in (502, 503, 504): self.log.warning( f"Request #{req_id} failed with HTTP {e.http_status}, " f"retrying in {backoff} seconds" ) else: self._log_request_done(path, req_id, time.monotonic() - start, e.http_status) raise except ClientError as e: API_CALLS_FAILED.labels(method=metrics_method).inc() if retry_count > 0: self.log.warning( f"Request #{req_id} failed with {e}, retrying in {backoff} seconds" ) else: raise MatrixConnectionError(str(e)) from e except Exception: API_CALLS_FAILED.labels(method=metrics_method).inc() raise await asyncio.sleep(backoff) backoff *= 2 retry_count -= 1 def get_txn_id(self) -> str: """Get a new unique transaction ID.""" self.txn_id += 1 return f"mautrix-python_{time.time_ns()}_{self.txn_id}" def get_download_url( self, mxc_uri: str, download_type: Literal["download", "thumbnail"] = "download", file_name: str | None = None, authenticated: bool = False, ) -> URL: """ Get the full HTTP URL to download a ``mxc://`` URI. Args: mxc_uri: The MXC URI whose full URL to get. download_type: The type of download ("download" or "thumbnail"). file_name: Optionally, a file name to include in the download URL. authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11. Returns: The full HTTP URL. Raises: ValueError: If `mxc_uri` doesn't begin with ``mxc://``. Examples: >>> api = HTTPAPI(base_url="https://matrix-client.matrix.org", ...) >>> api.get_download_url("mxc://matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6") "https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6" >>> api.get_download_url("mxc://matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6", file_name="hello.png") "https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png" """ server_name, media_id = self.parse_mxc_uri(mxc_uri) if authenticated: url = self.base_url / str(APIPath.CLIENT) / "v1" / "media" else: url = self.base_url / str(APIPath.MEDIA) / "v3" url = url / download_type / server_name / media_id if file_name: url /= file_name return url @staticmethod def parse_mxc_uri(mxc_uri: str) -> tuple[str, str]: """ Parse a ``mxc://`` URI. Args: mxc_uri: The MXC URI to parse. Returns: A tuple containing the server and media ID of the MXC URI. Raises: ValueError: If `mxc_uri` doesn't begin with ``mxc://``. """ if mxc_uri.startswith("mxc://"): server_name, media_id = mxc_uri[6:].split("/") return server_name, media_id else: raise ValueError("MXC URI did not begin with `mxc://`") python-0.20.7/mautrix/appservice/000077500000000000000000000000001473573527000167705ustar00rootroot00000000000000python-0.20.7/mautrix/appservice/__init__.py000066400000000000000000000006241473573527000211030ustar00rootroot00000000000000from .api import DOUBLE_PUPPET_SOURCE_KEY, AppServiceAPI, ChildAppServiceAPI, IntentAPI from .appservice import AppService from .as_handler import AppServiceServerMixin from .state_store import ASStateStore __all__ = [ "AppService", "AppServiceAPI", "ChildAppServiceAPI", "IntentAPI", "ASStateStore", "AppServiceServerMixin", "DOUBLE_PUPPET_SOURCE_KEY", "state_store", ] python-0.20.7/mautrix/appservice/api/000077500000000000000000000000001473573527000175415ustar00rootroot00000000000000python-0.20.7/mautrix/appservice/api/__init__.py000066400000000000000000000001621473573527000216510ustar00rootroot00000000000000from .appservice import AppServiceAPI, ChildAppServiceAPI from .intent import DOUBLE_PUPPET_SOURCE_KEY, IntentAPI python-0.20.7/mautrix/appservice/api/appservice.py000066400000000000000000000246501473573527000222630ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Awaitable from datetime import datetime, timezone import asyncio from aiohttp import ClientSession from yarl import URL from mautrix.api import HTTPAPI, Method, PathBuilder from mautrix.types import UserID from mautrix.util.logging import TraceLogger from .. import api as as_api, state_store as ss class AppServiceAPI(HTTPAPI): """ AppServiceAPI is an extension to HTTPAPI that provides appservice-specific features, such as child instances and easy access to IntentAPIs. """ base_log: TraceLogger identity: UserID | None bot_mxid: UserID state_store: ss.ASStateStore txn_id: int children: dict[str, ChildAppServiceAPI] real_users: dict[str, AppServiceAPI] is_real_user: bool bridge_name: str | None _bot_intent: as_api.IntentAPI | None def __init__( self, base_url: URL | str, bot_mxid: UserID = None, token: str = None, identity: UserID | None = None, log: TraceLogger = None, state_store: ss.ASStateStore = None, client_session: ClientSession = None, child: bool = False, real_user: bool = False, real_user_as_token: bool = False, bridge_name: str | None = None, default_retry_count: int = None, loop: asyncio.AbstractEventLoop | None = None, ) -> None: """ Args: base_url: The base URL of the homeserver client-server API to use. bot_mxid: The Matrix user ID of the appservice bot. token: The access token to use. identity: The ID of the Matrix user to act as. log: The logging.Logger instance to log requests with. state_store: The StateStore instance to use. client_session: The aiohttp ClientSession to use. child: Whether or not this is instance is a child of another AppServiceAPI. real_user: Whether or not this is a real (non-appservice-managed) user. real_user_as_token: Whether this real user is actually using another ``as_token``. bridge_name: The name of the bridge to put in the ``fi.mau.double_puppet_source`` field in outgoing message events sent through real users. """ self.base_log = log api_log = self.base_log.getChild("api").getChild(identity or "bot") super().__init__( base_url=base_url, token=token, loop=loop, log=api_log, client_session=client_session, txn_id=0 if not child else None, default_retry_count=default_retry_count, ) self.identity = identity self.bot_mxid = bot_mxid self._bot_intent = None self.state_store = state_store self.is_real_user = real_user self.is_real_user_as_token = real_user_as_token self.bridge_name = bridge_name if not child: self.txn_id = 0 if not real_user: self.children = {} self.real_users = {} def user(self, user: UserID) -> ChildAppServiceAPI: """ Get the AppServiceAPI for an appservice-managed user. Args: user: The Matrix user ID of the user whose AppServiceAPI to get. Returns: The ChildAppServiceAPI object for the user. """ if self.is_real_user: raise ValueError("Can't get child of real user") try: return self.children[user] except KeyError: child = ChildAppServiceAPI(user, self) self.children[user] = child return child def real_user( self, mxid: UserID, token: str, base_url: URL | None = None, as_token: bool = False ) -> AppServiceAPI: """ Get the AppServiceAPI for a real (non-appservice-managed) Matrix user. Args: mxid: The Matrix user ID of the user whose AppServiceAPI to get. token: The access token for the user. base_url: The base URL of the homeserver client-server API to use. Defaults to the appservice homeserver URL. as_token: Whether the token is actually an as_token (meaning the ``user_id`` query parameter needs to be used). Returns: The AppServiceAPI object for the user. Raises: ValueError: When this AppServiceAPI instance is a real user. """ if self.is_real_user: raise ValueError("Can't get child of real user") try: child = self.real_users[mxid] child.base_url = base_url or child.base_url child.token = token or child.token child.is_real_user_as_token = as_token except KeyError: child = type(self)( base_url=base_url or self.base_url, token=token, identity=mxid, log=self.base_log, state_store=self.state_store, client_session=self.session, real_user=True, real_user_as_token=as_token, bridge_name=self.bridge_name, default_retry_count=self.default_retry_count, ) self.real_users[mxid] = child return child def bot_intent(self) -> as_api.IntentAPI: """ Get the intent API for the appservice bot. Returns: The IntentAPI object for the appservice bot """ if not self._bot_intent: self._bot_intent = as_api.IntentAPI(self.bot_mxid, self, state_store=self.state_store) return self._bot_intent def intent( self, user: UserID = None, token: str | None = None, base_url: str | None = None, real_user_as_token: bool = False, ) -> as_api.IntentAPI: """ Get the intent API of a child user. Args: user: The Matrix user ID whose intent API to get. token: The access token to use. Only applicable for non-appservice-managed users. base_url: The base URL of the homeserver client-server API to use. Only applicable for non-appservice users. Defaults to the appservice homeserver URL. real_user_as_token: When providing a token, whether it's actually another as_token (meaning the ``user_id`` query parameter needs to be used). Returns: The IntentAPI object for the given user. Raises: ValueError: When this AppServiceAPI instance is a real user. """ if self.is_real_user: raise ValueError("Can't get child intent of real user") if token: return as_api.IntentAPI( user, self.real_user(user, token, base_url, as_token=real_user_as_token), self.bot_intent(), self.state_store, ) return as_api.IntentAPI(user, self.user(user), self.bot_intent(), self.state_store) def request( self, method: Method, path: PathBuilder, content: dict | bytes | str | None = None, timestamp: int | None = None, headers: dict[str, str] | None = None, query_params: dict[str, Any] | None = None, retry_count: int | None = None, metrics_method: str | None = "", min_iter_size: int = 25 * 1024 * 1024, ) -> Awaitable[dict]: """ Make a raw Matrix API request, acting as the appservice user assigned to this AppServiceAPI instance and optionally including timestamp massaging. Args: method: The HTTP method to use. path: The full API endpoint to call (including the _matrix/... prefix) content: The content to post as a dict/list (will be serialized as JSON) or bytes/str (will be sent as-is). timestamp: The timestamp query param used for timestamp massaging. headers: A dict of HTTP headers to send. If the headers don't contain ``Content-Type``, it'll be set to ``application/json``. The ``Authorization`` header is always overridden if :attr:`token` is set. query_params: A dict of query parameters to send. retry_count: Number of times to retry if the homeserver isn't reachable. Defaults to :attr:`default_retry_count`. metrics_method: Name of the method to include in Prometheus timing metrics. min_iter_size: If the request body is larger than this value, it will be passed to aiohttp as an async iterable to stop it from copying the whole thing in memory. Returns: The parsed response JSON. """ query_params = query_params or {} if timestamp is not None: if isinstance(timestamp, datetime): timestamp = int(timestamp.replace(tzinfo=timezone.utc).timestamp() * 1000) query_params["ts"] = timestamp if not self.is_real_user or self.is_real_user_as_token: query_params["user_id"] = self.identity or self.bot_mxid return super().request( method, path, content, headers, query_params, retry_count, metrics_method ) class ChildAppServiceAPI(AppServiceAPI): """ ChildAppServiceAPI is a simple way to copy AppServiceAPIs while maintaining a shared txn_id. """ parent: AppServiceAPI def __init__(self, user: UserID, parent: AppServiceAPI) -> None: """ Args: user: The Matrix user ID of the child user. parent: The parent AppServiceAPI instance. """ super().__init__( parent.base_url, parent.bot_mxid, parent.token, user, parent.base_log, parent.state_store, parent.session, child=True, bridge_name=parent.bridge_name, default_retry_count=parent.default_retry_count, ) self.parent = parent @property def txn_id(self) -> int: return self.parent.txn_id @txn_id.setter def txn_id(self, value: int) -> None: self.parent.txn_id = value python-0.20.7/mautrix/appservice/api/intent.py000066400000000000000000000667531473573527000214350ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Awaitable, Iterable, TypeVar from urllib.parse import quote as urllib_quote from mautrix.api import Method, Path from mautrix.client import ClientAPI, StoreUpdatingAPI from mautrix.errors import ( IntentError, MAlreadyJoined, MatrixRequestError, MBadState, MForbidden, MNotFound, MUserInUse, ) from mautrix.types import ( JSON, BatchID, BatchSendEvent, BatchSendResponse, BatchSendStateEvent, BeeperBatchSendResponse, ContentURI, EventContent, EventID, EventType, JoinRule, JoinRulesStateEventContent, Member, Membership, PowerLevelStateEventContent, PresenceState, RoomAvatarStateEventContent, RoomID, RoomNameStateEventContent, RoomPinnedEventsStateEventContent, RoomTopicStateEventContent, StateEventContent, UserID, ) from mautrix.util.logging import TraceLogger from .. import api as as_api, state_store as ss def quote(*args, **kwargs): return urllib_quote(*args, **kwargs, safe="") _bridgebot = object() ENSURE_REGISTERED_METHODS = ( # Room methods ClientAPI.create_room, ClientAPI.add_room_alias, ClientAPI.remove_room_alias, ClientAPI.resolve_room_alias, ClientAPI.get_joined_rooms, StoreUpdatingAPI.join_room_by_id, StoreUpdatingAPI.join_room, StoreUpdatingAPI.leave_room, ClientAPI.set_room_directory_visibility, ClientAPI.forget_room, # User data methods ClientAPI.search_users, ClientAPI.set_displayname, ClientAPI.set_avatar_url, ClientAPI.beeper_update_profile, ClientAPI.create_mxc, ClientAPI.upload_media, ClientAPI.send_receipt, ClientAPI.set_fully_read_marker, ) ENSURE_JOINED_METHODS = ( # Room methods StoreUpdatingAPI.invite_user, # Event methods ClientAPI.get_event, StoreUpdatingAPI.get_state_event, StoreUpdatingAPI.get_state, ClientAPI.get_joined_members, ClientAPI.get_messages, StoreUpdatingAPI.send_state_event, ClientAPI.send_message_event, ClientAPI.redact, ) DOUBLE_PUPPET_SOURCE_KEY = "fi.mau.double_puppet_source" T = TypeVar("T") class IntentAPI(StoreUpdatingAPI): """ IntentAPI is a high-level wrapper around the AppServiceAPI that provides many easy-to-use functions for accessing the client-server API. It is designed for appservices and will automatically handle many things like missing invites using the appservice bot. """ api: as_api.AppServiceAPI state_store: ss.ASStateStore bot: IntentAPI log: TraceLogger def __init__( self, mxid: UserID, api: as_api.AppServiceAPI, bot: IntentAPI = None, state_store: ss.ASStateStore = None, ) -> None: super().__init__(mxid=mxid, api=api, state_store=state_store) self.bot = bot if bot is not None: self.versions_cache = bot.versions_cache self.log = api.base_log.getChild("intent") for method in ENSURE_REGISTERED_METHODS: method = getattr(self, method.__name__) async def wrapper(*args, __self=self, __method=method, **kwargs): await __self.ensure_registered() return await __method(*args, **kwargs) setattr(self, method.__name__, wrapper) for method in ENSURE_JOINED_METHODS: method = getattr(self, method.__name__) async def wrapper(*args, __self=self, __method=method, **kwargs): room_id = kwargs.get("room_id", None) if not room_id: room_id = args[0] ensure_joined = kwargs.pop("ensure_joined", True) if ensure_joined: await __self.ensure_joined(room_id) return await __method(*args, **kwargs) setattr(self, method.__name__, wrapper) def user( self, user_id: UserID, token: str | None = None, base_url: str | None = None, as_token: bool = False, ) -> IntentAPI: """ Get the intent API for a specific user. This is just a proxy to :meth:`AppServiceAPI.intent`. You should only call this method for the bot user. Calling it with child intent APIs will result in a warning log. Args: user_id: The Matrix ID of the user whose intent API to get. token: The access token to use for the Matrix ID. base_url: An optional URL to use for API requests. as_token: Whether the provided token is actually another as_token (meaning the ``user_id`` query parameter needs to be used). Returns: The IntentAPI for the given user. """ if not self.bot: return self.api.intent(user_id, token, base_url, real_user_as_token=as_token) else: self.log.warning("Called IntentAPI#user() of child intent object.") return self.bot.api.intent(user_id, token, base_url, real_user_as_token=as_token) # region User actions async def set_presence( self, presence: PresenceState = PresenceState.ONLINE, status: str | None = None, ignore_cache: bool = False, ): """ Set the online status of the user. See also: `API reference `__ Args: presence: The online status of the user. status: The status message. ignore_cache: Whether to set presence even if the cache says the presence is already set to that value. """ await self.ensure_registered() if not ignore_cache and self.state_store.has_presence(self.mxid, status): return await super().set_presence(presence, status) self.state_store.set_presence(self.mxid, status) # endregion # region Room actions def _add_source_key(self, content: T = None) -> T: if self.api.is_real_user and self.api.bridge_name: if not content: content = {} content[DOUBLE_PUPPET_SOURCE_KEY] = self.api.bridge_name return content async def invite_user( self, room_id: RoomID, user_id: UserID, reason: str | None = None, check_cache: bool = False, extra_content: dict[str, Any] | None = None, ) -> None: """ Invite a user to participate in a particular room. They do not start participating in the room until they actually join the room. Only users currently in the room can invite other users to join that room. If the user was invited to the room, the homeserver will add a `m.room.member`_ event to the room. See also: `API reference `__ .. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember Args: room_id: The ID of the room to which to invite the user. user_id: The fully qualified user ID of the invitee. reason: The reason the user was invited. This will be supplied as the ``reason`` on the `m.room.member`_ event. check_cache: If ``True``, the function will first check the state store, and not do anything if the state store says the user is already invited or joined. extra_content: Additional properties for the invite event content. If a non-empty dict is passed, the invite event will be created using the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /invite``. """ try: ok_states = (Membership.INVITE, Membership.JOIN) do_invite = not check_cache or ( await self.state_store.get_membership(room_id, user_id) not in ok_states ) if do_invite: extra_content = self._add_source_key(extra_content) await super().invite_user( room_id, user_id, reason=reason, extra_content=extra_content ) await self.state_store.invited(room_id, user_id) except MAlreadyJoined as e: await self.state_store.joined(room_id, user_id) except MatrixRequestError as e: # TODO remove this once MSC3848 is released and minimum spec version is bumped if e.errcode == "M_FORBIDDEN" and ( "already in the room" in e.message or "is already joined to room" in e.message ): await self.state_store.joined(room_id, user_id) else: raise async def kick_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: extra_content = self._add_source_key(extra_content) await super().kick_user(room_id, user_id, reason=reason, extra_content=extra_content) async def ban_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: extra_content = self._add_source_key(extra_content) await super().ban_user(room_id, user_id, reason=reason, extra_content=extra_content) async def unban_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: extra_content = self._add_source_key(extra_content) await super().unban_user(room_id, user_id, reason=reason, extra_content=extra_content) async def join_room_by_id( self, room_id: RoomID, third_party_signed: JSON = None, extra_content: dict[str, JSON] | None = None, ) -> RoomID: extra_content = self._add_source_key(extra_content) return await super().join_room_by_id( room_id, third_party_signed=third_party_signed, extra_content=extra_content ) async def leave_room( self, room_id: RoomID, reason: str | None = None, extra_content: dict[str, JSON] | None = None, raise_not_in_room: bool = False, ) -> None: extra_content = self._add_source_key(extra_content) await super().leave_room(room_id, reason, extra_content, raise_not_in_room) def set_room_avatar( self, room_id: RoomID, avatar_url: ContentURI | None, **kwargs ) -> Awaitable[EventID]: content = RoomAvatarStateEventContent(url=avatar_url) content = self._add_source_key(content) return self.send_state_event(room_id, EventType.ROOM_AVATAR, content, **kwargs) def set_room_name(self, room_id: RoomID, name: str, **kwargs) -> Awaitable[EventID]: content = RoomNameStateEventContent(name=name) content = self._add_source_key(content) return self.send_state_event(room_id, EventType.ROOM_NAME, content, **kwargs) def set_room_topic(self, room_id: RoomID, topic: str, **kwargs) -> Awaitable[EventID]: content = RoomTopicStateEventContent(topic=topic) content = self._add_source_key(content) return self.send_state_event(room_id, EventType.ROOM_TOPIC, content, **kwargs) async def get_power_levels( self, room_id: RoomID, ignore_cache: bool = False, ensure_joined: bool = True ) -> PowerLevelStateEventContent: if ensure_joined: await self.ensure_joined(room_id) if not ignore_cache: levels = await self.state_store.get_power_levels(room_id) if levels: return levels try: levels = await self.get_state_event(room_id, EventType.ROOM_POWER_LEVELS) except MNotFound: levels = PowerLevelStateEventContent() except MForbidden: if not ensure_joined: return PowerLevelStateEventContent() raise await self.state_store.set_power_levels(room_id, levels) return levels async def set_power_levels( self, room_id: RoomID, content: PowerLevelStateEventContent, **kwargs ) -> EventID: content = self._add_source_key(content) response = await self.send_state_event( room_id, EventType.ROOM_POWER_LEVELS, content, **kwargs ) if response: await self.state_store.set_power_levels(room_id, content) return response async def get_pinned_messages(self, room_id: RoomID) -> list[EventID]: await self.ensure_joined(room_id) try: content = await self.get_state_event(room_id, EventType.ROOM_PINNED_EVENTS) except MNotFound: return [] return content["pinned"] def set_pinned_messages( self, room_id: RoomID, events: list[EventID], **kwargs ) -> Awaitable[EventID]: content = RoomPinnedEventsStateEventContent(pinned=events) content = self._add_source_key(content) return self.send_state_event(room_id, EventType.ROOM_PINNED_EVENTS, content, **kwargs) async def pin_message(self, room_id: RoomID, event_id: EventID) -> None: events = await self.get_pinned_messages(room_id) if event_id not in events: events.append(event_id) await self.set_pinned_messages(room_id, events) async def unpin_message(self, room_id: RoomID, event_id: EventID): events = await self.get_pinned_messages(room_id) if event_id in events: events.remove(event_id) await self.set_pinned_messages(room_id, events) async def set_join_rule(self, room_id: RoomID, join_rule: JoinRule, **kwargs): content = JoinRulesStateEventContent(join_rule=join_rule) content = self._add_source_key(content) await self.send_state_event(room_id, EventType.ROOM_JOIN_RULES, content, **kwargs) async def get_room_displayname( self, room_id: RoomID, user_id: UserID, ignore_cache=False ) -> str: return (await self.get_room_member_info(room_id, user_id, ignore_cache)).displayname async def get_room_avatar_url( self, room_id: RoomID, user_id: UserID, ignore_cache=False ) -> str: return (await self.get_room_member_info(room_id, user_id, ignore_cache)).avatar_url async def get_room_member_info( self, room_id: RoomID, user_id: UserID, ignore_cache=False ) -> Member: member = await self.state_store.get_member(room_id, user_id) if not member or not member.membership or ignore_cache: member = await self.get_state_event(room_id, EventType.ROOM_MEMBER, user_id) return member async def set_typing( self, room_id: RoomID, timeout: int = 0, ) -> None: await self.ensure_joined(room_id) await super().set_typing(room_id, timeout) async def error_and_leave( self, room_id: RoomID, text: str | None = None, html: str | None = None ) -> None: await self.ensure_joined(room_id) await self.send_notice(room_id, text, html=html) await self.leave_room(room_id) async def send_message_event( self, room_id: RoomID, event_type: EventType, content: EventContent, **kwargs ) -> EventID: await self._ensure_has_power_level_for(room_id, event_type) content = self._add_source_key(content) return await super().send_message_event(room_id, event_type, content, **kwargs) async def redact( self, room_id: RoomID, event_id: EventID, reason: str | None = None, extra_content: dict[str, JSON] | None = None, **kwargs, ) -> EventID: await self._ensure_has_power_level_for(room_id, EventType.ROOM_REDACTION) extra_content = self._add_source_key(extra_content) return await super().redact( room_id, event_id, reason, extra_content=extra_content, **kwargs ) async def send_state_event( self, room_id: RoomID, event_type: EventType, content: StateEventContent | dict[str, JSON], state_key: str = "", **kwargs, ) -> EventID: await self._ensure_has_power_level_for(room_id, event_type, state_key=state_key) content = self._add_source_key(content) return await super().send_state_event(room_id, event_type, content, state_key, **kwargs) async def get_room_members( self, room_id: RoomID, allowed_memberships: tuple[Membership, ...] = (Membership.JOIN,) ) -> list[UserID]: if len(allowed_memberships) == 1 and allowed_memberships[0] == Membership.JOIN: memberships = await self.get_joined_members(room_id) return list(memberships.keys()) member_events = await self.get_members(room_id) return [ UserID(evt.state_key) for evt in member_events if evt.content.membership in allowed_memberships ] async def mark_read( self, room_id: RoomID, event_id: EventID, extra_content: dict[str, JSON] | None = None ) -> None: if self.state_store.get_read(room_id, self.mxid) != event_id: if self.api.is_real_user and self.api.bridge_name: if not extra_content: extra_content = {} double_puppet_indicator = { DOUBLE_PUPPET_SOURCE_KEY: self.api.bridge_name, } extra_content.update( { "com.beeper.fully_read.extra": double_puppet_indicator, "com.beeper.read.extra": double_puppet_indicator, } ) await self.set_fully_read_marker( room_id, fully_read=event_id, read_receipt=event_id, extra_content=extra_content, ) self.state_store.set_read(room_id, self.mxid, event_id) async def appservice_ping(self, appservice_id: str, txn_id: str | None = None) -> int: resp = await self.api.request( Method.POST, Path.v1.appservice[appservice_id].ping, content={"transaction_id": txn_id} if txn_id is not None else {}, ) return resp.get("duration_ms") or -1 async def batch_send( self, room_id: RoomID, prev_event_id: EventID, *, batch_id: BatchID | None = None, events: Iterable[BatchSendEvent], state_events_at_start: Iterable[BatchSendStateEvent] = (), beeper_new_messages: bool = False, beeper_mark_read_by: UserID | None = None, ) -> BatchSendResponse: """ Send a batch of historical events into a room. See `MSC2716`_ for more info. .. _MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 .. versionadded:: v0.12.5 .. deprecated:: v0.20.3 MSC2716 was abandoned by upstream and Beeper has forked the endpoint. Args: room_id: The room ID to send the events to. prev_event_id: The anchor event. The batch will be inserted immediately after this event. batch_id: The batch ID for sending a continuation of an earlier batch. If provided, the new batch will be inserted between the prev event and the previous batch. events: The events to send. state_events_at_start: The state events to send at the start of the batch. These will be sent as outlier events, which means they won't be a part of the actual room state. beeper_new_messages: Custom flag to tell the server that the messages can be sent to the end of the room as normal messages instead of history. Returns: All the event IDs generated, plus a batch ID that can be passed back to this method. """ path = Path.unstable["org.matrix.msc2716"].rooms[room_id].batch_send query: JSON = {"prev_event_id": prev_event_id} if batch_id: query["batch_id"] = batch_id if beeper_new_messages: query["com.beeper.new_messages"] = "true" if beeper_mark_read_by: query["com.beeper.mark_read_by"] = beeper_mark_read_by resp = await self.api.request( Method.POST, path, query_params=query, content={ "events": [evt.serialize() for evt in events], "state_events_at_start": [evt.serialize() for evt in state_events_at_start], }, ) return BatchSendResponse.deserialize(resp) async def beeper_batch_send( self, room_id: RoomID, events: Iterable[BatchSendEvent], *, forward: bool = False, forward_if_no_messages: bool = False, send_notification: bool = False, mark_read_by: UserID | None = None, ) -> BeeperBatchSendResponse: """ Send a batch of events into a room. Only for Beeper/hungryserv. .. versionadded:: v0.20.3 Args: room_id: The room ID to send the events to. events: The events to send. forward: Send events to the end of the room instead of the beginning forward_if_no_messages: Send events to the end of the room, but only if there are no messages in the room. If there are messages, send the new messages to the beginning. send_notification: Send a push notification for the new messages. Only applies when sending to the end of the room. mark_read_by: Send a read receipt from the given user ID atomically. Returns: All the event IDs generated. """ body = { "events": [evt.serialize() for evt in events], } if forward: body["forward"] = forward elif forward_if_no_messages: body["forward_if_no_messages"] = forward_if_no_messages if send_notification: body["send_notification"] = send_notification if mark_read_by: body["mark_read_by"] = mark_read_by resp = await self.api.request( Method.POST, Path.unstable["com.beeper.backfill"].rooms[room_id].batch_send, content=body, ) return BeeperBatchSendResponse.deserialize(resp) async def beeper_delete_room(self, room_id: RoomID) -> None: versions = await self.versions() if not versions.supports("com.beeper.room_yeeting"): raise RuntimeError("Homeserver does not support yeeting rooms") await self.api.request(Method.POST, Path.unstable["com.beeper.yeet"].rooms[room_id].delete) # endregion # region Ensure functions async def ensure_joined( self, room_id: RoomID, ignore_cache: bool = False, bot: IntentAPI | None = _bridgebot ) -> bool: """ Ensure the user controlled by this intent is joined to the given room. If the user is not in the room and the room is invite-only or the user is banned, this will first invite and/or unban the user using the bridge bot account. Args: room_id: The room to join. ignore_cache: Should the Matrix state store be checked first? If ``False`` and the store says the user is in the room, no requests will be made. bot: An optional override account to use as the bridge bot. This is useful if you know the bridge bot is not an admin in the room, but some other ghost user is. Returns: ``False`` if the cache said the user is already in the room, ``True`` if the user was successfully added to the room just now. """ if not room_id: raise ValueError("Room ID not given") if not ignore_cache and await self.state_store.is_joined(room_id, self.mxid): return False if bot is _bridgebot: bot = self.bot if bot is self: bot = None await self.ensure_registered() try: await self.join_room(room_id, max_retries=0) await self.state_store.joined(room_id, self.mxid) except MForbidden as e: if not bot: raise IntentError(f"Failed to join room {room_id} as {self.mxid}") from e try: await bot.invite_user(room_id, self.mxid) await self.join_room(room_id, max_retries=0) await self.state_store.joined(room_id, self.mxid) except MatrixRequestError as e2: raise IntentError(f"Failed to join room {room_id} as {self.mxid}") from e2 except MBadState as e: if not bot: raise IntentError(f"Failed to join room {room_id} as {self.mxid}") from e try: await bot.unban_user(room_id, self.mxid) await bot.invite_user(room_id, self.mxid) await self.join_room(room_id, max_retries=0) await self.state_store.joined(room_id, self.mxid) except MatrixRequestError as e2: raise IntentError(f"Failed to join room {room_id} as {self.mxid}") from e2 except MatrixRequestError as e: raise IntentError(f"Failed to join room {room_id} as {self.mxid}") from e return True def _register(self) -> Awaitable[dict]: content = { "username": self.localpart, "type": "m.login.application_service", "inhibit_login": True, } query_params = {"kind": "user"} return self.api.request(Method.POST, Path.v3.register, content, query_params=query_params) async def ensure_registered(self) -> None: """ Ensure the user controlled by this intent has been registered on the homeserver. This will always check the state store first, but the ``M_USER_IN_USE`` error will also be silently ignored, so it's fine if the state store isn't accurate. However, if using double puppeting, the state store should always return ``True`` for those users. """ if await self.state_store.is_registered(self.mxid): return try: await self._register() except MUserInUse: pass await self.state_store.registered(self.mxid) async def _ensure_has_power_level_for( self, room_id: RoomID, event_type: EventType, state_key: str = "" ) -> None: if not room_id: raise ValueError("Room ID not given") elif not event_type: raise ValueError("Event type not given") if event_type == EventType.ROOM_MEMBER: # TODO: if state_key doesn't equal self.mxid, check invite/kick/ban permissions return if not await self.state_store.has_power_levels_cached(room_id): # TODO add option to not try to fetch power levels from server await self.get_power_levels(room_id, ignore_cache=True, ensure_joined=False) if not await self.state_store.has_power_level(room_id, self.mxid, event_type): # TODO implement something better raise IntentError( f"Power level of {self.mxid} is not enough for {event_type} in {room_id}" ) # self.log.warning( # f"Power level of {self.mxid} is not enough for {event_type} in {room_id}") # endregion python-0.20.7/mautrix/appservice/appservice.py000066400000000000000000000154551473573527000215150ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. # Partly based on github.com/Cadair/python-appservice-framework (MIT license) from __future__ import annotations from typing import Awaitable, Callable, Optional import asyncio import logging from aiohttp import web import aiohttp from mautrix.types import JSON, RoomAlias, UserID, VersionsResponse from mautrix.util.logging import TraceLogger from ..api import HTTPAPI from .api import AppServiceAPI, IntentAPI from .as_handler import AppServiceServerMixin from .state_store import ASStateStore, FileASStateStore try: import ssl except ImportError: ssl = None QueryFunc = Callable[[web.Request], Awaitable[Optional[web.Response]]] class AppService(AppServiceServerMixin): """The main AppService container.""" server: str domain: str id: str verify_ssl: bool tls_cert: str | None tls_key: str | None as_token: str hs_token: str bot_mxid: UserID default_ua: str default_http_retry_count: int bridge_name: str | None state_store: ASStateStore transactions: set[str] query_user: Callable[[UserID], JSON] query_alias: Callable[[RoomAlias], JSON] ready: bool live: bool loop: asyncio.AbstractEventLoop log: TraceLogger app: web.Application runner: web.AppRunner | None _http_session: aiohttp.ClientSession | None _intent: IntentAPI | None def __init__( self, server: str, domain: str, as_token: str, hs_token: str, bot_localpart: str, id: str, loop: asyncio.AbstractEventLoop | None = None, log: logging.Logger | str | None = None, verify_ssl: bool = True, tls_cert: str | None = None, tls_key: str | None = None, query_user: QueryFunc = None, query_alias: QueryFunc = None, bridge_name: str | None = None, state_store: ASStateStore = None, aiohttp_params: dict = None, ephemeral_events: bool = False, encryption_events: bool = False, default_ua: str = HTTPAPI.default_ua, default_http_retry_count: int = 0, connection_limit: int | None = None, ) -> None: super().__init__(ephemeral_events=ephemeral_events, encryption_events=encryption_events) self.server = server self.domain = domain self.id = id self.verify_ssl = verify_ssl self.tls_cert = tls_cert self.tls_key = tls_key self.connection_limit = connection_limit or 100 self.as_token = as_token self.hs_token = hs_token self.bot_mxid = UserID(f"@{bot_localpart}:{domain}") self.default_ua = default_ua self.default_http_retry_count = default_http_retry_count self.bridge_name = bridge_name if not state_store: file = state_store if isinstance(state_store, str) else "mx-state.json" self.state_store = FileASStateStore(path=file, binary=False) elif isinstance(state_store, ASStateStore): self.state_store = state_store else: raise ValueError(f"Unsupported state store {type(state_store)}") self._http_session = None self._intent = None self.loop = loop or asyncio.get_event_loop() self.log = ( logging.getLogger(log) if isinstance(log, str) else log or logging.getLogger("mau.appservice") ) self.query_user = query_user or self.query_user self.query_alias = query_alias or self.query_alias self.live = True self.ready = False self.app = web.Application(loop=self.loop, **aiohttp_params if aiohttp_params else {}) self.app.router.add_route("GET", "/_matrix/mau/live", self._liveness_probe) self.app.router.add_route("GET", "/_matrix/mau/ready", self._readiness_probe) self.register_routes(self.app) self.matrix_event_handler(self.state_store.update_state) @property def http_session(self) -> aiohttp.ClientSession: if self._http_session is None: raise AttributeError("the http_session attribute can only be used after starting") else: return self._http_session @property def intent(self) -> IntentAPI: if self._intent is None: raise AttributeError("the intent attribute can only be used after starting") else: return self._intent async def __aenter__(self) -> None: await self.start() async def __aexit__(self) -> None: await self.stop() async def start(self, host: str = "127.0.0.1", port: int = 8080) -> None: await self.state_store.open() self.log.debug(f"Starting appservice web server on {host}:{port}") if self.server.startswith("https://") and not self.verify_ssl: connector = aiohttp.TCPConnector(limit=self.connection_limit, verify_ssl=False) else: connector = aiohttp.TCPConnector(limit=self.connection_limit) default_headers = {"User-Agent": self.default_ua} self._http_session = aiohttp.ClientSession( loop=self.loop, connector=connector, headers=default_headers ) self._intent = AppServiceAPI( base_url=self.server, bot_mxid=self.bot_mxid, log=self.log, token=self.as_token, state_store=self.state_store, bridge_name=self.bridge_name, client_session=self._http_session, default_retry_count=self.default_http_retry_count, ).bot_intent() ssl_ctx = None if self.tls_cert and self.tls_key: ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_ctx.load_cert_chain(self.tls_cert, self.tls_key) self.runner = web.AppRunner(self.app) await self.runner.setup() site = web.TCPSite(self.runner, host, port, ssl_context=ssl_ctx) await site.start() async def stop(self) -> None: self.log.debug("Stopping appservice web server") if self.runner: await self.runner.cleanup() self._intent = None if self._http_session: await self._http_session.close() self._http_session = None await self.state_store.close() async def _liveness_probe(self, _: web.Request) -> web.Response: return web.Response(status=200 if self.live else 500, text="{}") async def _readiness_probe(self, _: web.Request) -> web.Response: return web.Response(status=200 if self.ready else 500, text="{}") async def ping_self(self, txn_id: str | None = None) -> int: return await self.intent.appservice_ping(self.id, txn_id=txn_id) python-0.20.7/mautrix/appservice/as_handler.py000066400000000000000000000317251473573527000214520ustar00rootroot00000000000000# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. # Partly based on github.com/Cadair/python-appservice-framework (MIT license) from __future__ import annotations from typing import Any, Awaitable, Callable from json import JSONDecodeError import json import logging from aiohttp import web from mautrix.types import ( JSON, ASToDeviceEvent, DeviceID, DeviceLists, DeviceOTKCount, EphemeralEvent, Event, EventType, RoomAlias, SerializerError, UserID, ) from mautrix.util import background_task HandlerFunc = Callable[[Event], Awaitable] class AppServiceServerMixin: log: logging.Logger hs_token: str ephemeral_events: bool encryption_events: bool synchronous_handlers: bool query_user: Callable[[UserID], JSON] query_alias: Callable[[RoomAlias], JSON] transactions: set[str] event_handlers: list[HandlerFunc] to_device_handler: HandlerFunc | None otk_handler: Callable[[dict[UserID, dict[DeviceID, DeviceOTKCount]]], Awaitable] | None device_list_handler: Callable[[DeviceLists], Awaitable] | None def __init__( self, ephemeral_events: bool = False, encryption_events: bool = False, log: logging.Logger | None = None, hs_token: str | None = None, ) -> None: if log is not None: self.log = log if hs_token is not None: self.hs_token = hs_token self.transactions = set() self.event_handlers = [] self.to_device_handler = None self.otk_handler = None self.device_list_handler = None self.ephemeral_events = ephemeral_events self.encryption_events = encryption_events self.synchronous_handlers = False async def default_query_handler(_): return None self.query_user = default_query_handler self.query_alias = default_query_handler def register_routes(self, app: web.Application) -> None: app.router.add_route( "PUT", "/transactions/{transaction_id}", self._http_handle_transaction ) app.router.add_route("GET", "/rooms/{alias}", self._http_query_alias) app.router.add_route("GET", "/users/{user_id}", self._http_query_user) app.router.add_route( "PUT", "/_matrix/app/v1/transactions/{transaction_id}", self._http_handle_transaction ) app.router.add_route("GET", "/_matrix/app/v1/rooms/{alias}", self._http_query_alias) app.router.add_route("GET", "/_matrix/app/v1/users/{user_id}", self._http_query_user) app.router.add_route("POST", "/_matrix/app/v1/ping", self._http_ping) def _check_token(self, request: web.Request) -> bool: try: token = request.rel_url.query["access_token"] except KeyError: try: token = request.headers["Authorization"].removeprefix("Bearer ") except KeyError: self.log.debug("No access_token nor Authorization header in request") return False if token != self.hs_token: self.log.debug(f"Incorrect hs_token in request") return False return True async def _http_query_user(self, request: web.Request) -> web.Response: if not self._check_token(request): return web.json_response({"error": "Invalid auth token"}, status=401) try: user_id = request.match_info["user_id"] except KeyError: return web.json_response({"error": "Missing user_id parameter"}, status=400) try: response = await self.query_user(user_id) except Exception: self.log.exception("Exception in user query handler") return web.json_response({"error": "Internal appservice error"}, status=500) if not response: return web.json_response({}, status=404) return web.json_response(response) async def _http_query_alias(self, request: web.Request) -> web.Response: if not self._check_token(request): return web.json_response({"error": "Invalid auth token"}, status=401) try: alias = request.match_info["alias"] except KeyError: return web.json_response({"error": "Missing alias parameter"}, status=400) try: response = await self.query_alias(alias) except Exception: self.log.exception("Exception in alias query handler") return web.json_response({"error": "Internal appservice error"}, status=500) if not response: return web.json_response({}, status=404) return web.json_response(response) async def _http_ping(self, request: web.Request) -> web.Response: if not self._check_token(request): raise web.HTTPUnauthorized( content_type="application/json", text=json.dumps({"error": "Invalid auth token", "errcode": "M_UNKNOWN_TOKEN"}), ) try: body = await request.json() except JSONDecodeError: raise web.HTTPBadRequest( content_type="application/json", text=json.dumps({"error": "Body is not JSON", "errcode": "M_NOT_JSON"}), ) txn_id = body.get("transaction_id") self.log.info(f"Received ping from homeserver with transaction ID {txn_id}") return web.json_response({}) @staticmethod def _get_with_fallback( json: dict[str, Any], field: str, unstable_prefix: str, default: Any = None ) -> Any: try: return json.pop(field) except KeyError: try: return json.pop(f"{unstable_prefix}.{field}") except KeyError: return default async def _read_transaction_header(self, request: web.Request) -> tuple[str, dict[str, Any]]: if not self._check_token(request): raise web.HTTPUnauthorized( content_type="application/json", text=json.dumps({"error": "Invalid auth token", "errcode": "M_UNKNOWN_TOKEN"}), ) transaction_id = request.match_info["transaction_id"] if transaction_id in self.transactions: raise web.HTTPOk(content_type="application/json", text="{}") try: return transaction_id, await request.json() except JSONDecodeError: raise web.HTTPBadRequest( content_type="application/json", text=json.dumps({"error": "Body is not JSON", "errcode": "M_NOT_JSON"}), ) async def _http_handle_transaction(self, request: web.Request) -> web.Response: transaction_id, data = await self._read_transaction_header(request) txn_content_log = [] try: events = data.pop("events") if events: txn_content_log.append(f"{len(events)} PDUs") except KeyError: raise web.HTTPBadRequest( content_type="application/json", text=json.dumps( {"error": "Missing events object in body", "errcode": "M_BAD_JSON"} ), ) if self.ephemeral_events: ephemeral = self._get_with_fallback(data, "ephemeral", "de.sorunome.msc2409") if ephemeral: txn_content_log.append(f"{len(ephemeral)} EDUs") else: ephemeral = None if self.encryption_events: to_device = self._get_with_fallback(data, "to_device", "de.sorunome.msc2409") device_lists = DeviceLists.deserialize( self._get_with_fallback(data, "device_lists", "org.matrix.msc3202") ) otk_counts = { user_id: { device_id: DeviceOTKCount.deserialize(count) for device_id, count in devices.items() } for user_id, devices in self._get_with_fallback( data, "device_one_time_keys_count", "org.matrix.msc3202", default={} ).items() } if to_device: txn_content_log.append(f"{len(to_device)} to-device events") if device_lists.changed: txn_content_log.append(f"{len(device_lists.changed)} device list changes") if otk_counts: txn_content_log.append( f"{sum(len(vals) for vals in otk_counts.values())} OTK counts" ) else: otk_counts = {} device_lists = None to_device = None if len(txn_content_log) > 2: txn_content_log = [", ".join(txn_content_log[:-1]), txn_content_log[-1]] if not txn_content_log: txn_description = "nothing?" else: txn_description = " and ".join(txn_content_log) self.log.debug(f"Handling transaction {transaction_id} with {txn_description}") try: output = await self.handle_transaction( transaction_id, events=events, extra_data=data, ephemeral=ephemeral, to_device=to_device, device_lists=device_lists, otk_counts=otk_counts, ) except Exception: self.log.exception("Exception in transaction handler") output = None finally: self.log.debug(f"Finished handling transaction {transaction_id}") self.transactions.add(transaction_id) return web.json_response(output or {}) @staticmethod def _fix_prev_content(raw_event: JSON) -> None: try: if raw_event["unsigned"] is None: del raw_event["unsigned"] except KeyError: pass try: raw_event["unsigned"]["prev_content"] except KeyError: try: raw_event.setdefault("unsigned", {})["prev_content"] = raw_event["prev_content"] except KeyError: pass async def handle_transaction( self, txn_id: str, *, events: list[JSON], extra_data: JSON, ephemeral: list[JSON] | None = None, to_device: list[JSON] | None = None, otk_counts: dict[UserID, dict[DeviceID, DeviceOTKCount]] | None = None, device_lists: DeviceLists | None = None, ) -> JSON: for raw_td in to_device or []: try: td = ASToDeviceEvent.deserialize(raw_td) except SerializerError: self.log.exception("Failed to deserialize to-device event %s", raw_td) else: try: await self.to_device_handler(td) except Exception: self.log.exception("Exception in Matrix to-device event handler") if device_lists and self.device_list_handler: try: await self.device_list_handler(device_lists) except Exception: self.log.exception("Exception in Matrix device list change handler") if otk_counts and self.otk_handler: try: await self.otk_handler(otk_counts) except Exception: self.log.exception("Exception in Matrix OTK count handler") for raw_edu in ephemeral or []: try: edu = EphemeralEvent.deserialize(raw_edu) except SerializerError: self.log.exception("Failed to deserialize ephemeral event %s", raw_edu) else: await self.handle_matrix_event(edu, ephemeral=True) for raw_event in events: try: self._fix_prev_content(raw_event) event = Event.deserialize(raw_event) except SerializerError: self.log.exception("Failed to deserialize event %s", raw_event) else: await self.handle_matrix_event(event) return {} async def handle_matrix_event(self, event: Event, ephemeral: bool = False) -> None: if ephemeral: event.type = event.type.with_class(EventType.Class.EPHEMERAL) elif getattr(event, "state_key", None) is not None: event.type = event.type.with_class(EventType.Class.STATE) else: event.type = event.type.with_class(EventType.Class.MESSAGE) async def try_handle(handler_func: HandlerFunc): try: await handler_func(event) except Exception: self.log.exception("Exception in Matrix event handler") if self.synchronous_handlers: for handler in self.event_handlers: await handler(event) else: for handler in self.event_handlers: background_task.create(try_handle(handler)) def matrix_event_handler(self, func: HandlerFunc) -> HandlerFunc: self.event_handlers.append(func) return func python-0.20.7/mautrix/appservice/state_store/000077500000000000000000000000001473573527000213245ustar00rootroot00000000000000python-0.20.7/mautrix/appservice/state_store/__init__.py000066400000000000000000000001771473573527000234420ustar00rootroot00000000000000from .file import FileASStateStore from .memory import ASStateStore __all__ = ["FileASStateStore", "ASStateStore", "asyncpg"] python-0.20.7/mautrix/appservice/state_store/asyncpg.py000066400000000000000000000011041473573527000233360ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from mautrix.client.state_store.asyncpg import PgStateStore as PgClientStateStore from mautrix.util.async_db import Database from .memory import ASStateStore class PgASStateStore(PgClientStateStore, ASStateStore): def __init__(self, db: Database) -> None: PgClientStateStore.__init__(self, db) ASStateStore.__init__(self) python-0.20.7/mautrix/appservice/state_store/file.py000066400000000000000000000021001473573527000226060ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import IO, Any from pathlib import Path from mautrix.client.state_store import FileStateStore from mautrix.util.file_store import Filer from .memory import ASStateStore class FileASStateStore(FileStateStore, ASStateStore): def __init__( self, path: str | Path | IO, filer: Filer | None = None, binary: bool = True, save_interval: float = 60.0, ) -> None: FileStateStore.__init__(self, path, filer, binary, save_interval) ASStateStore.__init__(self) def serialize(self) -> dict[str, Any]: return { "registered": self._registered, **super().serialize(), } def deserialize(self, data: dict[str, Any]) -> None: self._registered = data["registered"] super().deserialize(data) python-0.20.7/mautrix/appservice/state_store/memory.py000066400000000000000000000043121473573527000232060ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, Optional, Tuple from abc import ABC import time from mautrix.client.state_store import StateStore as ClientStateStore from mautrix.types import EventID, RoomID, UserID class ASStateStore(ClientStateStore, ABC): _presence: Dict[UserID, str] _read: Dict[Tuple[RoomID, UserID], EventID] _registered: Dict[UserID, bool] def __init__(self) -> None: self._registered = {} # Non-persistent storage self._presence = {} self._read = {} async def is_registered(self, user_id: UserID) -> bool: """ Check if a given user is registered. This should always return ``True`` for double puppets, because they're always registered beforehand and shouldn't be attempted to register by the bridge. Args: user_id: The user ID to check. Returns: ``True`` if the user is registered, ``False`` otherwise. """ if not user_id: raise ValueError("user_id is empty") return self._registered.get(user_id, False) async def registered(self, user_id: UserID) -> None: """ Mark the given user ID as registered. Args: user_id: The user ID to mark as registered. """ if not user_id: raise ValueError("user_id is empty") self._registered[user_id] = True def set_presence(self, user_id: UserID, presence: str) -> None: self._presence[user_id] = presence def has_presence(self, user_id: UserID, presence: str) -> bool: try: return self._presence[user_id] == presence except KeyError: return False def set_read(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None: self._read[(room_id, user_id)] = event_id def get_read(self, room_id: RoomID, user_id: UserID) -> Optional[EventID]: try: return self._read[(room_id, user_id)] except KeyError: return None python-0.20.7/mautrix/bridge/000077500000000000000000000000001473573527000160635ustar00rootroot00000000000000python-0.20.7/mautrix/bridge/__init__.py000066400000000000000000000021341473573527000201740ustar00rootroot00000000000000from ..util.async_getter_lock import async_getter_lock from .bridge import Bridge, HomeserverSoftware from .config import BaseBridgeConfig from .custom_puppet import ( AutologinError, CustomPuppetError, CustomPuppetMixin, EncryptionKeysFound, HomeserverURLNotFound, InvalidAccessToken, OnlyLoginSelf, OnlyLoginTrustedDomain, ) from .disappearing_message import AbstractDisappearingMessage from .matrix import BaseMatrixHandler from .notification_disabler import NotificationDisabler from .portal import BasePortal, DMCreateError, IgnoreMatrixInvite, RejectMatrixInvite from .puppet import BasePuppet from .user import BaseUser __all__ = [ "async_getter_lock", "Bridge", "HomeserverSoftware", "BaseBridgeConfig", "AutologinError", "CustomPuppetError", "CustomPuppetMixin", "HomeserverURLNotFound", "InvalidAccessToken", "OnlyLoginSelf", "OnlyLoginTrustedDomain", "AbstractDisappearingMessage", "BaseMatrixHandler", "NotificationDisabler", "BasePortal", "BasePuppet", "BaseUser", "state_store", "commands", ] python-0.20.7/mautrix/bridge/bridge.py000066400000000000000000000317721473573527000177030ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any from abc import ABC, abstractmethod from enum import Enum import asyncio import sys from aiohttp import web from mautrix import __version__ as __mautrix_version__ from mautrix.api import HTTPAPI from mautrix.appservice import AppService, ASStateStore from mautrix.client.state_store.asyncpg import PgStateStore as PgClientStateStore from mautrix.errors import MExclusive, MUnknownToken from mautrix.types import RoomID, UserID from mautrix.util.async_db import Database, DatabaseException, UpgradeTable from mautrix.util.bridge_state import BridgeState, BridgeStateEvent, GlobalBridgeState from mautrix.util.program import Program from .. import bridge as br from .state_store.asyncpg import PgBridgeStateStore try: import uvloop except ImportError: uvloop = None class HomeserverSoftware(Enum): STANDARD = "standard" ASMUX = "asmux" HUNGRY = "hungry" @property def is_hungry(self) -> bool: return self == self.HUNGRY @property def is_asmux(self) -> bool: return self == self.ASMUX class Bridge(Program, ABC): db: Database az: AppService state_store_class: type[ASStateStore] = PgBridgeStateStore state_store: ASStateStore upgrade_table: UpgradeTable config_class: type[br.BaseBridgeConfig] config: br.BaseBridgeConfig matrix_class: type[br.BaseMatrixHandler] matrix: br.BaseMatrixHandler repo_url: str markdown_version: str manhole: br.commands.manhole.ManholeState | None homeserver_software: HomeserverSoftware beeper_network_name: str | None = None beeper_service_name: str | None = None def __init__( self, module: str = None, name: str = None, description: str = None, command: str = None, version: str = None, config_class: type[br.BaseBridgeConfig] = None, matrix_class: type[br.BaseMatrixHandler] = None, state_store_class: type[ASStateStore] = None, ) -> None: super().__init__(module, name, description, command, version, config_class) if matrix_class: self.matrix_class = matrix_class if state_store_class: self.state_store_class = state_store_class self.manhole = None def prepare_arg_parser(self) -> None: super().prepare_arg_parser() self.parser.add_argument( "-g", "--generate-registration", action="store_true", help="generate registration and quit", ) self.parser.add_argument( "-r", "--registration", type=str, default="registration.yaml", metavar="", help=( "the path to save the generated registration to " "(not needed for running the bridge)" ), ) self.parser.add_argument( "--ignore-unsupported-database", action="store_true", help="Run even if the database schema is too new", ) self.parser.add_argument( "--ignore-foreign-tables", action="store_true", help="Run even if the database contains tables from other programs (like Synapse)", ) def preinit(self) -> None: super().preinit() if self.args.generate_registration: self.generate_registration() sys.exit(0) def prepare(self) -> None: if self.config.env: self.log.debug( "Loaded config overrides from environment: %s", list(self.config.env.keys()) ) super().prepare() try: self.homeserver_software = HomeserverSoftware(self.config["homeserver.software"]) except Exception: self.log.fatal("Invalid value for homeserver.software in config") sys.exit(11) self.prepare_db() self.prepare_appservice() self.prepare_bridge() def prepare_config(self) -> None: self.config = self.config_class( self.args.config, self.args.registration, self.base_config_path, env_prefix=self.module.upper(), ) if self.args.generate_registration: self.config._check_tokens = False self.load_and_update_config() def generate_registration(self) -> None: self.config.generate_registration() self.config.save() print(f"Registration generated and saved to {self.config.registration_path}") def make_state_store(self) -> None: if self.state_store_class is None: raise RuntimeError("state_store_class is not set") elif issubclass(self.state_store_class, PgBridgeStateStore): self.state_store = self.state_store_class( self.db, self.get_puppet, self.get_double_puppet ) else: self.state_store = self.state_store_class() def prepare_appservice(self) -> None: self.make_state_store() mb = 1024**2 default_http_retry_count = self.config.get("homeserver.http_retry_count", None) if self.name not in HTTPAPI.default_ua: HTTPAPI.default_ua = f"{self.name}/{self.version} {HTTPAPI.default_ua}" self.az = AppService( server=self.config["homeserver.address"], domain=self.config["homeserver.domain"], verify_ssl=self.config["homeserver.verify_ssl"], connection_limit=self.config["homeserver.connection_limit"], id=self.config["appservice.id"], as_token=self.config["appservice.as_token"], hs_token=self.config["appservice.hs_token"], tls_cert=self.config.get("appservice.tls_cert", None), tls_key=self.config.get("appservice.tls_key", None), bot_localpart=self.config["appservice.bot_username"], ephemeral_events=self.config["appservice.ephemeral_events"], encryption_events=self.config["bridge.encryption.appservice"], default_ua=HTTPAPI.default_ua, default_http_retry_count=default_http_retry_count, log="mau.as", loop=self.loop, state_store=self.state_store, bridge_name=self.name, aiohttp_params={"client_max_size": self.config["appservice.max_body_size"] * mb}, ) self.az.app.router.add_post("/_matrix/app/com.beeper.bridge_state", self.get_bridge_state) def prepare_db(self) -> None: if not hasattr(self, "upgrade_table") or not self.upgrade_table: raise RuntimeError("upgrade_table is not set") self.db = Database.create( self.config["appservice.database"], upgrade_table=self.upgrade_table, db_args=self.config["appservice.database_opts"], owner_name=self.name, ignore_foreign_tables=self.args.ignore_foreign_tables, ) def prepare_bridge(self) -> None: self.matrix = self.matrix_class(bridge=self) def _log_db_error(self, e: Exception) -> None: self.log.critical("Failed to initialize database", exc_info=e) if isinstance(e, DatabaseException) and e.explanation: self.log.info(e.explanation) sys.exit(25) async def start_db(self) -> None: if hasattr(self, "db") and isinstance(self.db, Database): self.log.debug("Starting database...") ignore_unsupported = self.args.ignore_unsupported_database self.db.upgrade_table.allow_unsupported = ignore_unsupported try: await self.db.start() if isinstance(self.state_store, PgClientStateStore): self.state_store.upgrade_table.allow_unsupported = ignore_unsupported await self.state_store.upgrade_table.upgrade(self.db) if self.matrix.e2ee: self.matrix.e2ee.crypto_db.allow_unsupported = ignore_unsupported self.matrix.e2ee.crypto_db.override_pool(self.db) except Exception as e: self._log_db_error(e) async def stop_db(self) -> None: if hasattr(self, "db") and isinstance(self.db, Database): await self.db.stop() async def start(self) -> None: await self.start_db() self.log.debug("Starting appservice...") await self.az.start(self.config["appservice.hostname"], self.config["appservice.port"]) try: await self.matrix.wait_for_connection() except MUnknownToken: self.log.critical( "The as_token was not accepted. Is the registration file installed " "in your homeserver correctly?" ) sys.exit(16) except MExclusive: self.log.critical( "The as_token was accepted, but the /register request was not. " "Are the homeserver domain and username template in the config " "correct, and do they match the values in the registration?" ) sys.exit(16) except Exception: self.log.critical("Failed to check connection to homeserver", exc_info=True) sys.exit(16) await self.matrix.init_encryption() self.add_startup_actions(self.matrix.init_as_bot()) await super().start() self.az.ready = True status_endpoint = self.config["homeserver.status_endpoint"] if status_endpoint and await self.count_logged_in_users() == 0: state = BridgeState(state_event=BridgeStateEvent.UNCONFIGURED).fill() while not await state.send(status_endpoint, self.az.as_token, self.log): await asyncio.sleep(5) async def system_exit(self) -> None: if hasattr(self, "db") and isinstance(self.db, Database): self.log.debug("Stopping database due to SystemExit") await self.db.stop() self.log.debug("Database stopped") elif getattr(self, "db", None): self.log.trace("Database not started at SystemExit") async def stop(self) -> None: if self.manhole: self.manhole.close() self.manhole = None await self.az.stop() await super().stop() if self.matrix.e2ee: await self.matrix.e2ee.stop() await self.stop_db() async def get_bridge_state(self, req: web.Request) -> web.Response: if not self.az._check_token(req): return web.json_response({"error": "Invalid auth token"}, status=401) try: user = await self.get_user(UserID(req.url.query["user_id"]), create=False) except KeyError: user = None if user is None: return web.json_response({"error": "User not found"}, status=404) try: states = await user.get_bridge_states() except NotImplementedError: return web.json_response({"error": "Bridge status not implemented"}, status=501) for state in states: await user.fill_bridge_state(state) global_state = BridgeState(state_event=BridgeStateEvent.RUNNING).fill() evt = GlobalBridgeState( remote_states={state.remote_id: state for state in states}, bridge_state=global_state ) return web.json_response(evt.serialize()) @abstractmethod async def get_user(self, user_id: UserID, create: bool = True) -> br.BaseUser | None: pass @abstractmethod async def get_portal(self, room_id: RoomID) -> br.BasePortal | None: pass @abstractmethod async def get_puppet(self, user_id: UserID, create: bool = False) -> br.BasePuppet | None: pass @abstractmethod async def get_double_puppet(self, user_id: UserID) -> br.BasePuppet | None: pass @abstractmethod def is_bridge_ghost(self, user_id: UserID) -> bool: pass @abstractmethod async def count_logged_in_users(self) -> int: return 0 async def manhole_global_namespace(self, user_id: UserID) -> dict[str, Any]: own_user = await self.get_user(user_id, create=False) try: own_puppet = await own_user.get_puppet() except NotImplementedError: own_puppet = None return { "bridge": self, "manhole": self.manhole, "own_user": own_user, "own_puppet": own_puppet, } @property def manhole_banner_python_version(self) -> str: return f"Python {sys.version} on {sys.platform}" @property def manhole_banner_program_version(self) -> str: return f"{self.name} {self.version} with mautrix-python {__mautrix_version__}" def manhole_banner(self, user_id: UserID) -> str: return ( f"{self.manhole_banner_python_version}\n" f"{self.manhole_banner_program_version}\n\n" f"Manhole opened by {user_id}\n" ) python-0.20.7/mautrix/bridge/commands/000077500000000000000000000000001473573527000176645ustar00rootroot00000000000000python-0.20.7/mautrix/bridge/commands/__init__.py000066400000000000000000000012531473573527000217760ustar00rootroot00000000000000from .handler import ( SECTION_ADMIN, SECTION_AUTH, SECTION_GENERAL, SECTION_RELAY, CommandEvent, CommandHandler, CommandHandlerFunc, CommandProcessor, HelpCacheKey, HelpSection, command_handler, ) from .meta import cancel, help_cmd, unknown_command from . import ( # isort: skip admin, clean_rooms, crypto, delete_portal, login_matrix, manhole, relay, ) __all__ = [ "HelpSection", "HelpCacheKey", "command_handler", "CommandHandler", "CommandProcessor", "CommandHandlerFunc", "CommandEvent", "SECTION_GENERAL", "SECTION_ADMIN", "SECTION_AUTH", "SECTION_RELAY", ] python-0.20.7/mautrix/bridge/commands/admin.py000066400000000000000000000126421473573527000213330ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.errors import IntentError, MatrixRequestError, MForbidden from mautrix.types import ContentURI, EventID, UserID from ... import bridge as br from .handler import SECTION_ADMIN, CommandEvent, command_handler @command_handler( needs_admin=True, needs_auth=False, name="set-pl", help_section=SECTION_ADMIN, help_args="[_mxid_] <_level_>", help_text="Set a temporary power level without affecting the remote platform.", ) async def set_power_level(evt: CommandEvent) -> EventID: try: user_id = UserID(evt.args[0]) except IndexError: return await evt.reply(f"**Usage:** `$cmdprefix+sp set-pl [mxid] `") if user_id.startswith("@"): evt.args.pop(0) else: user_id = evt.sender.mxid try: level = int(evt.args[0]) except (KeyError, IndexError): return await evt.reply("**Usage:** `$cmdprefix+sp set-pl [mxid] `") except ValueError: return await evt.reply("The level must be an integer.") levels = await evt.main_intent.get_power_levels(evt.room_id, ignore_cache=True) levels.users[user_id] = level try: return await evt.main_intent.set_power_levels(evt.room_id, levels) except MForbidden as e: await evt.reply(f"I don't seem to have permission to update power levels: {e.message}") except (MatrixRequestError, IntentError): evt.log.exception("Failed to update power levels") return await evt.reply("Failed to update power levels (see logs for more details)") async def _get_mxid_param( evt: CommandEvent, args: str ) -> tuple[br.BasePuppet | None, EventID | None]: try: user_id = UserID(evt.args[0]) except IndexError: return None, await evt.reply(f"**Usage:** `$cmdprefix+sp {evt.command} {args}`") if user_id.startswith("@") and ":" in user_id: # TODO support parsing mention pills instead of requiring a plaintext mxid puppet = await evt.bridge.get_puppet(user_id) if not puppet: return None, await evt.reply("The given user ID is not a valid ghost user.") evt.args.pop(0) return puppet, None elif evt.is_portal and (puppet := await evt.portal.get_dm_puppet()): return puppet, None return None, await evt.reply( "This is not a private chat portal, you must pass a user ID explicitly." ) @command_handler( needs_admin=True, needs_auth=False, name="set-avatar", help_section=SECTION_ADMIN, help_args="[_mxid_] <_mxc:// uri_>", help_text="Set an avatar for a ghost user.", ) async def set_ghost_avatar(evt: CommandEvent) -> EventID | None: puppet, err = await _get_mxid_param(evt, "[mxid] ") if err: return err try: mxc_uri = ContentURI(evt.args[0]) except IndexError: return await evt.reply("**Usage:** `$cmdprefix+sp set-avatar [mxid] `") if not mxc_uri.startswith("mxc://"): return await evt.reply("The avatar URL must start with `mxc://`") try: return await puppet.default_mxid_intent.set_avatar_url(mxc_uri) except (MatrixRequestError, IntentError): evt.log.exception("Failed to set avatar.") return await evt.reply("Failed to set avatar (see logs for more details).") @command_handler( needs_admin=True, needs_auth=False, name="remove-avatar", help_section=SECTION_ADMIN, help_args="[_mxid_]", help_text="Remove the avatar for a ghost user.", ) async def remove_ghost_avatar(evt: CommandEvent) -> EventID | None: puppet, err = await _get_mxid_param(evt, "[mxid]") if err: return err try: return await puppet.default_mxid_intent.set_avatar_url(ContentURI("")) except (MatrixRequestError, IntentError): evt.log.exception("Failed to remove avatar.") return await evt.reply("Failed to remove avatar (see logs for more details).") @command_handler( needs_admin=True, needs_auth=False, name="set-displayname", help_section=SECTION_ADMIN, help_args="[_mxid_] <_displayname_>", help_text="Set the display name for a ghost user.", ) async def set_ghost_display_name(evt: CommandEvent) -> EventID | None: puppet, err = await _get_mxid_param(evt, "[mxid] ") if err: return err try: return await puppet.default_mxid_intent.set_displayname(" ".join(evt.args)) except (MatrixRequestError, IntentError): evt.log.exception("Failed to set display name.") return await evt.reply("Failed to set display name (see logs for more details).") @command_handler( needs_admin=True, needs_auth=False, name="remove-displayname", help_section=SECTION_ADMIN, help_args="[_mxid_]", help_text="Remove the display name for a ghost user.", ) async def remove_ghost_display_name(evt: CommandEvent) -> EventID | None: puppet, err = await _get_mxid_param(evt, "[mxid]") if err: return err try: return await puppet.default_mxid_intent.set_displayname("") except (MatrixRequestError, IntentError): evt.log.exception("Failed to remove display name.") return await evt.reply("Failed to remove display name (see logs for more details).") python-0.20.7/mautrix/bridge/commands/clean_rooms.py000066400000000000000000000175171473573527000225520ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import List, NamedTuple, Optional, Union from mautrix.appservice import IntentAPI from mautrix.errors import MatrixRequestError from mautrix.types import EventID, EventType, RoomID, UserID from ... import bridge as br from .handler import SECTION_ADMIN, CommandEvent, command_handler class ManagementRoom(NamedTuple): room_id: RoomID user_id: UserID class RoomSearchResults(NamedTuple): management_rooms: List[ManagementRoom] unidentified_rooms: List[RoomID] tombstoned_rooms: List[RoomID] portals: List[br.BasePortal] empty_portals: List[br.BasePortal] async def _find_rooms(bridge: br.Bridge, intent: Optional[IntentAPI] = None) -> RoomSearchResults: results = RoomSearchResults([], [], [], [], []) intent = intent or bridge.az.intent rooms = await intent.get_joined_rooms() for room_id in rooms: portal = await bridge.get_portal(room_id) if not portal: try: tombstone = await intent.get_state_event(room_id, EventType.ROOM_TOMBSTONE) if tombstone and tombstone.replacement_room: results.tombstoned_rooms.append(room_id) continue except MatrixRequestError: pass try: members = await intent.get_room_members(room_id) except MatrixRequestError: members = [] if len(members) == 2: other_member = members[0] if members[0] != intent.mxid else members[1] if bridge.is_bridge_ghost(other_member): results.unidentified_rooms.append(room_id) else: results.management_rooms.append(ManagementRoom(room_id, other_member)) else: results.unidentified_rooms.append(room_id) else: members = await portal.get_authenticated_matrix_users() if len(members) == 0: results.empty_portals.append(portal) else: results.portals.append(portal) return results @command_handler( needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms", help_section=SECTION_ADMIN, help_text="Clean up unused portal/management rooms.", ) async def clean_rooms(evt: CommandEvent) -> EventID: results = await _find_rooms(evt.bridge) reply = ["#### Management rooms (M)"] reply += [ f"{n + 1}. [M{n + 1}](https://matrix.to/#/{room}) (with {other_member}" for n, (room, other_member) in enumerate(results.management_rooms) ] or ["No management rooms found."] reply.append("#### Active portal rooms (A)") reply += [ f"{n + 1}. [A{n + 1}](https://matrix.to/#/{portal.mxid}) " f'(to remote chat "{portal.name}")' for n, portal in enumerate(results.portals) ] or ["No active portal rooms found."] reply.append("#### Unidentified rooms (U)") reply += [ f"{n + 1}. [U{n + 1}](https://matrix.to/#/{room})" for n, room in enumerate(results.unidentified_rooms) ] or ["No unidentified rooms found."] reply.append("#### Tombstoned rooms (T)") reply += [ f"{n + 1}. [T{n + 1}](https://matrix.to/#/{room})" for n, room in enumerate(results.tombstoned_rooms) ] or ["No tombstoned rooms found."] reply.append("#### Inactive portal rooms (I)") reply += [ f'{n}. [I{n}](https://matrix.to/#/{portal.mxid}) (to remote chat "{portal.name}")' for n, portal in enumerate(results.empty_portals) ] or ["No inactive portal rooms found."] reply += [ "#### Usage", ( "To clean the recommended set of rooms (unidentified & inactive portals), " "type `$cmdprefix+sp clean-recommended`" ), "", ( "To clean other groups of rooms, type `$cmdprefix+sp clean-groups ` " "where `letters` are the first letters of the group names (M, A, U, I, T)" ), "", ( "To clean specific rooms, type `$cmdprefix+sp clean-range ` " "where `range` is the range (e.g. `5-21`) prefixed with the first letter of" "the group name. (e.g. `I2-6`)" ), "", ( "Please note that you will have to re-run `$cmdprefix+sp clean-rooms` " "between each use of the commands above." ), ] evt.sender.command_status = { "next": lambda clean_evt: set_rooms_to_clean(clean_evt, results), "action": "Room cleaning", } return await evt.reply("\n".join(reply)) async def set_rooms_to_clean(evt, results: RoomSearchResults) -> None: command = evt.args[0] rooms_to_clean: List[Union[br.BasePortal, RoomID]] = [] if command == "clean-recommended": rooms_to_clean += results.empty_portals rooms_to_clean += results.unidentified_rooms elif command == "clean-groups": if len(evt.args) < 2: return await evt.reply("**Usage:** $cmdprefix+sp clean-groups [M][A][U][I]") groups_to_clean = evt.args[1].upper() if "M" in groups_to_clean: rooms_to_clean += [room_id for (room_id, user_id) in results.management_rooms] if "A" in groups_to_clean: rooms_to_clean += results.portals if "U" in groups_to_clean: rooms_to_clean += results.unidentified_rooms if "I" in groups_to_clean: rooms_to_clean += results.empty_portals if "T" in groups_to_clean: rooms_to_clean += results.tombstoned_rooms elif command == "clean-range": try: clean_range = evt.args[1] group, clean_range = clean_range[0], clean_range[1:] start, end = clean_range.split("-") start, end = int(start), int(end) if group == "M": group = [room_id for (room_id, user_id) in results.management_rooms] elif group == "A": group = results.portals elif group == "U": group = results.unidentified_rooms elif group == "I": group = results.empty_portals elif group == "T": group = results.tombstoned_rooms else: raise ValueError("Unknown group") rooms_to_clean = group[start - 1 : end] except (KeyError, ValueError): return await evt.reply("**Usage:** $cmdprefix+sp clean-range <_M|A|U|I_>") else: return await evt.reply( f"Unknown room cleaning action `{command}`. " "Use `$cmdprefix+sp cancel` to cancel room cleaning." ) evt.sender.command_status = { "next": lambda confirm: execute_room_cleanup(confirm, rooms_to_clean), "action": "Room cleaning", } await evt.reply( f"To confirm cleaning up {len(rooms_to_clean)} rooms, type `$cmdprefix+sp confirm-clean`." ) async def execute_room_cleanup(evt, rooms_to_clean: List[Union[br.BasePortal, RoomID]]) -> None: if len(evt.args) > 0 and evt.args[0] == "confirm-clean": await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. This might take a while.") cleaned = 0 for room in rooms_to_clean: if isinstance(room, br.BasePortal): await room.cleanup_and_delete() cleaned += 1 else: await br.BasePortal.cleanup_room(evt.az.intent, room, "Room deleted") cleaned += 1 evt.sender.command_status = None await evt.reply(f"{cleaned} rooms cleaned up successfully.") else: await evt.reply("Room cleaning cancelled.") python-0.20.7/mautrix/bridge/commands/crypto.py000066400000000000000000000015041473573527000215560ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .handler import SECTION_ADMIN, CommandEvent, command_handler @command_handler( needs_admin=True, needs_auth=False, help_section=SECTION_ADMIN, help_text="Reset the bridge's megolm session in this room", ) async def discard_megolm_session(evt: CommandEvent) -> None: if not evt.bridge.matrix.e2ee: await evt.reply("End-to-bridge encryption is not enabled on this bridge instance") return await evt.bridge.matrix.e2ee.crypto_store.remove_outbound_group_session(evt.room_id) await evt.reply("Successfully removed outbound group session for this room") python-0.20.7/mautrix/bridge/commands/delete_portal.py000066400000000000000000000020361473573527000230620ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .handler import SECTION_ADMIN, CommandEvent, command_handler @command_handler( needs_auth=False, needs_puppeting=False, needs_admin=True, help_section=SECTION_ADMIN, help_text="Remove all users from the current portal room and forget the portal.", ) async def delete_portal(evt: CommandEvent) -> None: if not evt.portal: await evt.reply("This is not a portal room") return await evt.portal.cleanup_and_delete() @command_handler( needs_auth=False, needs_puppeting=False, help_section=SECTION_ADMIN, help_text="Remove puppets from the current portal room and forget the portal.", ) async def unbridge(evt: CommandEvent) -> None: if not evt.portal: await evt.reply("This is not a portal room") return await evt.portal.unbridge() python-0.20.7/mautrix/bridge/commands/handler.py000066400000000000000000000437511473573527000216650ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Awaitable, Callable, NamedTuple, Type import asyncio import logging import time import traceback from mautrix.appservice import AppService, IntentAPI from mautrix.errors import MForbidden from mautrix.types import EventID, MessageEventContent, RoomID from mautrix.util import markdown from mautrix.util.logging import TraceLogger from ... import bridge as br command_handlers: dict[str, CommandHandler] = {} command_aliases: dict[str, CommandHandler] = {} HelpSection = NamedTuple("HelpSection", name=str, order=int, description=str) HelpCacheKey = NamedTuple( "HelpCacheKey", is_management=bool, is_portal=bool, is_admin=bool, is_logged_in=bool ) SECTION_GENERAL = HelpSection("General", 0, "") SECTION_AUTH = HelpSection("Authentication", 10, "") SECTION_ADMIN = HelpSection("Administration", 50, "") SECTION_RELAY = HelpSection("Relay mode management", 15, "") def ensure_trailing_newline(s: str) -> str: """Returns the passed string, but with a guaranteed trailing newline.""" return s + ("" if s[-1] == "\n" else "\n") class CommandEvent: """Holds information about a command issued in a Matrix room. When a Matrix command was issued to the bot, CommandEvent will hold information regarding the event. Attributes: room_id: The id of the Matrix room in which the command was issued. event_id: The id of the matrix event which contained the command. sender: The user who issued the command. command: The issued command. args: Arguments given with the issued command. content: The raw content in the command event. portal: The portal the command was sent to. is_management: Determines whether the room in which the command was issued in is a management room. has_bridge_bot: Whether or not the bridge bot is in the room. """ bridge: bridge.Bridge az: AppService log: TraceLogger loop: asyncio.AbstractEventLoop config: br.BaseBridgeConfig processor: CommandProcessor command_prefix: str room_id: RoomID event_id: EventID sender: br.BaseUser command: str args: list[str] content: MessageEventContent portal: br.BasePortal | None is_management: bool has_bridge_bot: bool def __init__( self, processor: CommandProcessor, room_id: RoomID, event_id: EventID, sender: br.BaseUser, command: str, args: list[str], content: MessageEventContent, portal: br.BasePortal | None, is_management: bool, has_bridge_bot: bool, ) -> None: self.bridge = processor.bridge self.az = processor.az self.log = processor.log self.loop = processor.loop self.config = processor.config self.processor = processor self.command_prefix = processor.command_prefix self.room_id = room_id self.event_id = event_id self.sender = sender self.command = command self.args = args self.content = content self.portal = portal self.is_management = is_management self.has_bridge_bot = has_bridge_bot @property def is_portal(self) -> bool: return self.portal is not None async def get_help_key(self) -> HelpCacheKey: """ Get the help cache key for the given CommandEvent. Help messages are generated dynamically from the CommandHandlers that have been added so that they would only contain relevant commands. The help cache key is tuple-unpacked and passed to :meth:`CommandHandler.has_permission` when generating the help page. After the first generation, the page is cached using the help cache key. If you override this property or :meth:`CommandHandler.has_permission`, make sure to override the other too to handle the changes properly. When you override this property or otherwise extend CommandEvent, remember to pass the extended CommandEvent class when initializing your CommandProcessor. """ return HelpCacheKey( is_management=self.is_management, is_portal=self.portal is not None, is_admin=self.sender.is_admin, is_logged_in=await self.sender.is_logged_in(), ) @property def print_error_traceback(self) -> bool: """ Whether or not the stack traces of unhandled exceptions during the handling of this command should be sent to the user. If false, the error message will simply tell the user to check the logs. Bridges may want to limit tracebacks to bridge admins. """ return self.sender.is_admin @property def main_intent(self) -> IntentAPI: return self.portal.main_intent if self.portal else self.az.intent async def redact(self, reason: str | None = None) -> None: """ Try to redact the command. If the redaction fails with M_FORBIDDEN, the error will be logged and ignored. """ try: if self.has_bridge_bot: await self.az.intent.redact(self.room_id, self.event_id, reason=reason) else: await self.main_intent.redact(self.room_id, self.event_id, reason=reason) except MForbidden as e: self.log.warning(f"Failed to redact command {self.command}: {e}") except Exception: self.log.warning(f"Failed to redact command {self.command}", exc_info=True) def reply( self, message: str, allow_html: bool = False, render_markdown: bool = True ) -> Awaitable[EventID]: """Write a reply to the room in which the command was issued. Replaces occurences of "$cmdprefix" in the message with the command prefix and replaces occurences of "$cmdprefix+sp " with the command prefix if the command was not issued in a management room. If allow_html and render_markdown are both False, the message will not be rendered to html and sending of html is disabled. Args: message: The message to post in the room. allow_html: Escape html in the message or don't render html at all if markdown is disabled. render_markdown: Use markdown formatting to render the passed message to html. Returns: Handler for the message sending function. """ message = self._replace_command_prefix(message) html = self._render_message( message, allow_html=allow_html, render_markdown=render_markdown ) if self.has_bridge_bot: return self.az.intent.send_notice(self.room_id, message, html=html) else: return self.main_intent.send_notice(self.room_id, message, html=html) async def mark_read(self) -> None: """Marks the command as read by the bot.""" if self.has_bridge_bot: await self.az.intent.mark_read(self.room_id, self.event_id) def _replace_command_prefix(self, message: str) -> str: """Returns the string with the proper command prefix entered.""" message = message.replace( "$cmdprefix+sp ", "" if self.is_management else f"{self.command_prefix} " ) return message.replace("$cmdprefix", self.command_prefix) @staticmethod def _render_message(message: str, allow_html: bool, render_markdown: bool) -> str | None: """Renders the message as HTML. Args: allow_html: Flag to allow custom HTML in the message. render_markdown: If true, markdown styling is applied to the message. Returns: The message rendered as HTML. None is returned if no styled output is required. """ html = "" if render_markdown: html = markdown.render(message, allow_html=allow_html) elif allow_html: html = message return ensure_trailing_newline(html) if html else None CommandHandlerFunc = Callable[[CommandEvent], Awaitable[Any]] IsEnabledForFunc = Callable[[CommandEvent], bool] class CommandHandler: """A command which can be executed from a Matrix room. The command manages its permission and help texts. When called, it will check the permission of the command event and execute the command or, in case of error, report back to the user. Attributes: management_only: Whether the command can exclusively be issued in a management room. name: The name of this command. help_section: Section of the help in which this command will appear. """ name: str management_only: bool needs_admin: bool needs_auth: bool is_enabled_for: IsEnabledForFunc _help_text: str _help_args: str help_section: HelpSection def __init__( self, handler: CommandHandlerFunc, management_only: bool, name: str, help_text: str, help_args: str, help_section: HelpSection, needs_auth: bool, needs_admin: bool, is_enabled_for: IsEnabledForFunc = lambda _: True, **kwargs, ) -> None: """ Args: handler: The function handling the execution of this command. management_only: Whether the command can exclusively be issued in a management room. needs_auth: Whether the command needs the bridge to be authed already needs_admin: Whether the command needs the issuer to be bridge admin name: The name of this command. help_text: The text displayed in the help for this command. help_args: Help text for the arguments of this command. help_section: Section of the help in which this command will appear. """ for key, value in kwargs.items(): setattr(self, key, value) self._handler = handler self.management_only = management_only self.needs_admin = needs_admin self.needs_auth = needs_auth self.name = name self._help_text = help_text self._help_args = help_args self.help_section = help_section self.is_enabled_for = is_enabled_for async def get_permission_error(self, evt: CommandEvent) -> str | None: """Returns the reason why the command could not be issued. Args: evt: The event for which to get the error information. Returns: A string describing the error or None if there was no error. """ if self.management_only and not evt.is_management: return ( f"`{evt.command}` is a restricted command: " "you may only run it in management rooms." ) elif self.needs_admin and not evt.sender.is_admin: return "That command is limited to bridge administrators." elif self.needs_auth and not await evt.sender.is_logged_in(): return "That command requires you to be logged in." return None def has_permission(self, key: HelpCacheKey) -> bool: """Checks the permission for this command with the given status. Args: key: The help cache key. See meth:`CommandEvent.get_cache_key`. Returns: True if a user with the given state is allowed to issue the command. """ return ( (not self.management_only or key.is_management) and (not self.needs_admin or key.is_admin) and (not self.needs_auth or key.is_logged_in) ) async def __call__(self, evt: CommandEvent) -> Any: """Executes the command if evt was issued with proper rights. Args: evt: The CommandEvent for which to check permissions. Returns: The result of the command or the error message function. """ error = await self.get_permission_error(evt) if error is not None: return await evt.reply(error) return await self._handler(evt) @property def has_help(self) -> bool: """Returns true if this command has a help text.""" return bool(self.help_section) and bool(self._help_text) @property def help(self) -> str: """Returns the help text to this command.""" return f"**{self.name}** {self._help_args} - {self._help_text}" def command_handler( _func: CommandHandlerFunc | None = None, *, management_only: bool = False, name: str | None = None, help_text: str = "", help_args: str = "", help_section: HelpSection = None, aliases: list[str] | None = None, _handler_class: Type[CommandHandler] = CommandHandler, needs_auth: bool = True, needs_admin: bool = False, is_enabled_for: IsEnabledForFunc = lambda _: True, **kwargs, ) -> Callable[[CommandHandlerFunc], CommandHandler]: """Decorator to create CommandHandlers""" def decorator(func: CommandHandlerFunc) -> CommandHandler: actual_name = name or func.__name__.replace("_", "-") handler = _handler_class( func, management_only=management_only, name=actual_name, help_text=help_text, help_args=help_args, help_section=help_section, needs_auth=needs_auth, needs_admin=needs_admin, is_enabled_for=is_enabled_for, **kwargs, ) command_handlers[handler.name] = handler if aliases: for alias in aliases: command_aliases[alias] = handler return handler return decorator if _func is None else decorator(_func) class CommandProcessor: """Handles the raw commands issued by a user to the Matrix bot.""" log: TraceLogger = logging.getLogger("mau.commands") az: AppService config: br.BaseBridgeConfig loop: asyncio.AbstractEventLoop event_class: Type[CommandEvent] bridge: bridge.Bridge _ref_no: int def __init__( self, bridge: bridge.Bridge, event_class: Type[CommandEvent] = CommandEvent ) -> None: self.az = bridge.az self.config = bridge.config self.loop = bridge.loop or asyncio.get_event_loop() self.command_prefix = self.config["bridge.command_prefix"] self.bridge = bridge self.event_class = event_class self._ref_no = int(time.time()) @property def ref_no(self) -> int: """ Reference number for a command handling exception to help sysadmins find the error when receiving user reports. """ self._ref_no += 1 return self._ref_no @staticmethod def _run_handler( handler: Callable[[CommandEvent], Awaitable[Any]], evt: CommandEvent ) -> Awaitable[Any]: return handler(evt) async def handle( self, room_id: RoomID, event_id: EventID, sender: br.BaseUser, command: str, args: list[str], content: MessageEventContent, portal: br.BasePortal | None, is_management: bool, has_bridge_bot: bool, ) -> None: """Handles the raw commands issued by a user to the Matrix bot. If the command is not known, it might be a followup command and is delegated to a command handler registered for that purpose in the senders command_status as "next". Args: room_id: ID of the Matrix room in which the command was issued. event_id: ID of the event by which the command was issued. sender: The sender who issued the command. command: The issued command, case insensitive. args: Arguments given with the command. content: The raw content in the command event. portal: The portal the command was sent to. is_management: Whether the room is a management room. has_bridge_bot: Whether or not the bridge bot is in the room. Returns: The result of the error message function or None if no error occured. Unknown and delegated commands do not count as errors. """ if not command_handlers or "unknown-command" not in command_handlers: raise ValueError("command_handlers are not properly initialized.") evt = self.event_class( processor=self, room_id=room_id, event_id=event_id, sender=sender, command=command, args=args, content=content, portal=portal, is_management=is_management, has_bridge_bot=has_bridge_bot, ) orig_command = command command = command.lower() handler = command_handlers.get(command, command_aliases.get(command)) if handler is None or not handler.is_enabled_for(evt): if sender.command_status and "next" in sender.command_status: args.insert(0, orig_command) evt.command = "" handler = sender.command_status["next"] else: handler = command_handlers["unknown-command"] try: await self._run_handler(handler, evt) except Exception: ref_no = self.ref_no self.log.exception( "Unhandled error while handling command " f"{evt.command} {' '.join(args)} from {sender.mxid} (ref: {ref_no})" ) if evt.print_error_traceback: await evt.reply( "Unhandled error while handling command:\n\n" "```traceback\n" f"{traceback.format_exc()}" "```" ) else: await evt.reply( "Unhandled error while handling command. " f"Check logs for more details (ref: {ref_no})." ) raise return None python-0.20.7/mautrix/bridge/commands/login_matrix.py000066400000000000000000000051201473573527000227300ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from mautrix.client import Client from mautrix.types import EventID from ..custom_puppet import AutologinError, CustomPuppetError, InvalidAccessToken from .handler import SECTION_AUTH, CommandEvent, command_handler @command_handler( needs_auth=True, management_only=True, help_args="<_access token_>", help_section=SECTION_AUTH, help_text="Enable double puppeting.", ) async def login_matrix(evt: CommandEvent) -> None: if len(evt.args) == 0: await evt.reply("**Usage:** `$cmdprefix+sp login-matrix `") return try: puppet = await evt.sender.get_puppet() except NotImplementedError: await evt.reply("This bridge has not implemented the login-matrix command.") return _, homeserver = Client.parse_user_id(evt.sender.mxid) try: await puppet.switch_mxid(evt.args[0], evt.sender.mxid) await evt.reply("Successfully enabled double puppeting.") except AutologinError as e: await evt.reply(f"Failed to create an access token: {e}") except CustomPuppetError as e: await evt.reply(str(e)) @command_handler( needs_auth=True, management_only=True, help_section=SECTION_AUTH, help_text="Disable double puppeting.", ) async def logout_matrix(evt: CommandEvent) -> EventID: try: puppet = await evt.sender.get_puppet() except NotImplementedError: return await evt.reply("This bridge has not implemented the logout-matrix command.") if not puppet or not puppet.is_real_user: return await evt.reply("You don't have double puppeting enabled.") await puppet.switch_mxid(None, None) return await evt.reply("Successfully disabled double puppeting.") @command_handler( needs_auth=True, help_section=SECTION_AUTH, help_text="Pings the Matrix server with the double puppet.", ) async def ping_matrix(evt: CommandEvent) -> EventID: try: puppet = await evt.sender.get_puppet() except NotImplementedError: return await evt.reply("This bridge has not implemented the ping-matrix command.") if not puppet.is_real_user: return await evt.reply("You are not logged in with your Matrix account.") try: await puppet.start() except InvalidAccessToken: return await evt.reply("Your access token is invalid.") return await evt.reply("Your Matrix login is working.") python-0.20.7/mautrix/bridge/commands/manhole.py000066400000000000000000000102001473573527000216520ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Callable import asyncio import os from attr import dataclass from mautrix.errors import MatrixConnectionError from mautrix.types import UserID from mautrix.util.manhole import start_manhole from . import SECTION_ADMIN, CommandEvent, command_handler @dataclass class ManholeState: server: asyncio.AbstractServer opened_by: UserID close: Callable[[], None] whitelist: set[int] @command_handler( needs_auth=False, needs_admin=True, help_section=SECTION_ADMIN, help_text="Open a manhole into the bridge.", help_args="<_uid..._>", ) async def open_manhole(evt: CommandEvent) -> None: if not evt.config["manhole.enabled"]: await evt.reply("The manhole has been disabled in the config.") return elif len(evt.args) == 0: await evt.reply("**Usage:** `$cmdprefix+sp open-manhole `") return whitelist = set() whitelist_whitelist = evt.config["manhole.whitelist"] for arg in evt.args: try: uid = int(arg) except ValueError: await evt.reply(f"{arg} is not an integer.") return if whitelist_whitelist and uid not in whitelist_whitelist: await evt.reply(f"{uid} is not in the list of allowed UIDs.") return whitelist.add(uid) if evt.bridge.manhole: added = [uid for uid in whitelist if uid not in evt.bridge.manhole.whitelist] evt.bridge.manhole.whitelist |= set(added) if len(added) == 0: await evt.reply( f"There's an existing manhole opened by {evt.bridge.manhole.opened_by}" " and all the given UIDs are already whitelisted." ) else: added_str = ( f"{', '.join(str(uid) for uid in added[:-1])} and {added[-1]}" if len(added) > 1 else added[0] ) await evt.reply( f"There's an existing manhole opened by {evt.bridge.manhole.opened_by}" f". Added {added_str} to the whitelist." ) evt.log.info(f"{evt.sender.mxid} added {added_str} to the manhole whitelist.") return namespace = await evt.bridge.manhole_global_namespace(evt.sender.mxid) banner = evt.bridge.manhole_banner(evt.sender.mxid) path = evt.config["manhole.path"] wl_list = list(whitelist) whitelist_str = ( f"{', '.join(str(uid) for uid in wl_list[:-1])} and {wl_list[-1]}" if len(wl_list) > 1 else wl_list[0] ) evt.log.info(f"{evt.sender.mxid} opened a manhole with {whitelist_str} whitelisted.") server, close = await start_manhole( path=path, banner=banner, namespace=namespace, loop=evt.loop, whitelist=whitelist ) evt.bridge.manhole = ManholeState( server=server, opened_by=evt.sender.mxid, close=close, whitelist=whitelist ) plrl = "s" if len(whitelist) != 1 else "" await evt.reply(f"Opened manhole at unix://{path} with UID{plrl} {whitelist_str} whitelisted") await server.wait_closed() evt.bridge.manhole = None try: os.unlink(path) except FileNotFoundError: pass evt.log.info(f"{evt.sender.mxid}'s manhole was closed.") try: await evt.reply("Your manhole was closed.") except (AttributeError, MatrixConnectionError) as e: evt.log.warning(f"Failed to send manhole close notification: {e}") @command_handler( needs_auth=False, needs_admin=True, help_section=SECTION_ADMIN, help_text="Close an open manhole.", ) async def close_manhole(evt: CommandEvent) -> None: if not evt.bridge.manhole: await evt.reply("There is no open manhole.") return opened_by = evt.bridge.manhole.opened_by evt.bridge.manhole.close() evt.bridge.manhole = None if opened_by != evt.sender.mxid: await evt.reply(f"Closed manhole opened by {opened_by}") python-0.20.7/mautrix/bridge/commands/meta.py000066400000000000000000000060221473573527000211640ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.types import EventID from .handler import ( SECTION_GENERAL, CommandEvent, HelpCacheKey, HelpSection, command_handler, command_handlers, ) @command_handler( needs_auth=False, help_section=SECTION_GENERAL, help_text="Cancel an ongoing action." ) async def cancel(evt: CommandEvent) -> EventID: if evt.sender.command_status: action = evt.sender.command_status["action"] evt.sender.command_status = None return await evt.reply(f"{action} cancelled.") else: return await evt.reply("No ongoing command.") @command_handler( needs_auth=False, help_section=SECTION_GENERAL, help_text="Get the bridge version." ) async def version(evt: CommandEvent) -> None: if not evt.processor.bridge: await evt.reply("Bridge version unknown") else: await evt.reply( f"[{evt.processor.bridge.name}]({evt.processor.bridge.repo_url}) " f"{evt.processor.bridge.markdown_version or evt.processor.bridge.version}" ) @command_handler(needs_auth=False) async def unknown_command(evt: CommandEvent) -> EventID: return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.") help_cache: dict[HelpCacheKey, str] = {} async def _get_help_text(evt: CommandEvent) -> str: cache_key = await evt.get_help_key() if cache_key not in help_cache: help_sections: dict[HelpSection, list[str]] = {} for handler in command_handlers.values(): if ( handler.has_help and handler.has_permission(cache_key) and handler.is_enabled_for(evt) ): help_sections.setdefault(handler.help_section, []) help_sections[handler.help_section].append(handler.help + " ") help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order) helps = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted] help_cache[cache_key] = "\n".join(helps) return help_cache[cache_key] def _get_management_status(evt: CommandEvent) -> str: if evt.is_management: return "This is a management room: prefixing commands with `$cmdprefix` is not required." elif evt.is_portal: return ( "**This is a portal room**: you must always prefix commands with `$cmdprefix`.\n" "Management commands will not be bridged." ) return "**This is not a management room**: you must prefix commands with `$cmdprefix`." @command_handler( name="help", needs_auth=False, help_section=SECTION_GENERAL, help_text="Show this help message.", ) async def help_cmd(evt: CommandEvent) -> EventID: return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt)) python-0.20.7/mautrix/bridge/commands/relay.py000066400000000000000000000030231473573527000213500ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from mautrix.types import EventID from .handler import SECTION_RELAY, CommandEvent, command_handler @command_handler( needs_auth=True, management_only=False, name="set-relay", help_section=SECTION_RELAY, help_text="Relay messages in this room through your account.", is_enabled_for=lambda evt: evt.config["bridge.relay.enabled"], ) async def set_relay(evt: CommandEvent) -> EventID: if not evt.is_portal: return await evt.reply("This is not a portal room.") await evt.portal.set_relay_user(evt.sender) return await evt.reply( "Messages from non-logged-in users in this room will now be bridged " "through your account." ) @command_handler( needs_auth=True, management_only=False, name="unset-relay", help_section=SECTION_RELAY, help_text="Stop relaying messages in this room.", is_enabled_for=lambda evt: evt.config["bridge.relay.enabled"], ) async def unset_relay(evt: CommandEvent) -> EventID: if not evt.is_portal: return await evt.reply("This is not a portal room.") elif not evt.portal.has_relay: return await evt.reply("This room does not have a relay user set.") await evt.portal.set_relay_user(None) return await evt.reply("Messages from non-logged-in users will no longer be bridged.") python-0.20.7/mautrix/bridge/config.py000066400000000000000000000217341473573527000177110ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, ClassVar from abc import ABC import json import os import re import secrets import time from mautrix.util.config import ( BaseFileConfig, BaseValidatableConfig, ConfigUpdateHelper, ForbiddenDefault, yaml, ) class BaseBridgeConfig(BaseFileConfig, BaseValidatableConfig, ABC): env_prefix: str | None = None registration_path: str _registration: dict | None _check_tokens: bool env: dict[str, Any] def __init__( self, path: str, registration_path: str, base_path: str, env_prefix: str | None = None ) -> None: super().__init__(path, base_path) self.registration_path = registration_path self._registration = None self._check_tokens = True self.env = {} if not self.env_prefix: self.env_prefix = env_prefix if self.env_prefix: env_prefix = f"{self.env_prefix}_" for key, value in os.environ.items(): if not key.startswith(env_prefix): continue key = key.removeprefix(env_prefix) if value.startswith("json::"): value = json.loads(value.removeprefix("json::")) self.env[key] = value def __getitem__(self, item: str) -> Any: if self.env: try: sanitized_item = item.replace(".", "_").replace("[", "").replace("]", "").upper() return self.env[sanitized_item] except KeyError: pass return super().__getitem__(item) def save(self) -> None: super().save() if self._registration and self.registration_path: with open(self.registration_path, "w") as stream: yaml.dump(self._registration, stream) @staticmethod def _new_token() -> str: return secrets.token_urlsafe(48) @property def forbidden_defaults(self) -> list[ForbiddenDefault]: return [ ForbiddenDefault("homeserver.address", "https://example.com"), ForbiddenDefault("homeserver.address", "https://matrix.example.com"), ForbiddenDefault("homeserver.domain", "example.com"), ] + ( [ ForbiddenDefault( "appservice.as_token", "This value is generated when generating the registration", "Did you forget to generate the registration?", ), ForbiddenDefault( "appservice.hs_token", "This value is generated when generating the registration", "Did you forget to generate the registration?", ), ] if self._check_tokens else [] ) def do_update(self, helper: ConfigUpdateHelper) -> None: copy, copy_dict = helper.copy, helper.copy_dict copy("homeserver.address") copy("homeserver.domain") copy("homeserver.verify_ssl") copy("homeserver.http_retry_count") copy("homeserver.connection_limit") copy("homeserver.status_endpoint") copy("homeserver.message_send_checkpoint_endpoint") copy("homeserver.async_media") if self.get("homeserver.asmux", False): helper.base["homeserver.software"] = "asmux" else: copy("homeserver.software") copy("appservice.address") copy("appservice.hostname") copy("appservice.port") copy("appservice.max_body_size") copy("appservice.tls_cert") copy("appservice.tls_key") if "appservice.database" in self and self["appservice.database"].startswith("sqlite:///"): helper.base["appservice.database"] = self["appservice.database"].replace( "sqlite:///", "sqlite:" ) else: copy("appservice.database") copy("appservice.database_opts") copy("appservice.id") copy("appservice.bot_username") copy("appservice.bot_displayname") copy("appservice.bot_avatar") copy("appservice.as_token") copy("appservice.hs_token") copy("appservice.ephemeral_events") copy("bridge.management_room_text.welcome") copy("bridge.management_room_text.welcome_connected") copy("bridge.management_room_text.welcome_unconnected") copy("bridge.management_room_text.additional_help") copy("bridge.management_room_multiple_messages") copy("bridge.encryption.allow") copy("bridge.encryption.default") copy("bridge.encryption.require") copy("bridge.encryption.appservice") copy("bridge.encryption.delete_keys.delete_outbound_on_ack") copy("bridge.encryption.delete_keys.dont_store_outbound") copy("bridge.encryption.delete_keys.ratchet_on_decrypt") copy("bridge.encryption.delete_keys.delete_fully_used_on_decrypt") copy("bridge.encryption.delete_keys.delete_prev_on_new_session") copy("bridge.encryption.delete_keys.delete_on_device_delete") copy("bridge.encryption.delete_keys.periodically_delete_expired") copy("bridge.encryption.delete_keys.delete_outdated_inbound") copy("bridge.encryption.verification_levels.receive") copy("bridge.encryption.verification_levels.send") copy("bridge.encryption.verification_levels.share") copy("bridge.encryption.allow_key_sharing") if self.get("bridge.encryption.key_sharing.allow", False): helper.base["bridge.encryption.allow_key_sharing"] = True require_verif = self.get("bridge.encryption.key_sharing.require_verification", True) require_cs = self.get("bridge.encryption.key_sharing.require_cross_signing", False) if require_verif: helper.base["bridge.encryption.verification_levels.share"] = "verified" elif not require_cs: helper.base["bridge.encryption.verification_levels.share"] = "unverified" # else: default (cross-signed-tofu) copy("bridge.encryption.rotation.enable_custom") copy("bridge.encryption.rotation.milliseconds") copy("bridge.encryption.rotation.messages") copy("bridge.encryption.rotation.disable_device_change_key_rotation") copy("bridge.relay.enabled") copy_dict("bridge.relay.message_formats", override_existing_map=False) copy("manhole.enabled") copy("manhole.path") copy("manhole.whitelist") copy("logging") @property def namespaces(self) -> dict[str, list[dict[str, Any]]]: """ Generate the user ID and room alias namespace config for the registration as specified in https://matrix.org/docs/spec/application_service/r0.1.0.html#application-services """ homeserver = self["homeserver.domain"] regex_ph = f"regexplaceholder{int(time.time())}" username_format = self["bridge.username_template"].format(userid=regex_ph) alias_format = ( self["bridge.alias_template"].format(groupname=regex_ph) if "bridge.alias_template" in self else None ) return { "users": [ { "exclusive": True, "regex": re.escape(f"@{username_format}:{homeserver}").replace(regex_ph, ".*"), } ], "aliases": ( [ { "exclusive": True, "regex": re.escape(f"#{alias_format}:{homeserver}").replace( regex_ph, ".*" ), } ] if alias_format else [] ), } def generate_registration(self) -> None: self["appservice.as_token"] = self._new_token() self["appservice.hs_token"] = self._new_token() namespaces = self.namespaces bot_username = self["appservice.bot_username"] homeserver_domain = self["homeserver.domain"] namespaces.setdefault("users", []).append( { "exclusive": True, "regex": re.escape(f"@{bot_username}:{homeserver_domain}"), } ) self._registration = { "id": self["appservice.id"], "as_token": self["appservice.as_token"], "hs_token": self["appservice.hs_token"], "namespaces": namespaces, "url": self["appservice.address"], "sender_localpart": self._new_token(), "rate_limited": False, } if self["appservice.ephemeral_events"]: self._registration["de.sorunome.msc2409.push_ephemeral"] = True self._registration["push_ephemeral"] = True python-0.20.7/mautrix/bridge/crypto_state_store.py000066400000000000000000000035021473573527000223710ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Awaitable, Callable from abc import ABC from mautrix import __optional_imports__ from mautrix.bridge.portal import BasePortal from mautrix.crypto import StateStore from mautrix.types import RoomEncryptionStateEventContent, RoomID, UserID from mautrix.util.async_db import Database GetPortalFunc = Callable[[RoomID], Awaitable[BasePortal]] class BaseCryptoStateStore(StateStore, ABC): get_portal: GetPortalFunc def __init__(self, get_portal: GetPortalFunc): self.get_portal = get_portal async def is_encrypted(self, room_id: RoomID) -> bool: portal = await self.get_portal(room_id) return portal.encrypted if portal else False class PgCryptoStateStore(BaseCryptoStateStore): db: Database def __init__(self, db: Database, get_portal: GetPortalFunc) -> None: super().__init__(get_portal) self.db = db async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]: rows = await self.db.fetch( "SELECT room_id FROM mx_user_profile " "LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id " "WHERE user_id=$1 AND portal.encrypted=true", user_id, ) return [row["room_id"] for row in rows] async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None: val = await self.db.fetchval( "SELECT encryption FROM mx_room_state WHERE room_id=$1", room_id ) if not val: return None return RoomEncryptionStateEventContent.parse_json(val) python-0.20.7/mautrix/bridge/custom_puppet.py000066400000000000000000000313061473573527000213470ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from abc import ABC, abstractmethod import asyncio import hashlib import hmac import logging from yarl import URL from mautrix.appservice import AppService, IntentAPI from mautrix.client import ClientAPI from mautrix.errors import ( IntentError, MatrixError, MatrixInvalidToken, MatrixRequestError, WellKnownError, ) from mautrix.types import LoginType, MatrixUserIdentifier, RoomID, UserID from .. import bridge as br class CustomPuppetError(MatrixError): """Base class for double puppeting setup errors.""" class InvalidAccessToken(CustomPuppetError): def __init__(self): super().__init__("The given access token was invalid.") class OnlyLoginSelf(CustomPuppetError): def __init__(self): super().__init__("You may only enable double puppeting with your own Matrix account.") class EncryptionKeysFound(CustomPuppetError): def __init__(self): super().__init__( "The given access token is for a device that has encryption keys set up. " "Please provide a fresh token, don't reuse one from another client." ) class HomeserverURLNotFound(CustomPuppetError): def __init__(self, domain: str): super().__init__( f"Could not discover a valid homeserver URL for {domain}." " Please ensure a client .well-known file is set up, or ask the bridge administrator " "to add the homeserver URL to the bridge config." ) class OnlyLoginTrustedDomain(CustomPuppetError): def __init__(self): super().__init__( "This bridge doesn't allow double-puppeting with accounts on untrusted servers." ) class AutologinError(CustomPuppetError): pass class CustomPuppetMixin(ABC): """ Mixin for the Puppet class to enable Matrix puppeting. Attributes: sync_with_custom_puppets: Whether or not custom puppets should /sync allow_discover_url: Allow logging into other homeservers using .well-known discovery. homeserver_url_map: Static map from server name to URL that are always allowed to log in. only_handle_own_synced_events: Whether or not typing notifications and read receipts by other users should be filtered away before passing them to the Matrix event handler. az: The AppService object. loop: The asyncio event loop. log: The logger to use. mx: The Matrix event handler to send /sync events to. by_custom_mxid: A mapping from custom mxid to puppet object. default_mxid: The default user ID of the puppet. default_mxid_intent: The IntentAPI for the default user ID. custom_mxid: The user ID of the custom puppet. access_token: The access token for the custom puppet. intent: The primary IntentAPI. """ allow_discover_url: bool = False homeserver_url_map: dict[str, URL] = {} only_handle_own_synced_events: bool = True login_shared_secret_map: dict[str, bytes] = {} login_device_name: str | None = None az: AppService loop: asyncio.AbstractEventLoop log: logging.Logger mx: br.BaseMatrixHandler by_custom_mxid: dict[UserID, CustomPuppetMixin] = {} default_mxid: UserID default_mxid_intent: IntentAPI custom_mxid: UserID | None access_token: str | None base_url: URL | None intent: IntentAPI @abstractmethod async def save(self) -> None: """Save the information of this puppet. Called from :meth:`switch_mxid`""" @property def mxid(self) -> UserID: """The main Matrix user ID of this puppet.""" return self.custom_mxid or self.default_mxid @property def is_real_user(self) -> bool: """Whether this puppet uses a real Matrix user instead of an appservice-owned ID.""" return bool(self.custom_mxid and self.access_token) def _fresh_intent(self) -> IntentAPI: if self.custom_mxid: _, server = self.az.intent.parse_user_id(self.custom_mxid) try: self.base_url = self.homeserver_url_map[server] except KeyError: if server == self.az.domain: self.base_url = self.az.intent.api.base_url if self.access_token == "appservice-config" and self.custom_mxid: try: secret = self.login_shared_secret_map[server] except KeyError: raise AutologinError(f"No shared secret configured for {server}") self.log.debug(f"Using as_token for double puppeting {self.custom_mxid}") return self.az.intent.user( self.custom_mxid, secret.decode("utf-8").removeprefix("as_token:"), self.base_url, as_token=True, ) return ( self.az.intent.user(self.custom_mxid, self.access_token, self.base_url) if self.is_real_user else self.default_mxid_intent ) @classmethod def can_auto_login(cls, mxid: UserID) -> bool: _, server = cls.az.intent.parse_user_id(mxid) return server in cls.login_shared_secret_map and ( server in cls.homeserver_url_map or server == cls.az.domain ) @classmethod async def _login_with_shared_secret(cls, mxid: UserID) -> str: _, server = cls.az.intent.parse_user_id(mxid) try: secret = cls.login_shared_secret_map[server] except KeyError: raise AutologinError(f"No shared secret configured for {server}") if secret.startswith(b"as_token:"): return "appservice-config" try: base_url = cls.homeserver_url_map[server] except KeyError: if server == cls.az.domain: base_url = cls.az.intent.api.base_url else: raise AutologinError(f"No homeserver URL configured for {server}") client = ClientAPI(base_url=base_url) login_args = {} if secret == b"appservice": login_type = LoginType.APPSERVICE client.api.token = cls.az.as_token else: flows = await client.get_login_flows() flow = flows.get_first_of_type(LoginType.DEVTURE_SHARED_SECRET, LoginType.PASSWORD) if not flow: raise AutologinError("No supported shared secret auth login flows") login_type = flow.type token = hmac.new(secret, mxid.encode("utf-8"), hashlib.sha512).hexdigest() if login_type == LoginType.DEVTURE_SHARED_SECRET: login_args["token"] = token elif login_type == LoginType.PASSWORD: login_args["password"] = token resp = await client.login( identifier=MatrixUserIdentifier(user=mxid), device_id=cls.login_device_name, initial_device_display_name=cls.login_device_name, login_type=login_type, **login_args, store_access_token=False, update_hs_url=False, ) return resp.access_token async def switch_mxid( self, access_token: str | None, mxid: UserID | None, start_sync_task: bool = True ) -> None: """ Switch to a real Matrix user or away from one. Args: access_token: The access token for the custom account, or ``None`` to switch back to the appservice-owned ID. mxid: The expected Matrix user ID of the custom account, or ``None`` when ``access_token`` is None. """ if access_token == "auto": access_token = await self._login_with_shared_secret(mxid) if access_token != "appservice-config": self.log.debug(f"Logged in for {mxid} using shared secret") if mxid is not None: _, mxid_domain = self.az.intent.parse_user_id(mxid) if mxid_domain in self.homeserver_url_map: base_url = self.homeserver_url_map[mxid_domain] elif mxid_domain == self.az.domain: base_url = None else: if not self.allow_discover_url: raise OnlyLoginTrustedDomain() try: base_url = await IntentAPI.discover(mxid_domain, self.az.http_session) except WellKnownError as e: raise HomeserverURLNotFound(mxid_domain) from e if base_url is None: raise HomeserverURLNotFound(mxid_domain) else: base_url = None prev_mxid = self.custom_mxid self.custom_mxid = mxid self.access_token = access_token self.base_url = base_url self.intent = self._fresh_intent() await self.start(check_e2ee_keys=True) try: del self.by_custom_mxid[prev_mxid] except KeyError: pass if self.mxid != self.default_mxid: self.by_custom_mxid[self.mxid] = self try: await self._leave_rooms_with_default_user() except Exception: self.log.warning("Error when leaving rooms with default user", exc_info=True) await self.save() async def try_start(self, retry_auto_login: bool = True) -> None: try: await self.start(retry_auto_login=retry_auto_login) except Exception: self.log.exception("Failed to initialize custom mxid") async def _invalidate_double_puppet(self) -> None: if self.custom_mxid and self.by_custom_mxid.get(self.custom_mxid) == self: del self.by_custom_mxid[self.custom_mxid] self.custom_mxid = None self.access_token = None await self.save() self.intent = self._fresh_intent() async def start( self, retry_auto_login: bool = False, start_sync_task: bool = True, check_e2ee_keys: bool = False, ) -> None: """Initialize the custom account this puppet uses. Should be called at startup to start the /sync task. Is called by :meth:`switch_mxid` automatically.""" if not self.is_real_user: return try: whoami = await self.intent.whoami() except MatrixInvalidToken as e: if retry_auto_login and self.custom_mxid and self.can_auto_login(self.custom_mxid): self.log.debug(f"Got {e.errcode} while trying to initialize custom mxid") await self.switch_mxid("auto", self.custom_mxid) return self.log.warning(f"Got {e.errcode} while trying to initialize custom mxid") whoami = None if not whoami or whoami.user_id != self.custom_mxid: prev_custom_mxid = self.custom_mxid await self._invalidate_double_puppet() if whoami and whoami.user_id != prev_custom_mxid: raise OnlyLoginSelf() raise InvalidAccessToken() if check_e2ee_keys: try: devices = await self.intent.query_keys({whoami.user_id: [whoami.device_id]}) device_keys = devices.device_keys.get(whoami.user_id, {}).get(whoami.device_id) except Exception: self.log.warning( "Failed to query keys to check if double puppeting token was reused", exc_info=True, ) else: if device_keys and len(device_keys.keys) > 0: await self._invalidate_double_puppet() raise EncryptionKeysFound() self.log.info(f"Initialized custom mxid: {whoami.user_id}") def stop(self) -> None: """ No-op .. deprecated:: 0.20.1 """ async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool: """ Whether or not the default puppet user should leave the given room when this puppet is switched to using a custom user account. Args: room_id: The room to check. Returns: Whether or not the default user account should leave. """ return True async def _leave_rooms_with_default_user(self) -> None: for room_id in await self.default_mxid_intent.get_joined_rooms(): try: if await self.default_puppet_should_leave_room(room_id): await self.default_mxid_intent.leave_room(room_id) await self.intent.ensure_joined(room_id) except (IntentError, MatrixRequestError): pass python-0.20.7/mautrix/bridge/disappearing_message.py000066400000000000000000000024611473573527000226120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import TypeVar from abc import ABC, abstractmethod import time from attr import dataclass from mautrix.types import EventID, RoomID @dataclass class AbstractDisappearingMessage(ABC): room_id: RoomID event_id: EventID expiration_seconds: int expiration_ts: int | None = None @abstractmethod async def insert(self) -> None: pass @abstractmethod async def update(self) -> None: pass def start_timer(self) -> None: self.expiration_ts = int(time.time() * 1000) + (self.expiration_seconds * 1000) @abstractmethod async def delete(self) -> None: pass @classmethod @abstractmethod async def get_all_scheduled(cls: type[DisappearingMessage]) -> list[DisappearingMessage]: pass @classmethod @abstractmethod async def get_unscheduled_for_room( cls: type[DisappearingMessage], room_id: RoomID ) -> list[DisappearingMessage]: pass DisappearingMessage = TypeVar("DisappearingMessage", bound=AbstractDisappearingMessage) python-0.20.7/mautrix/bridge/e2ee.py000066400000000000000000000370251473573527000172640ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import asyncio import logging import sys from mautrix import __optional_imports__ from mautrix.appservice import AppService from mautrix.client import Client, InternalEventType, SyncStore from mautrix.crypto import CryptoStore, OlmMachine, PgCryptoStore, RejectKeyShare, StateStore from mautrix.errors import EncryptionError, MForbidden, MNotFound, SessionNotFound from mautrix.types import ( JSON, DeviceIdentity, EncryptedEvent, EncryptedMegolmEventContent, EventFilter, EventType, Filter, LoginType, MessageEvent, RequestedKeyInfo, RoomEventFilter, RoomFilter, RoomID, RoomKeyWithheldCode, Serializable, StateEvent, StateFilter, TrustState, ) from mautrix.util import background_task from mautrix.util.async_db import Database from mautrix.util.logging import TraceLogger from .. import bridge as br from .crypto_state_store import PgCryptoStateStore class EncryptionManager: loop: asyncio.AbstractEventLoop log: TraceLogger = logging.getLogger("mau.bridge.e2ee") client: Client crypto: OlmMachine crypto_store: CryptoStore | SyncStore crypto_db: Database | None state_store: StateStore min_send_trust: TrustState key_sharing_enabled: bool appservice_mode: bool periodically_delete_expired_keys: bool delete_outdated_inbound: bool bridge: br.Bridge az: AppService _id_prefix: str _id_suffix: str _share_session_events: dict[RoomID, asyncio.Event] _key_delete_task: asyncio.Task | None def __init__( self, bridge: br.Bridge, homeserver_address: str, user_id_prefix: str, user_id_suffix: str, db_url: str, ) -> None: self.loop = bridge.loop or asyncio.get_event_loop() self.bridge = bridge self.az = bridge.az self.device_name = bridge.name self._id_prefix = user_id_prefix self._id_suffix = user_id_suffix self._share_session_events = {} pickle_key = "mautrix.bridge.e2ee" self.crypto_db = Database.create( url=db_url, upgrade_table=PgCryptoStore.upgrade_table, log=logging.getLogger("mau.crypto.db"), ) self.crypto_store = PgCryptoStore("", pickle_key, self.crypto_db) self.state_store = PgCryptoStateStore(self.crypto_db, bridge.get_portal) default_http_retry_count = bridge.config.get("homeserver.http_retry_count", None) self.client = Client( base_url=homeserver_address, mxid=self.az.bot_mxid, loop=self.loop, sync_store=self.crypto_store, log=self.log.getChild("client"), default_retry_count=default_http_retry_count, state_store=self.bridge.state_store, ) self.crypto = OlmMachine(self.client, self.crypto_store, self.state_store) self.client.add_event_handler(InternalEventType.SYNC_STOPPED, self._exit_on_sync_fail) self.crypto.allow_key_share = self.allow_key_share verification_levels = bridge.config["bridge.encryption.verification_levels"] self.min_send_trust = TrustState.parse(verification_levels["send"]) self.crypto.share_keys_min_trust = TrustState.parse(verification_levels["share"]) self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"]) self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"] self.appservice_mode = bridge.config["bridge.encryption.appservice"] if self.appservice_mode: self.az.otk_handler = self.crypto.handle_as_otk_counts self.az.device_list_handler = self.crypto.handle_as_device_lists self.az.to_device_handler = self.crypto.handle_as_to_device_event self.periodically_delete_expired_keys = False self.delete_outdated_inbound = False self._key_delete_task = None del_cfg = bridge.config["bridge.encryption.delete_keys"] if del_cfg: self.crypto.delete_outbound_keys_on_ack = del_cfg["delete_outbound_on_ack"] self.crypto.dont_store_outbound_keys = del_cfg["dont_store_outbound"] self.crypto.delete_previous_keys_on_receive = del_cfg["delete_prev_on_new_session"] self.crypto.ratchet_keys_on_decrypt = del_cfg["ratchet_on_decrypt"] self.crypto.delete_fully_used_keys_on_decrypt = del_cfg["delete_fully_used_on_decrypt"] self.crypto.delete_keys_on_device_delete = del_cfg["delete_on_device_delete"] self.periodically_delete_expired_keys = del_cfg["periodically_delete_expired"] self.delete_outdated_inbound = del_cfg["delete_outdated_inbound"] self.crypto.disable_device_change_key_rotation = bridge.config[ "bridge.encryption.rotation.disable_device_change_key_rotation" ] async def _exit_on_sync_fail(self, data) -> None: if data["error"]: self.log.critical("Exiting due to crypto sync error") sys.exit(32) async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInfo) -> bool: if not self.key_sharing_enabled: self.log.debug( f"Key sharing not enabled, ignoring key request from " f"{device.user_id}/{device.device_id}" ) return False elif device.trust == TrustState.BLACKLISTED: raise RejectKeyShare( f"Rejecting key request from blacklisted device " f"{device.user_id}/{device.device_id}", code=RoomKeyWithheldCode.BLACKLISTED, reason="Your device has been blacklisted by the bridge", ) elif await self.crypto.resolve_trust(device) >= self.crypto.share_keys_min_trust: portal = await self.bridge.get_portal(request.room_id) if portal is None: raise RejectKeyShare( f"Rejecting key request for {request.session_id} from " f"{device.user_id}/{device.device_id}: room is not a portal", code=RoomKeyWithheldCode.UNAVAILABLE, reason="Requested room is not a portal", ) user = await self.bridge.get_user(device.user_id) if not await user.is_in_portal(portal): raise RejectKeyShare( f"Rejecting key request for {request.session_id} from " f"{device.user_id}/{device.device_id}: user is not in portal", code=RoomKeyWithheldCode.UNAUTHORIZED, reason="You're not in that portal", ) self.log.debug( f"Accepting key request for {request.session_id} from " f"{device.user_id}/{device.device_id}" ) return True else: raise RejectKeyShare( f"Rejecting key request from unverified device " f"{device.user_id}/{device.device_id}", code=RoomKeyWithheldCode.UNVERIFIED, reason="Your device is not trusted by the bridge", ) def _ignore_user(self, user_id: str) -> bool: return ( user_id.startswith(self._id_prefix) and user_id.endswith(self._id_suffix) and user_id != self.az.bot_mxid ) async def handle_member_event(self, evt: StateEvent) -> None: if self._ignore_user(evt.state_key): # We don't want to invalidate group sessions because a ghost left or joined return await self.crypto.handle_member_event(evt) async def _share_session_lock(self, room_id: RoomID) -> bool: try: event = self._share_session_events[room_id] except KeyError: self._share_session_events[room_id] = asyncio.Event() return True else: await event.wait() return False async def encrypt( self, room_id: RoomID, event_type: EventType, content: Serializable | JSON ) -> tuple[EventType, EncryptedMegolmEventContent]: try: encrypted = await self.crypto.encrypt_megolm_event(room_id, event_type, content) except EncryptionError: self.log.debug("Got EncryptionError, sharing group session and trying again") if await self._share_session_lock(room_id): try: users = await self.az.state_store.get_members_filtered( room_id, self._id_prefix, self._id_suffix, self.az.bot_mxid ) await self.crypto.share_group_session(room_id, users) finally: self._share_session_events.pop(room_id).set() encrypted = await self.crypto.encrypt_megolm_event(room_id, event_type, content) return EventType.ROOM_ENCRYPTED, encrypted async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> MessageEvent: try: decrypted = await self.crypto.decrypt_megolm_event(evt) except SessionNotFound as e: if not wait_session_timeout: raise self.log.debug( f"Couldn't find session {e.session_id} trying to decrypt {evt.event_id}," f" waiting {wait_session_timeout} seconds..." ) got_keys = await self.crypto.wait_for_session( evt.room_id, e.session_id, timeout=wait_session_timeout ) if got_keys: self.log.debug( f"Got session {e.session_id} after waiting, " f"trying to decrypt {evt.event_id} again" ) decrypted = await self.crypto.decrypt_megolm_event(evt) else: raise self.log.trace("Decrypted event %s: %s", evt.event_id, decrypted) return decrypted async def start(self) -> None: flows = await self.client.get_login_flows() if not flows.supports_type(LoginType.APPSERVICE): self.log.critical( "Encryption enabled in config, but homeserver does not support appservice login" ) sys.exit(30) self.log.debug("Logging in with bridge bot user") if self.crypto_db: try: await self.crypto_db.start() except Exception as e: self.bridge._log_db_error(e) await self.crypto_store.open() device_id = await self.crypto_store.get_device_id() if device_id: self.log.debug(f"Found device ID in database: {device_id}") # We set the API token to the AS token here to authenticate the appservice login # It'll get overridden after the login self.client.api.token = self.az.as_token await self.client.login( login_type=LoginType.APPSERVICE, device_name=self.device_name, device_id=device_id, store_access_token=True, update_hs_url=False, ) await self.crypto.load() if not device_id: await self.crypto_store.put_device_id(self.client.device_id) self.log.debug(f"Logged in with new device ID {self.client.device_id}") elif self.crypto.account.shared: await self._verify_keys_are_on_server() if self.appservice_mode: self.log.info("End-to-bridge encryption support is enabled (appservice mode)") else: _ = self.client.start(self._filter) self.log.info("End-to-bridge encryption support is enabled (sync mode)") if self.delete_outdated_inbound: deleted = await self.crypto_store.redact_outdated_group_sessions() if len(deleted) > 0: self.log.debug( f"Deleted {len(deleted)} inbound keys which lacked expiration metadata" ) if self.periodically_delete_expired_keys: self._key_delete_task = background_task.create(self._periodically_delete_keys()) background_task.create(self._resync_encryption_info()) async def _resync_encryption_info(self) -> None: rows = await self.crypto_db.fetch( """SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'""" ) room_ids = [row["room_id"] for row in rows] if not room_ids: return self.log.debug(f"Resyncing encryption state event in rooms: {room_ids}") for room_id in room_ids: try: evt = await self.client.get_state_event(room_id, EventType.ROOM_ENCRYPTION) except (MNotFound, MForbidden) as e: self.log.debug(f"Failed to get encryption state in {room_id}: {e}") q = """ UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' """ await self.crypto_db.execute(q, room_id) else: self.log.debug(f"Resynced encryption state in {room_id}: {evt}") q = """ UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2 WHERE room_id=$3 AND max_age IS NULL and max_messages IS NULL """ await self.crypto_db.execute( q, evt.rotation_period_ms, evt.rotation_period_msgs, room_id ) async def _verify_keys_are_on_server(self) -> None: self.log.debug("Making sure keys are still on server") try: resp = await self.client.query_keys([self.client.mxid]) except Exception: self.log.critical( "Failed to query own keys to make sure device still exists", exc_info=True ) sys.exit(33) try: own_keys = resp.device_keys[self.client.mxid][self.client.device_id] if len(own_keys.keys) > 0: return except KeyError: pass self.log.critical("Existing device doesn't have keys on server, resetting crypto") await self.crypto.crypto_store.delete() await self.client.logout_all() sys.exit(34) async def stop(self) -> None: if self._key_delete_task: self._key_delete_task.cancel() self._key_delete_task = None self.client.stop() await self.crypto_store.close() if self.crypto_db: await self.crypto_db.stop() @property def _filter(self) -> Filter: all_events = EventType.find("*") return Filter( account_data=EventFilter(types=[all_events]), presence=EventFilter(not_types=[all_events]), room=RoomFilter( include_leave=False, state=StateFilter(not_types=[all_events]), timeline=RoomEventFilter(not_types=[all_events]), account_data=RoomEventFilter(not_types=[all_events]), ephemeral=RoomEventFilter(not_types=[all_events]), ), ) async def _periodically_delete_keys(self) -> None: while True: deleted = await self.crypto_store.redact_expired_group_sessions() if deleted: self.log.info(f"Deleted expired megolm sessions: {deleted}") else: self.log.debug("No expired megolm sessions found") await asyncio.sleep(24 * 60 * 60) python-0.20.7/mautrix/bridge/matrix.py000066400000000000000000001223161473573527000177460ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from collections import defaultdict import asyncio import logging import sys import time from mautrix import __optional_imports__ from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY, AppService from mautrix.errors import ( DecryptionError, IntentError, MatrixError, MExclusive, MForbidden, MUnknownToken, SessionNotFound, ) from mautrix.types import ( BaseRoomEvent, BeeperMessageStatusEventContent, EncryptedEvent, Event, EventID, EventType, MediaRepoConfig, Membership, MemberStateEventContent, MessageEvent, MessageEventContent, MessageStatus, MessageStatusReason, MessageType, PresenceEvent, ReactionEvent, ReceiptEvent, ReceiptType, RedactionEvent, RelatesTo, RelationType, RoomID, RoomType, SingleReceiptEventContent, SpecVersions, StateEvent, StateUnsigned, TextMessageEventContent, TrustState, TypingEvent, UserID, Version, VersionsResponse, ) from mautrix.util import background_task, markdown from mautrix.util.logging import TraceLogger from mautrix.util.message_send_checkpoint import ( CHECKPOINT_TYPES, MessageSendCheckpoint, MessageSendCheckpointReportedBy, MessageSendCheckpointStatus, MessageSendCheckpointStep, ) from mautrix.util.opt_prometheus import Histogram from .. import bridge as br from . import commands as cmd encryption_import_error = None media_encrypt_import_error = None try: from .e2ee import EncryptionManager except ImportError as e: if __optional_imports__: raise encryption_import_error = e EncryptionManager = None try: from mautrix.crypto.attachments import encrypt_attachment except ImportError as e: if __optional_imports__: raise media_encrypt_import_error = e encrypt_attachment = None EVENT_TIME = Histogram( "bridge_matrix_event", "Time spent processing Matrix events", ["event_type"] ) class UnencryptedMessageError(DecryptionError): def __init__(self) -> None: super().__init__("unencrypted message") @property def human_message(self) -> str: return "the message is not encrypted" class EncryptionUnsupportedError(DecryptionError): def __init__(self) -> None: super().__init__("encryption is not supported") @property def human_message(self) -> str: return "the bridge is not configured to support encryption" class DeviceUntrustedError(DecryptionError): def __init__(self, trust: TrustState) -> None: explanation = { TrustState.BLACKLISTED: "device is blacklisted", TrustState.UNVERIFIED: "unverified", TrustState.UNKNOWN_DEVICE: "device info not found", TrustState.FORWARDED: "keys were forwarded from an unknown device", TrustState.CROSS_SIGNED_UNTRUSTED: ( "cross-signing keys changed after setting up the bridge" ), }.get(trust) base = "your device is not trusted" self.message = f"{base} ({explanation})" if explanation else base super().__init__(self.message) @property def human_message(self) -> str: return self.message class BaseMatrixHandler: log: TraceLogger = logging.getLogger("mau.mx") az: AppService commands: cmd.CommandProcessor config: config.BaseBridgeConfig bridge: br.Bridge e2ee: EncryptionManager | None require_e2ee: bool media_config: MediaRepoConfig versions: VersionsResponse minimum_spec_version: Version = SpecVersions.V11 room_locks: dict[str, asyncio.Lock] user_id_prefix: str user_id_suffix: str def __init__( self, command_processor: cmd.CommandProcessor | None = None, bridge: br.Bridge | None = None, ) -> None: self.az = bridge.az self.config = bridge.config self.bridge = bridge self.commands = command_processor or cmd.CommandProcessor(bridge=bridge) self.media_config = MediaRepoConfig(upload_size=50 * 1024 * 1024) self.versions = VersionsResponse.deserialize({"versions": ["v1.3"]}) self.az.matrix_event_handler(self.int_handle_event) self.room_locks = defaultdict(asyncio.Lock) self.e2ee = None self.require_e2ee = False if self.config["bridge.encryption.allow"]: if not EncryptionManager: self.log.fatal( "Encryption enabled in config, but dependencies not installed.", exc_info=encryption_import_error, ) sys.exit(31) if not encrypt_attachment: self.log.fatal( "Encryption enabled in config, but media encryption dependencies " "not installed.", exc_info=media_encrypt_import_error, ) sys.exit(31) self.e2ee = EncryptionManager( bridge=bridge, user_id_prefix=self.user_id_prefix, user_id_suffix=self.user_id_suffix, homeserver_address=self.config["homeserver.address"], db_url=self.config["appservice.database"], ) self.require_e2ee = self.config["bridge.encryption.require"] self.management_room_text = self.config.get( "bridge.management_room_text", { "welcome": "Hello, I'm a bridge bot.", "welcome_connected": "Use `help` for help.", "welcome_unconnected": "Use `help` for help on how to log in.", }, ) self.management_room_multiple_messages = self.config.get( "bridge.management_room_multiple_messages", False, ) async def check_versions(self) -> None: if not self.versions.supports_at_least(self.minimum_spec_version): self.log.fatal( "The homeserver is outdated " "(server supports Matrix %s, but the bridge requires at least %s)", self.versions.latest_version, self.minimum_spec_version, ) sys.exit(18) if self.bridge.homeserver_software.is_hungry and not self.versions.supports( "com.beeper.hungry" ): self.log.fatal( "The config claims the homeserver is hungryserv, " "but the /versions response didn't confirm it" ) sys.exit(18) async def wait_for_connection(self) -> None: self.log.info("Ensuring connectivity to homeserver") while True: try: self.versions = await self.az.intent.versions() break except MForbidden: self.log.debug( "/versions endpoint returned M_FORBIDDEN, " "trying to register bridge bot before retrying..." ) await self.az.intent.ensure_registered() except Exception: self.log.exception("Connection to homeserver failed, retrying in 10 seconds") await asyncio.sleep(10) await self.check_versions() try: await self.az.intent.whoami() except MForbidden: self.log.debug( "Whoami endpoint returned M_FORBIDDEN, " "trying to register bridge bot before retrying..." ) await self.az.intent.ensure_registered() await self.az.intent.whoami() if self.versions.supports("fi.mau.msc2659.stable") or self.versions.supports_at_least( SpecVersions.V17 ): try: txn_id = self.az.intent.api.get_txn_id() duration = await self.az.ping_self(txn_id) self.log.debug( "Homeserver->bridge connection works, " f"roundtrip time is {duration} ms (txn ID: {txn_id})" ) except Exception: self.log.exception("Error checking homeserver -> bridge connection") sys.exit(16) else: self.log.debug( "Homeserver does not support checking status of homeserver -> bridge connection" ) try: self.media_config = await self.az.intent.get_media_repo_config() except Exception: self.log.warning("Failed to fetch media repo config", exc_info=True) async def init_as_bot(self) -> None: self.log.debug("Initializing appservice bot") displayname = self.config["appservice.bot_displayname"] if displayname: try: await self.az.intent.set_displayname( displayname if displayname != "remove" else "" ) except Exception: self.log.exception("Failed to set bot displayname") avatar = self.config["appservice.bot_avatar"] if avatar: try: await self.az.intent.set_avatar_url(avatar if avatar != "remove" else "") except Exception: self.log.exception("Failed to set bot avatar") if self.bridge.homeserver_software.is_hungry and self.bridge.beeper_network_name: self.log.debug("Setting contact info on the appservice bot") await self.az.intent.beeper_update_profile( { "com.beeper.bridge.service": self.bridge.beeper_service_name, "com.beeper.bridge.network": self.bridge.beeper_network_name, "com.beeper.bridge.is_bridge_bot": True, } ) async def init_encryption(self) -> None: if self.e2ee: await self.e2ee.start() async def allow_message(self, user: br.BaseUser) -> bool: return user.is_whitelisted or ( self.config["bridge.relay.enabled"] and user.relay_whitelisted ) @staticmethod async def allow_command(user: br.BaseUser) -> bool: return user.is_whitelisted @staticmethod async def allow_bridging_message(user: br.BaseUser, portal: br.BasePortal) -> bool: return await user.is_logged_in() or (user.relay_whitelisted and portal.has_relay) @staticmethod async def allow_puppet_invite(user: br.BaseUser, puppet: br.BasePuppet) -> bool: return await user.is_logged_in() async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None: pass async def handle_kick( self, room_id: RoomID, user_id: UserID, kicked_by: UserID, reason: str, event_id: EventID ) -> None: pass async def handle_ban( self, room_id: RoomID, user_id: UserID, banned_by: UserID, reason: str, event_id: EventID ) -> None: pass async def handle_unban( self, room_id: RoomID, user_id: UserID, unbanned_by: UserID, reason: str, event_id: EventID ) -> None: pass async def handle_join(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None: pass async def handle_knock( self, room_id: RoomID, user_id: UserID, reason: str, event_id: EventID ) -> None: pass async def handle_retract_knock( self, room_id: RoomID, user_id: UserID, reason: str, event_id: EventID ) -> None: pass async def handle_reject_knock( self, room_id: RoomID, user_id: UserID, sender: UserID, reason: str, event_id: EventID ) -> None: pass async def handle_accept_knock( self, room_id: RoomID, user_id: UserID, sender: UserID, reason: str, event_id: EventID ) -> None: pass async def handle_member_info_change( self, room_id: RoomID, user_id: UserID, content: MemberStateEventContent, prev_content: MemberStateEventContent, event_id: EventID, ) -> None: pass async def handle_puppet_group_invite( self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent, members: list[UserID], ) -> None: if self.az.bot_mxid not in members: await puppet.default_mxid_intent.leave_room( room_id, reason="This ghost does not join multi-user rooms without the bridge bot." ) async def handle_puppet_dm_invite( self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent ) -> None: portal = await invited_by.get_portal_with(puppet) if portal: await portal.accept_matrix_dm(room_id, invited_by, puppet) else: await puppet.default_mxid_intent.leave_room( room_id, reason="This bridge does not support creating DMs." ) async def handle_puppet_space_invite( self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent ) -> None: await puppet.default_mxid_intent.leave_room( room_id, reason="This ghost does not join spaces." ) async def handle_puppet_nonportal_invite( self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent ) -> None: intent = puppet.default_mxid_intent await intent.join_room(room_id) try: create_evt = await intent.get_state_event(room_id, EventType.ROOM_CREATE) members = await intent.get_room_members(room_id) except MatrixError: self.log.exception(f"Failed to get state after joining {room_id} as {intent.mxid}") background_task.create(intent.leave_room(room_id, reason="Internal error")) return if create_evt.type == RoomType.SPACE: await self.handle_puppet_space_invite(room_id, puppet, invited_by, evt) elif len(members) > 2 or not evt.content.is_direct: await self.handle_puppet_group_invite(room_id, puppet, invited_by, evt, members) else: await self.handle_puppet_dm_invite(room_id, puppet, invited_by, evt) async def handle_puppet_invite( self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent ) -> None: intent = puppet.default_mxid_intent if not await self.allow_puppet_invite(invited_by, puppet): self.log.debug(f"Rejecting invite for {intent.mxid} to {room_id}: user can't invite") await intent.leave_room(room_id, reason="You're not allowed to invite this ghost.") return async with self.room_locks[room_id]: portal = await self.bridge.get_portal(room_id) if portal: try: await portal.handle_matrix_invite(invited_by, puppet) except br.RejectMatrixInvite as e: await intent.leave_room(room_id, reason=e.message) except br.IgnoreMatrixInvite: pass else: await intent.join_room(room_id) return else: await self.handle_puppet_nonportal_invite(room_id, puppet, invited_by, evt) async def handle_invite( self, room_id: RoomID, user_id: UserID, invited_by: br.BaseUser, evt: StateEvent ) -> None: pass async def handle_reject( self, room_id: RoomID, user_id: UserID, reason: str, event_id: EventID ) -> None: pass async def handle_disinvite( self, room_id: RoomID, user_id: UserID, disinvited_by: UserID, reason: str, event_id: EventID, ) -> None: pass async def handle_event(self, evt: Event) -> None: """ Called by :meth:`int_handle_event` for message events other than m.room.message. **N.B.** You may need to add the event class to :attr:`allowed_event_classes` or override :meth:`allow_matrix_event` for it to reach here. """ async def handle_state_event(self, evt: StateEvent) -> None: """ Called by :meth:`int_handle_event` for state events other than m.room.membership. **N.B.** You may need to add the event class to :attr:`allowed_event_classes` or override :meth:`allow_matrix_event` for it to reach here. """ async def handle_ephemeral_event( self, evt: ReceiptEvent | PresenceEvent | TypingEvent ) -> None: if evt.type == EventType.RECEIPT: await self.handle_receipt(evt) async def send_permission_error(self, room_id: RoomID) -> None: await self.az.intent.send_notice( room_id, text=( "You are not whitelisted to use this bridge.\n\n" "If you are the owner of this bridge, see the bridge.permissions " "section in your config file." ), html=( "

You are not whitelisted to use this bridge.

" "

If you are the owner of this bridge, see the " "bridge.permissions section in your config file.

" ), ) async def accept_bot_invite(self, room_id: RoomID, inviter: br.BaseUser) -> None: try: await self.az.intent.join_room(room_id) except Exception: self.log.exception(f"Failed to join room {room_id} as bridge bot") return if not await self.allow_command(inviter): await self.send_permission_error(room_id) await self.az.intent.leave_room(room_id) return await self.send_welcome_message(room_id, inviter) async def send_welcome_message(self, room_id: RoomID, inviter: br.BaseUser) -> None: has_two_members, bridge_bot_in_room = await self._is_direct_chat(room_id) is_management = has_two_members and bridge_bot_in_room welcome_messages = [self.management_room_text.get("welcome")] if is_management: if await inviter.is_logged_in(): welcome_messages.append(self.management_room_text.get("welcome_connected")) else: welcome_messages.append(self.management_room_text.get("welcome_unconnected")) additional_help = self.management_room_text.get("additional_help") if additional_help: welcome_messages.append(additional_help) else: cmd_prefix = self.commands.command_prefix welcome_messages.append(f"Use `{cmd_prefix} help` for help.") if self.management_room_multiple_messages: for m in welcome_messages: await self.az.intent.send_notice(room_id, text=m, html=markdown.render(m)) else: combined = "\n".join(welcome_messages) combined_html = "".join(map(markdown.render, welcome_messages)) await self.az.intent.send_notice(room_id, text=combined, html=combined_html) async def int_handle_invite(self, evt: StateEvent) -> None: self.log.debug(f"{evt.sender} invited {evt.state_key} to {evt.room_id}") inviter = await self.bridge.get_user(evt.sender) if inviter is None: self.log.exception(f"Failed to find user with Matrix ID {evt.sender}") return elif evt.state_key == self.az.bot_mxid: await self.accept_bot_invite(evt.room_id, inviter) return puppet = await self.bridge.get_puppet(UserID(evt.state_key)) if puppet: await self.handle_puppet_invite(evt.room_id, puppet, inviter, evt) return await self.handle_invite(evt.room_id, UserID(evt.state_key), inviter, evt) def is_command(self, message: MessageEventContent) -> tuple[bool, str]: text = message.body prefix = self.config["bridge.command_prefix"] is_command = text.startswith(prefix) if is_command: text = text[len(prefix) + 1 :].lstrip() return is_command, text async def _send_mss( self, evt: Event, status: MessageStatus, reason: MessageStatusReason | None = None, error: str | None = None, message: str | None = None, ) -> None: if not self.config.get("bridge.message_status_events", False): return status_content = BeeperMessageStatusEventContent( network="", # TODO set network properly relates_to=RelatesTo(rel_type=RelationType.REFERENCE, event_id=evt.event_id), status=status, reason=reason, error=error, message=message, ) await self.az.intent.send_message_event( evt.room_id, EventType.BEEPER_MESSAGE_STATUS, status_content ) async def _send_crypto_status_error( self, evt: Event, err: DecryptionError | None = None, retry_num: int = 0, is_final: bool = True, edit: EventID | None = None, wait_for: int | None = None, ) -> EventID | None: msg = str(err) if isinstance(err, (SessionNotFound, UnencryptedMessageError)): msg = err.human_message self._send_message_checkpoint( evt, MessageSendCheckpointStep.DECRYPTED, msg, permanent=is_final, retry_num=retry_num ) if wait_for: msg += f". The bridge will retry for {wait_for} seconds" full_msg = f"\u26a0 Your message was not bridged: {msg}." if isinstance(err, EncryptionUnsupportedError): full_msg = "🔒️ This bridge has not been configured to support encryption" event_id = None if self.config.get("bridge.delivery_error_reports", True): try: content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=full_msg) if edit: content.set_edit(edit) event_id = await self.az.intent.send_message(evt.room_id, content) except IntentError: self.log.debug("IntentError while sending encryption error", exc_info=True) self.log.error( "Got IntentError while trying to send encryption error message. " "This likely means the bridge bot is not in the room, which can " "happen if you force-enable e2ee on the homeserver without enabling " "it by default on the bridge (bridge -> encryption -> default)." ) await self._send_mss( evt, status=MessageStatus.RETRIABLE if is_final else MessageStatus.PENDING, reason=MessageStatusReason.UNDECRYPTABLE, error=str(err), message=err.human_message if err else None, ) return event_id async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) -> None: room_id = evt.room_id user_id = evt.sender event_id = evt.event_id message = evt.content if not was_encrypted and self.require_e2ee: self.log.warning(f"Dropping {event_id} from {user_id} as it's not encrypted!") await self._send_crypto_status_error(evt, UnencryptedMessageError(), 0) return sender = await self.bridge.get_user(user_id) if not sender or not await self.allow_message(sender): self.log.debug( f"Ignoring message {event_id} from {user_id} to {room_id}:" " user is not whitelisted." ) self._send_message_checkpoint( evt, MessageSendCheckpointStep.BRIDGE, "user is not whitelisted" ) return self.log.debug(f"Received Matrix event {event_id} from {sender.mxid} in {room_id}") self.log.trace("Event %s content: %s", event_id, message) if isinstance(message, TextMessageEventContent): message.trim_reply_fallback() is_command, text = self.is_command(message) portal = await self.bridge.get_portal(room_id) if not is_command and portal: if await self.allow_bridging_message(sender, portal): await portal.handle_matrix_message(sender, message, event_id) else: self.log.debug( f"Ignoring event {event_id} from {sender.mxid}:" " not allowed to send to portal" ) self._send_message_checkpoint( evt, MessageSendCheckpointStep.BRIDGE, "user is not allowed to send to the portal", ) return if message.msgtype != MessageType.TEXT: self.log.debug( f"Ignoring event {event_id}: not a portal room and not a m.text message" ) self._send_message_checkpoint( evt, MessageSendCheckpointStep.BRIDGE, "not a portal room and not a m.text message" ) return elif not await self.allow_command(sender): self.log.debug( f"Ignoring command {event_id} from {sender.mxid}: not allowed to run commands" ) self._send_message_checkpoint( evt, MessageSendCheckpointStep.COMMAND, "not allowed to run commands" ) return has_two_members, bridge_bot_in_room = await self._is_direct_chat(room_id) is_management = has_two_members and bridge_bot_in_room if is_command or is_management: try: command, arguments = text.split(" ", 1) args = arguments.split(" ") except ValueError: # Not enough values to unpack, i.e. no arguments command = text args = [] try: await self.commands.handle( room_id, event_id, sender, command, args, message, portal, is_management, bridge_bot_in_room, ) except Exception as e: self.log.debug(f"Error handling command {command} from {sender}: {e}") self._send_message_checkpoint(evt, MessageSendCheckpointStep.COMMAND, e) await self._send_mss( evt, status=MessageStatus.FAIL, reason=MessageStatusReason.GENERIC_ERROR, error="", message="Command execution failed", ) else: await MessageSendCheckpoint( event_id=event_id, room_id=room_id, step=MessageSendCheckpointStep.COMMAND, timestamp=int(time.time() * 1000), status=MessageSendCheckpointStatus.SUCCESS, reported_by=MessageSendCheckpointReportedBy.BRIDGE, event_type=EventType.ROOM_MESSAGE, message_type=message.msgtype, ).send( self.bridge.config["homeserver.message_send_checkpoint_endpoint"], self.az.as_token, self.log, ) await self._send_mss(evt, status=MessageStatus.SUCCESS) else: self.log.debug( f"Ignoring event {event_id} from {sender.mxid}:" " not a command and not a portal room" ) self._send_message_checkpoint( evt, MessageSendCheckpointStep.COMMAND, "not a command and not a portal room" ) await self._send_mss( evt, status=MessageStatus.FAIL, reason=MessageStatusReason.UNSUPPORTED, error="Unknown room", message="Unknown room", ) async def _is_direct_chat(self, room_id: RoomID) -> tuple[bool, bool]: try: members = await self.az.intent.get_room_members(room_id) return len(members) == 2, self.az.bot_mxid in members except MatrixError: return False, False async def handle_receipt(self, evt: ReceiptEvent) -> None: for event_id, receipts in evt.content.items(): for user_id, data in receipts.get(ReceiptType.READ, {}).items(): user = await self.bridge.get_user(user_id, create=False) if not user or not await user.is_logged_in(): continue portal = await self.bridge.get_portal(evt.room_id) if not portal: continue await portal.schedule_disappearing() if ( data.get(DOUBLE_PUPPET_SOURCE_KEY) == self.az.bridge_name and await self.bridge.get_double_puppet(user_id) is not None ): continue await self.handle_read_receipt(user, portal, event_id, data) async def handle_read_receipt( self, user: br.BaseUser, portal: br.BasePortal, event_id: EventID, data: SingleReceiptEventContent, ) -> None: pass async def try_handle_sync_event(self, evt: Event) -> None: try: if isinstance(evt, (ReceiptEvent, PresenceEvent, TypingEvent)): await self.handle_ephemeral_event(evt) else: self.log.trace("Unknown event type received from sync: %s", evt) except Exception: self.log.exception("Error handling manually received Matrix event") async def _post_decrypt( self, evt: Event, retry_num: int = 0, error_event_id: EventID | None = None ) -> None: trust_state = evt["mautrix"]["trust_state"] if trust_state < self.e2ee.min_send_trust: self.log.warning( f"Dropping {evt.event_id} from {evt.sender} due to insufficient verification level" f" (event: {trust_state}, required: {self.e2ee.min_send_trust})" ) await self._send_crypto_status_error( evt, retry_num=retry_num, err=DeviceUntrustedError(trust_state), edit=error_event_id, ) return self._send_message_checkpoint( evt, MessageSendCheckpointStep.DECRYPTED, retry_num=retry_num ) if error_event_id: await self.az.intent.redact(evt.room_id, error_event_id) await self.int_handle_event(evt, was_encrypted=True) async def handle_encrypted(self, evt: EncryptedEvent) -> None: if not self.e2ee: self.log.debug( "Got encrypted message %s from %s, but encryption is not enabled", evt.event_id, evt.sender, ) await self._send_crypto_status_error(evt, EncryptionUnsupportedError()) return try: decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=3) except SessionNotFound as e: await self._handle_encrypted_wait(evt, e, wait=22) except DecryptionError as e: self.log.warning(f"Failed to decrypt {evt.event_id}: {e}") self.log.trace("%s decryption traceback:", evt.event_id, exc_info=True) await self._send_crypto_status_error(evt, e) else: await self._post_decrypt(decrypted) async def _handle_encrypted_wait( self, evt: EncryptedEvent, err: SessionNotFound, wait: int ) -> None: self.log.debug( f"Couldn't find session {err.session_id} trying to decrypt {evt.event_id}," " waiting even longer" ) background_task.create( self.e2ee.crypto.request_room_key( evt.room_id, evt.content.sender_key, evt.content.session_id, from_devices={evt.sender: [evt.content.device_id]}, ) ) event_id = await self._send_crypto_status_error(evt, err, is_final=False, wait_for=wait) got_keys = await self.e2ee.crypto.wait_for_session( evt.room_id, err.session_id, timeout=wait ) if got_keys: self.log.debug( f"Got session {err.session_id} after waiting more, " f"trying to decrypt {evt.event_id} again" ) try: decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=0) except DecryptionError as e: await self._send_crypto_status_error(evt, e, retry_num=1, edit=event_id) self.log.warning(f"Failed to decrypt {evt.event_id}: {e}") self.log.trace("%s decryption traceback:", evt.event_id, exc_info=True) else: await self._post_decrypt(decrypted, retry_num=1, error_event_id=event_id) return else: self.log.warning(f"Didn't get {err.session_id}, giving up on {evt.event_id}") await self._send_crypto_status_error( evt, SessionNotFound(err.session_id), retry_num=1, edit=event_id ) async def handle_encryption(self, evt: StateEvent) -> None: await self.az.state_store.set_encryption_info(evt.room_id, evt.content) portal = await self.bridge.get_portal(evt.room_id) if portal: portal.encrypted = True await portal.save() if portal.is_direct: portal.log.debug("Received encryption event in direct portal: %s", evt.content) await portal.enable_dm_encryption() def _send_message_checkpoint( self, evt: Event, step: MessageSendCheckpointStep, err: Exception | str | None = None, permanent: bool = True, retry_num: int = 0, ) -> None: endpoint = self.bridge.config["homeserver.message_send_checkpoint_endpoint"] if not endpoint: return if evt.type not in CHECKPOINT_TYPES: return self.log.debug(f"Sending message send checkpoint for {evt.event_id} (step: {step})") status = MessageSendCheckpointStatus.SUCCESS if err: status = ( MessageSendCheckpointStatus.PERM_FAILURE if permanent else MessageSendCheckpointStatus.WILL_RETRY ) checkpoint = MessageSendCheckpoint( event_id=evt.event_id, room_id=evt.room_id, step=step, timestamp=int(time.time() * 1000), status=status, reported_by=MessageSendCheckpointReportedBy.BRIDGE, event_type=evt.type, message_type=evt.content.msgtype if evt.type == EventType.ROOM_MESSAGE else None, info=str(err) if err else None, retry_num=retry_num, ) background_task.create(checkpoint.send(endpoint, self.az.as_token, self.log)) allowed_event_classes: tuple[type, ...] = ( MessageEvent, StateEvent, ReactionEvent, EncryptedEvent, RedactionEvent, ReceiptEvent, TypingEvent, PresenceEvent, ) async def allow_matrix_event(self, evt: Event) -> bool: # If the event is not one of the allowed classes, ignore it. if not isinstance(evt, self.allowed_event_classes): return False # For room events, make sure the message didn't originate from the bridge. if isinstance(evt, BaseRoomEvent): # If the event is from a bridge ghost, ignore it. if evt.sender == self.az.bot_mxid or self.bridge.is_bridge_ghost(evt.sender): return False # If the event is marked as double puppeted and we can confirm that we are in fact # double puppeting that user ID, ignore it. if ( evt.content.get(DOUBLE_PUPPET_SOURCE_KEY) == self.az.bridge_name and await self.bridge.get_double_puppet(evt.sender) is not None ): return False # For non-room events and non-bridge-originated room events, allow. return True async def int_handle_event(self, evt: Event, was_encrypted: bool = False) -> None: if isinstance(evt, StateEvent) and evt.type == EventType.ROOM_MEMBER and self.e2ee: await self.e2ee.handle_member_event(evt) if not await self.allow_matrix_event(evt): return self.log.trace("Received event: %s", evt) if not was_encrypted: self._send_message_checkpoint(evt, MessageSendCheckpointStep.BRIDGE) start_time = time.time() if evt.type == EventType.ROOM_MEMBER: evt: StateEvent unsigned = evt.unsigned or StateUnsigned() prev_content = unsigned.prev_content or MemberStateEventContent() prev_membership = prev_content.membership if prev_content else Membership.JOIN if evt.content.membership == Membership.INVITE: if prev_membership == Membership.KNOCK: await self.handle_accept_knock( evt.room_id, UserID(evt.state_key), evt.sender, evt.content.reason, evt.event_id, ) else: await self.int_handle_invite(evt) elif evt.content.membership == Membership.LEAVE: if prev_membership == Membership.BAN: await self.handle_unban( evt.room_id, UserID(evt.state_key), evt.sender, evt.content.reason, evt.event_id, ) elif prev_membership == Membership.INVITE: if evt.sender == evt.state_key: await self.handle_reject( evt.room_id, UserID(evt.state_key), evt.content.reason, evt.event_id ) else: await self.handle_disinvite( evt.room_id, UserID(evt.state_key), evt.sender, evt.content.reason, evt.event_id, ) elif prev_membership == Membership.KNOCK: if evt.sender == evt.state_key: await self.handle_retract_knock( evt.room_id, UserID(evt.state_key), evt.content.reason, evt.event_id ) else: await self.handle_reject_knock( evt.room_id, UserID(evt.state_key), evt.sender, evt.content.reason, evt.event_id, ) elif evt.sender == evt.state_key: await self.handle_leave(evt.room_id, UserID(evt.state_key), evt.event_id) else: await self.handle_kick( evt.room_id, UserID(evt.state_key), evt.sender, evt.content.reason, evt.event_id, ) elif evt.content.membership == Membership.BAN: await self.handle_ban( evt.room_id, UserID(evt.state_key), evt.sender, evt.content.reason, evt.event_id, ) elif evt.content.membership == Membership.JOIN: if prev_membership != Membership.JOIN: await self.handle_join(evt.room_id, UserID(evt.state_key), evt.event_id) else: await self.handle_member_info_change( evt.room_id, UserID(evt.state_key), evt.content, prev_content, evt.event_id ) elif evt.content.membership == Membership.KNOCK: await self.handle_knock( evt.room_id, UserID(evt.state_key), evt.content.reason, evt.event_id, ) elif evt.type in (EventType.ROOM_MESSAGE, EventType.STICKER): evt: MessageEvent if evt.type != EventType.ROOM_MESSAGE: evt.content.msgtype = MessageType(str(evt.type)) await self.handle_message(evt, was_encrypted=was_encrypted) elif evt.type == EventType.ROOM_ENCRYPTED: await self.handle_encrypted(evt) elif evt.type == EventType.ROOM_ENCRYPTION: await self.handle_encryption(evt) else: if evt.type.is_state and isinstance(evt, StateEvent): await self.handle_state_event(evt) elif evt.type.is_ephemeral and isinstance( evt, (PresenceEvent, TypingEvent, ReceiptEvent) ): await self.handle_ephemeral_event(evt) else: await self.handle_event(evt) await self.log_event_handle_duration(evt, time.time() - start_time) async def log_event_handle_duration(self, evt: Event, duration: float) -> None: EVENT_TIME.labels(event_type=str(evt.type)).observe(duration) python-0.20.7/mautrix/bridge/notification_disabler.py000066400000000000000000000052471473573527000230000ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Type import logging from mautrix.api import Method, Path, PathBuilder from mautrix.appservice import IntentAPI from mautrix.types import RoomID, UserID from mautrix.util.logging import TraceLogger from .puppet import BasePuppet from .user import BaseUser class NotificationDisabler: puppet_cls: Type[BasePuppet] config_enabled: bool = False log: TraceLogger = logging.getLogger("mau.notification_disabler") user_id: UserID room_id: RoomID intent: IntentAPI | None enabled: bool def __init__(self, room_id: RoomID, user: BaseUser) -> None: self.user_id = user.mxid self.room_id = room_id self.enabled = False @property def _path(self) -> PathBuilder: return Path.v3.pushrules["global"].override[ f"net.maunium.silence_while_backfilling:{self.room_id}" ] @property def _rule(self) -> dict: return { "actions": ["dont_notify"], "conditions": [ { "kind": "event_match", "key": "room_id", "pattern": self.room_id, } ], } async def __aenter__(self) -> None: puppet = await self.puppet_cls.get_by_custom_mxid(self.user_id) self.intent = puppet.intent if puppet and puppet.is_real_user else None if not self.intent or not self.config_enabled: return self.enabled = True try: self.log.debug(f"Disabling notifications in {self.room_id} for {self.intent.mxid}") await self.intent.api.request(Method.PUT, self._path, content=self._rule) except Exception: self.log.warning( f"Failed to disable notifications in {self.room_id} " f"for {self.intent.mxid} while backfilling", exc_info=True, ) raise async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: if not self.enabled: return try: self.log.debug(f"Re-enabling notifications in {self.room_id} for {self.intent.mxid}") await self.intent.api.request(Method.DELETE, self._path) except Exception: self.log.warning( f"Failed to re-enable notifications in {self.room_id} " f"for {self.intent.mxid} after backfilling", exc_info=True, ) python-0.20.7/mautrix/bridge/portal.py000066400000000000000000000501401473573527000177360ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, NamedTuple from abc import ABC, abstractmethod from collections import defaultdict from string import Template import asyncio import html import logging import time from mautrix.appservice import AppService, IntentAPI from mautrix.errors import MatrixError, MatrixRequestError, MForbidden, MNotFound from mautrix.types import ( JSON, EncryptionAlgorithm, EventID, EventType, Format, MessageEventContent, MessageType, RoomEncryptionStateEventContent, RoomID, RoomTombstoneStateEventContent, TextMessageEventContent, UserID, ) from mautrix.util import background_task from mautrix.util.logging import TraceLogger from mautrix.util.simple_lock import SimpleLock from .. import bridge as br class RelaySender(NamedTuple): sender: br.BaseUser | None is_relay: bool class RejectMatrixInvite(Exception): def __init__(self, message: str) -> None: super().__init__(message) self.message = message class IgnoreMatrixInvite(Exception): pass class DMCreateError(RejectMatrixInvite): """ An error raised by :meth:`BasePortal.prepare_dm` if the DM can't be set up. The message in the exception will be sent to the user as a message before the ghost leaves. """ class BasePortal(ABC): log: TraceLogger = logging.getLogger("mau.portal") _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) disappearing_msg_class: type[br.AbstractDisappearingMessage] | None = None _disappearing_lock: asyncio.Lock | None az: AppService matrix: br.BaseMatrixHandler bridge: br.Bridge loop: asyncio.AbstractEventLoop main_intent: IntentAPI mxid: RoomID | None name: str | None encrypted: bool is_direct: bool backfill_lock: SimpleLock relay_user_id: UserID | None _relay_user: br.BaseUser | None relay_emote_to_text: bool = True relay_formatted_body: bool = True def __init__(self) -> None: self._disappearing_lock = asyncio.Lock() if self.disappearing_msg_class else None @abstractmethod async def save(self) -> None: pass @abstractmethod async def get_dm_puppet(self) -> br.BasePuppet | None: """ Get the ghost representing the other end of this direct chat. Returns: A puppet entity, or ``None`` if this is not a 1:1 chat. """ @abstractmethod async def handle_matrix_message( self, sender: br.BaseUser, message: MessageEventContent, event_id: EventID ) -> None: pass async def prepare_remote_dm( self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet ) -> str: """ Do whatever is needed on the remote platform to set up a direct chat between the user and the ghost. By default, this does nothing (and lets :meth:`setup_matrix_dm` handle everything). Args: room_id: The room ID that will be used. invited_by: The Matrix user who invited the ghost. puppet: The ghost who was invited. Returns: A simple message indicating what was done (will be sent as a notice to the room). If empty, the message won't be sent. Raises: DMCreateError: if the DM could not be created and the ghost should leave the room. """ return "Portal to private chat created." async def postprocess_matrix_dm(self, user: br.BaseUser, puppet: br.BasePuppet) -> None: await self.update_bridge_info() async def reject_duplicate_dm( self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet ) -> None: try: await puppet.default_mxid_intent.send_notice( room_id, text=f"You already have a private chat with me: {self.mxid}", html=( "You already have a private chat with me: " f"
Link to room" ), ) except Exception as e: self.log.debug(f"Failed to send notice to duplicate private chat room: {e}") try: await puppet.default_mxid_intent.send_state_event( room_id, event_type=EventType.ROOM_TOMBSTONE, content=RoomTombstoneStateEventContent( replacement_room=self.mxid, body="You already have a private chat with me", ), ) except Exception as e: self.log.debug(f"Failed to send tombstone to duplicate private chat room: {e}") await puppet.default_mxid_intent.leave_room(room_id) async def accept_matrix_dm( self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet ) -> None: """ Set up a room as a direct chat portal. The ghost has already accepted the invite at this point, so this method needs to make it leave if the DM can't be created for some reason. By default, this checks if there's an existing portal and redirects the user there if it does exist. If a portal doesn't exist, this will call :meth:`prepare_matrix_dm` and then save the room ID, enable encryption and update bridge info. If the portal exists, but isn't usable, the old room will be cleaned up and the function will continue. Args: room_id: The room ID that will be used. invited_by: The Matrix user who invited the ghost. puppet: The ghost who was invited. """ if self.mxid: try: portal_members = await self.main_intent.get_room_members(self.mxid) except (MForbidden, MNotFound): portal_members = [] if invited_by.mxid in portal_members: await self.reject_duplicate_dm(room_id, invited_by, puppet) return self.log.debug( f"{invited_by.mxid} isn't in old portal room {self.mxid}," " cleaning up and accepting new room as the DM portal" ) await self.cleanup_portal( message="User seems to have left DM portal", puppets_only=True ) try: message = await self.prepare_remote_dm(room_id, invited_by, puppet) except DMCreateError as e: if e.message: await puppet.default_mxid_intent.send_notice(room_id, text=e.message) await puppet.default_mxid_intent.leave_room(room_id, reason="Failed to create DM") return self.mxid = room_id e2be_ok = await self.check_dm_encryption() await self.save() if e2be_ok is False: message += "\n\nWarning: Failed to enable end-to-bridge encryption." if message: await self._send_message( puppet.default_mxid_intent, TextMessageEventContent( msgtype=MessageType.NOTICE, body=message, ), ) await self.postprocess_matrix_dm(invited_by, puppet) async def handle_matrix_invite(self, invited_by: br.BaseUser, puppet: br.BasePuppet) -> None: """ Called when a Matrix user invites a bridge ghost to a room to process the invite (and check if it should be accepted). Args: invited_by: The user who invited the ghost. puppet: The ghost who was invited. Raises: RejectMatrixInvite: if the invite should be rejected. IgnoreMatrixInvite: if the invite should be ignored (e.g. if it was already accepted). """ if self.is_direct: raise RejectMatrixInvite("You can't invite additional users to private chats.") raise RejectMatrixInvite("This bridge does not implement inviting users to portals.") async def update_bridge_info(self) -> None: """Resend the ``m.bridge`` event into the room.""" @property def _relay_is_implemented(self) -> bool: return hasattr(self, "relay_user_id") and hasattr(self, "_relay_user") @property def has_relay(self) -> bool: return ( self._relay_is_implemented and self.bridge.config["bridge.relay.enabled"] and bool(self.relay_user_id) ) async def get_relay_user(self) -> br.BaseUser | None: if not self.has_relay: return None if self._relay_user is None: self._relay_user = await self.bridge.get_user(self.relay_user_id) return self._relay_user if await self._relay_user.is_logged_in() else None async def set_relay_user(self, user: br.BaseUser | None) -> None: if not self._relay_is_implemented or not self.bridge.config["bridge.relay.enabled"]: raise RuntimeError("Can't set_relay_user() when relay mode is not enabled") self._relay_user = user self.relay_user_id = user.mxid if user else None await self.save() async def get_relay_sender(self, sender: br.BaseUser, evt_identifier: str) -> RelaySender: if not await sender.needs_relay(self): return RelaySender(sender, False) if not self.has_relay: self.log.debug( f"Ignoring {evt_identifier} from non-logged-in user {sender.mxid} " f"in chat with no relay user" ) return RelaySender(None, True) relay_sender = await self.get_relay_user() if not relay_sender: self.log.debug( f"Ignoring {evt_identifier} from non-logged-in user {sender.mxid} " f"relay user {self.relay_user_id} is not set up correctly" ) return RelaySender(None, True) return RelaySender(relay_sender, True) async def apply_relay_message_format( self, sender: br.BaseUser, content: MessageEventContent ) -> None: if self.relay_formatted_body and content.get("format", None) != Format.HTML: content["format"] = Format.HTML content["formatted_body"] = html.escape(content.body).replace("\n", "
") tpl = self.bridge.config["bridge.relay.message_formats"].get( content.msgtype.value, "$sender_displayname: $message" ) displayname = await self.get_displayname(sender) username, _ = self.az.intent.parse_user_id(sender.mxid) tpl_args = { "sender_mxid": sender.mxid, "sender_username": username, "sender_displayname": html.escape(displayname), "formatted_body": content["formatted_body"], "body": content.body, "message": content.body, } content.body = Template(tpl).safe_substitute(tpl_args) if self.relay_formatted_body and "formatted_body" in content: tpl_args["message"] = content["formatted_body"] content["formatted_body"] = Template(tpl).safe_substitute(tpl_args) if self.relay_emote_to_text and content.msgtype == MessageType.EMOTE: content.msgtype = MessageType.TEXT async def get_displayname(self, user: br.BaseUser) -> str: return await self.main_intent.get_room_displayname(self.mxid, user.mxid) or user.mxid async def check_dm_encryption(self) -> bool | None: try: evt = await self.main_intent.get_state_event(self.mxid, EventType.ROOM_ENCRYPTION) self.log.debug("Found existing encryption event in direct portal: %s", evt) if evt and evt.algorithm == EncryptionAlgorithm.MEGOLM_V1: self.encrypted = True except MNotFound: pass if ( self.is_direct and self.matrix.e2ee and (self.bridge.config["bridge.encryption.default"] or self.encrypted) ): return await self.enable_dm_encryption() return None def get_encryption_state_event_json(self) -> JSON: evt = RoomEncryptionStateEventContent(EncryptionAlgorithm.MEGOLM_V1) if self.bridge.config["bridge.encryption.rotation.enable_custom"]: evt.rotation_period_ms = self.bridge.config["bridge.encryption.rotation.milliseconds"] evt.rotation_period_msgs = self.bridge.config["bridge.encryption.rotation.messages"] return evt.serialize() async def enable_dm_encryption(self) -> bool: self.log.debug("Inviting bridge bot to room for end-to-bridge encryption") try: await self.main_intent.invite_user(self.mxid, self.az.bot_mxid) await self.az.intent.join_room_by_id(self.mxid) if not self.encrypted: await self.main_intent.send_state_event( self.mxid, EventType.ROOM_ENCRYPTION, self.get_encryption_state_event_json(), ) except Exception: self.log.warning(f"Failed to enable end-to-bridge encryption", exc_info=True) return False self.encrypted = True await self.update_info_from_puppet() return True async def update_info_from_puppet(self, puppet: br.BasePuppet | None = None) -> None: """ Update the room metadata to match the ghost's name/avatar. This is called after enabling encryption, as the bridge bot needs to join for e2ee, but that messes up the default name generation. If/when canonical DMs happen, this might not be necessary anymore. Args: puppet: The ghost that is the other participant in the room. If ``None``, the entity should be fetched as necessary. """ @property def disappearing_enabled(self) -> bool: return bool(self.disappearing_msg_class) async def _disappear_event(self, msg: br.AbstractDisappearingMessage) -> None: sleep_time = (msg.expiration_ts / 1000) - time.time() self.log.trace(f"Sleeping {sleep_time:.3f} seconds before redacting {msg.event_id}") await asyncio.sleep(sleep_time) try: await msg.delete() except Exception: self.log.exception( f"Failed to delete disappearing message record for {msg.event_id} from database" ) if self.mxid != msg.room_id: self.log.debug( f"Not redacting expired event {msg.event_id}, " f"portal room seems to have changed ({self.mxid!r} != {msg.room_id!r})" ) return try: await self._do_disappear(msg.event_id) self.log.debug(f"Expired event {msg.event_id} disappeared successfully") except Exception as e: self.log.warning(f"Failed to make expired event {msg.event_id} disappear: {e}") async def _do_disappear(self, event_id: EventID) -> None: await self.main_intent.redact(self.mxid, event_id) @classmethod async def restart_scheduled_disappearing(cls) -> None: """ Restart disappearing message timers for all messages that were already scheduled to disappear earlier. This should be called at bridge startup. """ if not cls.disappearing_msg_class: return msgs = await cls.disappearing_msg_class.get_all_scheduled() for msg in msgs: portal = await cls.bridge.get_portal(msg.room_id) if portal and portal.mxid: background_task.create(portal._disappear_event(msg)) else: await msg.delete() async def schedule_disappearing(self) -> None: """ Start the disappearing message timer for all unscheduled messages in this room. This is automatically called from :meth:`MatrixHandler.handle_receipt`. """ if not self.disappearing_msg_class: return async with self._disappearing_lock: msgs = await self.disappearing_msg_class.get_unscheduled_for_room(self.mxid) for msg in msgs: msg.start_timer() await msg.update() background_task.create(self._disappear_event(msg)) async def _send_message( self, intent: IntentAPI, content: MessageEventContent, event_type: EventType = EventType.ROOM_MESSAGE, **kwargs, ) -> EventID: if self.encrypted and self.matrix.e2ee: event_type, content = await self.matrix.e2ee.encrypt(self.mxid, event_type, content) event_id = await intent.send_message_event(self.mxid, event_type, content, **kwargs) if intent.api.is_real_user: background_task.create(intent.mark_read(self.mxid, event_id)) return event_id @property @abstractmethod def bridge_info_state_key(self) -> str: pass @property @abstractmethod def bridge_info(self) -> dict[str, Any]: pass # region Matrix room cleanup @abstractmethod async def delete(self) -> None: pass @classmethod async def cleanup_room( cls, intent: IntentAPI, room_id: RoomID, message: str = "Cleaning room", puppets_only: bool = False, ) -> None: if not puppets_only and cls.bridge.homeserver_software.is_hungry: try: await intent.beeper_delete_room(room_id) return except MNotFound as err: cls.log.debug(f"Hungryserv yeet returned {err}, assuming the room is already gone") return except Exception: cls.log.warning( f"Failed to delete {room_id} using hungryserv yeet endpoint, " f"falling back to normal method", exc_info=True, ) try: members = await intent.get_room_members(room_id) except MatrixError: members = [] for user_id in members: if user_id == intent.mxid: continue puppet = await cls.bridge.get_puppet(user_id, create=False) if puppet: await puppet.default_mxid_intent.leave_room(room_id) continue if not puppets_only: custom_puppet = await cls.bridge.get_double_puppet(user_id) left = False if custom_puppet: try: await custom_puppet.intent.leave_room(room_id) await custom_puppet.intent.forget_room(room_id) except MatrixError: pass else: left = True if not left: try: await intent.kick_user(room_id, user_id, message) except MatrixError: pass try: await intent.leave_room(room_id) except MatrixError: cls.log.warning(f"Failed to leave room {room_id} when cleaning up room", exc_info=True) async def cleanup_portal(self, message: str, puppets_only: bool = False) -> None: await self.cleanup_room(self.main_intent, self.mxid, message, puppets_only) await self.delete() async def unbridge(self) -> None: await self.cleanup_portal("Room unbridged", puppets_only=True) async def cleanup_and_delete(self) -> None: await self.cleanup_portal("Portal deleted") async def get_authenticated_matrix_users(self) -> list[UserID]: """ Get the list of Matrix user IDs who can be bridged. This is used to determine if the portal is empty (and should be cleaned up) or not. Bridges should override this to check that the users are either logged in or the portal has a relaybot. """ try: members = await self.main_intent.get_room_members(self.mxid) except MatrixRequestError: return [] return [ member for member in members if (not self.bridge.is_bridge_ghost(member) and member != self.az.bot_mxid) ] # endregion python-0.20.7/mautrix/bridge/puppet.py000066400000000000000000000022071473573527000177530ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any from abc import ABC, abstractmethod from collections import defaultdict import asyncio import logging from mautrix.appservice import AppService, IntentAPI from mautrix.types import UserID from mautrix.util.logging import TraceLogger from .. import bridge as br from .custom_puppet import CustomPuppetMixin class BasePuppet(CustomPuppetMixin, ABC): log: TraceLogger = logging.getLogger("mau.puppet") _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) az: AppService loop: asyncio.AbstractEventLoop mx: br.BaseMatrixHandler is_registered: bool mxid: str intent: IntentAPI @classmethod @abstractmethod async def get_by_mxid(cls, mxid: UserID) -> BasePuppet: pass @classmethod @abstractmethod async def get_by_custom_mxid(cls, mxid: UserID) -> BasePuppet: pass python-0.20.7/mautrix/bridge/state_store/000077500000000000000000000000001473573527000204175ustar00rootroot00000000000000python-0.20.7/mautrix/bridge/state_store/__init__.py000066400000000000000000000000261473573527000225260ustar00rootroot00000000000000__all__ = ["asyncpg"] python-0.20.7/mautrix/bridge/state_store/asyncpg.py000066400000000000000000000027231473573527000224410ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Awaitable, Callable, Union from mautrix.appservice.state_store.asyncpg import PgASStateStore from mautrix.types import UserID from mautrix.util.async_db import Database from ..puppet import BasePuppet GetPuppetFunc = Union[ Callable[[UserID], Awaitable[BasePuppet]], Callable[[UserID, bool], Awaitable[BasePuppet]] ] class PgBridgeStateStore(PgASStateStore): def __init__( self, db: Database, get_puppet: GetPuppetFunc, get_double_puppet: GetPuppetFunc ) -> None: super().__init__(db) self.get_puppet = get_puppet self.get_double_puppet = get_double_puppet async def is_registered(self, user_id: UserID) -> bool: puppet = await self.get_puppet(user_id) if puppet: return puppet.is_registered custom_puppet = await self.get_double_puppet(user_id) if custom_puppet: return True return await super().is_registered(user_id) async def registered(self, user_id: UserID) -> None: puppet = await self.get_puppet(user_id, True) if puppet: puppet.is_registered = True await puppet.save() else: await super().registered(user_id) python-0.20.7/mautrix/bridge/user.py000066400000000000000000000231321473573527000174140ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, NamedTuple from abc import ABC, abstractmethod from collections import defaultdict, deque import asyncio import logging import time from mautrix.api import Method, Path from mautrix.appservice import AppService from mautrix.errors import MNotFound from mautrix.types import EventID, EventType, Membership, MessageType, RoomID, UserID from mautrix.util import background_task from mautrix.util.bridge_state import BridgeState, BridgeStateEvent from mautrix.util.logging import TraceLogger from mautrix.util.message_send_checkpoint import ( MessageSendCheckpoint, MessageSendCheckpointReportedBy, MessageSendCheckpointStatus, MessageSendCheckpointStep, ) from mautrix.util.opt_prometheus import Gauge from .. import bridge as br AsmuxPath = Path.unstable["com.beeper.asmux"] class WrappedTask(NamedTuple): task: asyncio.Task | None class BaseUser(ABC): log: TraceLogger = logging.getLogger("mau.user") _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) az: AppService bridge: br.Bridge loop: asyncio.AbstractEventLoop is_whitelisted: bool is_admin: bool relay_whitelisted: bool mxid: UserID dm_update_lock: asyncio.Lock command_status: dict[str, Any] | None _metric_value: dict[Gauge, bool] _prev_bridge_status: BridgeState | None _bridge_state_queue: deque[BridgeState] _bridge_state_loop: asyncio.Task | None def __init__(self) -> None: self.dm_update_lock = asyncio.Lock() self.command_status = None self._metric_value = defaultdict(lambda: False) self._prev_bridge_status = None self.log = self.log.getChild(self.mxid) self.relay_whitelisted = False self._bridge_state_queue = deque() self._bridge_state_loop = None @abstractmethod async def is_logged_in(self) -> bool: raise NotImplementedError() @abstractmethod async def get_puppet(self) -> br.BasePuppet | None: """ Get the ghost that represents this Matrix user on the remote network. Returns: The puppet entity, or ``None`` if the user is not logged in, or it's otherwise not possible to find the remote ghost. """ raise NotImplementedError() @abstractmethod async def get_portal_with( self, puppet: br.BasePuppet, create: bool = True ) -> br.BasePortal | None: """ Get a private chat portal between this user and the given ghost. Args: puppet: The ghost who the portal should be with. create: ``True`` if the portal entity should be created if it doesn't exist. Returns: The portal entity, or ``None`` if it can't be found, or doesn't exist and ``create`` is ``False``. """ async def needs_relay(self, portal: br.BasePortal) -> bool: return not await self.is_logged_in() async def is_in_portal(self, portal: br.BasePortal) -> bool: try: member_event = await portal.main_intent.get_state_event( portal.mxid, EventType.ROOM_MEMBER, self.mxid ) except MNotFound: return False return member_event and member_event.membership in (Membership.JOIN, Membership.INVITE) async def get_direct_chats(self) -> dict[UserID, list[RoomID]]: raise NotImplementedError() async def update_direct_chats(self, dms: dict[UserID, list[RoomID]] | None = None) -> None: """ Update the m.direct account data of the user. Args: dms: DMs to _add_ to the list. If not provided, the list is _replaced_ with the result of :meth:`get_direct_chats`. """ if not self.bridge.config["bridge.sync_direct_chat_list"]: return puppet = await self.bridge.get_double_puppet(self.mxid) if not puppet or not puppet.is_real_user: return self.log.debug("Updating m.direct list on homeserver") replace = dms is None dms = dms or await self.get_direct_chats() if self.bridge.homeserver_software.is_asmux: # This uses a secret endpoint for atomically updating the DM list await puppet.intent.api.request( Method.PUT if replace else Method.PATCH, AsmuxPath.dms, content=dms, headers={"X-Asmux-Auth": self.az.as_token}, ) else: async with self.dm_update_lock: try: current_dms = await puppet.intent.get_account_data(EventType.DIRECT) except MNotFound: current_dms = {} if replace: # Filter away all existing DM statuses with bridge users filtered_dms = { user: rooms for user, rooms in current_dms.items() if not self.bridge.is_bridge_ghost(user) } else: filtered_dms = current_dms # Add DM statuses for all rooms in our database new_dms = {**filtered_dms, **dms} if current_dms != new_dms: await puppet.intent.set_account_data(EventType.DIRECT, new_dms) def _track_metric(self, metric: Gauge, value: bool) -> None: if self._metric_value[metric] != value: if value: metric.inc(1) else: metric.dec(1) self._metric_value[metric] = value async def fill_bridge_state(self, state: BridgeState) -> None: state.user_id = self.mxid state.fill() async def get_bridge_states(self) -> list[BridgeState]: raise NotImplementedError() async def push_bridge_state( self, state_event: BridgeStateEvent, error: str | None = None, message: str | None = None, ttl: int | None = None, remote_id: str | None = None, info: dict[str, Any] | None = None, reason: str | None = None, ) -> None: if not self.bridge.config["homeserver.status_endpoint"]: return state = BridgeState( state_event=state_event, error=error, message=message, ttl=ttl, remote_id=remote_id, info=info, reason=reason, ) await self.fill_bridge_state(state) if state.should_deduplicate(self._prev_bridge_status): return self._prev_bridge_status = state self._bridge_state_queue.append(state) if not self._bridge_state_loop or self._bridge_state_loop.done(): self.log.trace(f"Starting bridge state loop") self._bridge_state_loop = asyncio.create_task(self._start_bridge_state_send_loop()) else: self.log.debug(f"Queued bridge state to send later: {state.state_event}") async def _start_bridge_state_send_loop(self): url = self.bridge.config["homeserver.status_endpoint"] while self._bridge_state_queue: state = self._bridge_state_queue.popleft() success = await state.send(url, self.az.as_token, self.log) if not success: if state.send_attempts_ <= 10: retry_seconds = state.send_attempts_**2 self.log.warning( f"Attempt #{state.send_attempts_} of sending bridge state " f"{state.state_event} failed, retrying in {retry_seconds} seconds" ) await asyncio.sleep(retry_seconds) self._bridge_state_queue.appendleft(state) else: self.log.error( f"Failed to send bridge state {state.state_event} " f"after {state.send_attempts_} attempts, giving up" ) self._bridge_state_loop = None def send_remote_checkpoint( self, status: MessageSendCheckpointStatus, event_id: EventID, room_id: RoomID, event_type: EventType, message_type: MessageType | None = None, error: str | Exception | None = None, retry_num: int = 0, ) -> WrappedTask: """ Send a remote checkpoint for the given ``event_id``. This function spaws an :class:`asyncio.Task`` to send the checkpoint. :returns: the checkpoint send task. This can be awaited if you want to block on the checkpoint send. """ if not self.bridge.config["homeserver.message_send_checkpoint_endpoint"]: return WrappedTask(task=None) task = background_task.create( MessageSendCheckpoint( event_id=event_id, room_id=room_id, step=MessageSendCheckpointStep.REMOTE, timestamp=int(time.time() * 1000), status=status, reported_by=MessageSendCheckpointReportedBy.BRIDGE, event_type=event_type, message_type=message_type, info=str(error) if error else None, retry_num=retry_num, ).send( self.bridge.config["homeserver.message_send_checkpoint_endpoint"], self.az.as_token, self.log, ) ) return WrappedTask(task=task) python-0.20.7/mautrix/client/000077500000000000000000000000001473573527000161055ustar00rootroot00000000000000python-0.20.7/mautrix/client/__init__.py000066400000000000000000000014521473573527000202200ustar00rootroot00000000000000from .api import ClientAPI from .client import Client from .dispatcher import Dispatcher, MembershipEventDispatcher, SimpleDispatcher from .encryption_manager import DecryptionDispatcher, EncryptingAPI from .state_store import FileStateStore, MemoryStateStore, MemorySyncStore, StateStore, SyncStore from .store_updater import StoreUpdatingAPI from .syncer import EventHandler, InternalEventType, Syncer, SyncStream __all__ = [ "ClientAPI", "Client", "Dispatcher", "MembershipEventDispatcher", "SimpleDispatcher", "DecryptionDispatcher", "EncryptingAPI", "FileStateStore", "MemoryStateStore", "MemorySyncStore", "StateStore", "SyncStore", "StoreUpdatingAPI", "EventHandler", "InternalEventType", "Syncer", "SyncStream", "state_store", ] python-0.20.7/mautrix/client/api/000077500000000000000000000000001473573527000166565ustar00rootroot00000000000000python-0.20.7/mautrix/client/api/__init__.py000066400000000000000000000004111473573527000207630ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .client import ClientAPI python-0.20.7/mautrix/client/api/authentication.py000066400000000000000000000161571473573527000222610ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( DeviceID, LoginFlowList, LoginResponse, LoginType, MatrixUserIdentifier, UserID, UserIdentifier, WhoamiResponse, ) from .base import BaseClientAPI class ClientAuthenticationMethods(BaseClientAPI): """ Methods in section 5 Authentication of the spec. These methods are used for setting and getting user metadata and searching for users. See also: `API reference `__ """ # region 5.5 Login # API reference: https://matrix.org/docs/spec/client_server/r0.6.1.html#login async def get_login_flows(self) -> LoginFlowList: """ Get login flows supported by the homeserver. See also: `API reference `__ Returns: The list of login flows that the homeserver supports. """ resp = await self.api.request(Method.GET, Path.v3.login) try: return LoginFlowList.deserialize(resp) except KeyError: raise MatrixResponseError("`flows` not in response.") async def login( self, identifier: UserIdentifier | UserID | None = None, login_type: LoginType = LoginType.PASSWORD, device_name: str | None = None, device_id: str | None = None, password: str | None = None, store_access_token: bool = True, update_hs_url: bool = False, **kwargs: str, ) -> LoginResponse: """ Authenticates the user, and issues an access token they can use to authorize themself in subsequent requests. See also: `API reference `__ Args: login_type: The login type being used. identifier: Identification information for the user. device_name: A display name to assign to the newly-created device. Ignored if ``device_id`` correspnods to a known device. device_id: ID of the client device. If this does not correspond to a known client device, a new device will be created. The server will auto-generate a device_id if this is not specified. password: The user's password. Required when `type` is `m.login.password`. store_access_token: Whether or not mautrix-python should store the returned access token in this ClientAPI instance for future requests. update_hs_url: Whether or not mautrix-python should use the returned homeserver URL in this ClientAPI instance for future requests. **kwargs: Additional arguments for other login types. Returns: The login response. """ if identifier is None or isinstance(identifier, str): identifier = MatrixUserIdentifier(identifier or self.mxid) if password is not None: kwargs["password"] = password if device_name is not None: kwargs["initial_device_display_name"] = device_name if device_id: kwargs["device_id"] = device_id elif self.device_id: kwargs["device_id"] = self.device_id resp = await self.api.request( Method.POST, Path.v3.login, { "type": str(login_type), "identifier": identifier.serialize(), **kwargs, }, sensitive="password" in kwargs or "token" in kwargs, ) resp_data = LoginResponse.deserialize(resp) if store_access_token: self.mxid = resp_data.user_id self.device_id = resp_data.device_id self.api.token = resp_data.access_token if update_hs_url: base_url = resp_data.well_known.homeserver.base_url if base_url and base_url != self.api.base_url: self.log.debug( "Login response contained new base URL, switching from " f"{self.api.base_url} to {base_url}" ) self.api.base_url = base_url.rstrip("/") return resp_data async def logout(self, clear_access_token: bool = True) -> None: """ Invalidates an existing access token, so that it can no longer be used for authorization. The device associated with the access token is also deleted. `Device keys `__ for the device are deleted alongside the device. See also: `API reference `__ Args: clear_access_token: Whether or not mautrix-python should forget the stored access token. """ await self.api.request(Method.POST, Path.v3.logout) if clear_access_token: self.api.token = "" self.device_id = DeviceID("") async def logout_all(self, clear_access_token: bool = True) -> None: """ Invalidates all access tokens for a user, so that they can no longer be used for authorization. This includes the access token that made this request. All devices for the user are also deleted. `Device keys `__ for the device are deleted alongside the device. This endpoint does not require UI (user-interactive) authorization because UI authorization is designed to protect against attacks where the someone gets hold of a single access token then takes over the account. This endpoint invalidates all access tokens for the user, including the token used in the request, and therefore the attacker is unable to take over the account in this way. See also: `API reference `__ Args: clear_access_token: Whether or not mautrix-python should forget the stored access token. """ await self.api.request(Method.POST, Path.v3.logout.all) if clear_access_token: self.api.token = "" self.device_id = DeviceID("") # endregion # TODO other sections # region 5.7 Current account information # API reference: https://matrix.org/docs/spec/client_server/r0.6.1.html#current-account-information async def whoami(self) -> WhoamiResponse: """ Get information about the current user. Returns: The user ID and device ID of the current user. """ resp = await self.api.request(Method.GET, Path.v3.account.whoami) return WhoamiResponse.deserialize(resp) # endregion python-0.20.7/mautrix/client/api/base.py000066400000000000000000000147641473573527000201560ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import json from aiohttp import ClientError, ClientSession, ContentTypeError from yarl import URL from mautrix.api import HTTPAPI, Method, Path from mautrix.errors import ( WellKnownInvalidVersionsResponse, WellKnownMissingHomeserver, WellKnownNotJSON, WellKnownNotURL, WellKnownUnexpectedStatus, WellKnownUnsupportedScheme, ) from mautrix.types import DeviceID, SerializerError, UserID, VersionsResponse from mautrix.util.logging import TraceLogger class BaseClientAPI: """ BaseClientAPI is the base class for :class:`ClientAPI`. This is separate from the main ClientAPI class so that the ClientAPI methods can be split into multiple classes (that inherit this class).All those section-specific method classes are inherited by the main ClientAPI class to create the full class. """ localpart: str domain: str _mxid: UserID device_id: DeviceID api: HTTPAPI log: TraceLogger versions_cache: VersionsResponse | None def __init__( self, mxid: UserID = "", device_id: DeviceID = "", api: HTTPAPI | None = None, **kwargs ) -> None: """ Initialize a ClientAPI. You must either provide the ``api`` parameter with an existing :class:`mautrix.api.HTTPAPI` instance, or provide the ``base_url`` and other arguments for creating it as kwargs. Args: mxid: The Matrix ID of the user. This is used for things like setting profile metadata. Additionally, the homeserver domain is extracted from this string and used for setting aliases and such. This can be changed later using `set_mxid`. device_id: The device ID corresponding to the access token used. api: The :class:`mautrix.api.HTTPAPI` instance to use. You can also pass the ``kwargs`` to create a HTTPAPI instance rather than creating the instance yourself. kwargs: If ``api`` is not specified, then the arguments to pass when creating a HTTPAPI. """ if mxid: self.mxid = mxid else: self._mxid = None self.localpart = None self.domain = None self.fill_member_event_callback = None self.versions_cache = None self.device_id = device_id self.api = api or HTTPAPI(**kwargs) self.log = self.api.log @classmethod def parse_user_id(cls, mxid: UserID) -> tuple[str, str]: """ Parse the localpart and server name from a Matrix user ID. Args: mxid: The Matrix user ID. Returns: A tuple of (localpart, server_name). Raises: ValueError: if the given user ID is invalid. """ if len(mxid) == 0: raise ValueError("User ID is empty") elif mxid[0] != "@": raise ValueError("User IDs start with @") try: sep = mxid.index(":") except ValueError as e: raise ValueError("User ID must contain domain separator") from e if sep == len(mxid) - 1: raise ValueError("User ID must contain domain") return mxid[1:sep], mxid[sep + 1 :] @property def mxid(self) -> UserID: return self._mxid @mxid.setter def mxid(self, mxid: UserID) -> None: self.localpart, self.domain = self.parse_user_id(mxid) self._mxid = mxid async def versions(self, no_cache: bool = False) -> VersionsResponse: """ Get client-server spec versions supported by the server. Args: no_cache: If true, the versions will always be fetched from the server rather than using cached results when availab.e. Returns: The supported Matrix spec versions and unstable features. """ if no_cache or not self.versions_cache: resp = await self.api.request(Method.GET, Path.versions) self.versions_cache = VersionsResponse.deserialize(resp) return self.versions_cache @classmethod async def discover(cls, domain: str, session: ClientSession | None = None) -> URL | None: """ Follow the server discovery spec to find the actual URL when given a Matrix server name. Args: domain: The server name (end of user ID) to discover. session: Optionally, the aiohttp ClientSession object to use. Returns: The parsed URL if the discovery succeeded. ``None`` if the request returned a 404 status. Raises: WellKnownError: for other errors """ if session is None: async with ClientSession(headers={"User-Agent": HTTPAPI.default_ua}) as sess: return await cls._discover(domain, sess) else: return await cls._discover(domain, session) @classmethod async def _discover(cls, domain: str, session: ClientSession) -> URL | None: well_known = URL.build(scheme="https", host=domain, path="/.well-known/matrix/client") async with session.get(well_known) as resp: if resp.status == 404: return None elif resp.status != 200: raise WellKnownUnexpectedStatus(resp.status) try: data = await resp.json(content_type=None) except (json.JSONDecodeError, ContentTypeError) as e: raise WellKnownNotJSON() from e try: homeserver_url = data["m.homeserver"]["base_url"] except KeyError as e: raise WellKnownMissingHomeserver() from e parsed_url = URL(homeserver_url) if not parsed_url.is_absolute(): raise WellKnownNotURL() elif parsed_url.scheme not in ("http", "https"): raise WellKnownUnsupportedScheme(parsed_url.scheme) try: async with session.get(parsed_url / "_matrix/client/versions") as resp: data = VersionsResponse.deserialize(await resp.json()) if len(data.versions) == 0: raise ValueError("no versions defined in /_matrix/client/versions response") except (ClientError, json.JSONDecodeError, SerializerError, ValueError) as e: raise WellKnownInvalidVersionsResponse() from e return parsed_url python-0.20.7/mautrix/client/api/client.py000066400000000000000000000025461473573527000205150ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .authentication import ClientAuthenticationMethods from .events import EventMethods from .filtering import FilteringMethods from .modules import ModuleMethods from .rooms import RoomMethods from .user_data import UserDataMethods class ClientAPI( ClientAuthenticationMethods, FilteringMethods, RoomMethods, EventMethods, UserDataMethods, ModuleMethods, ): """ ClientAPI is a medium-level wrapper around the HTTPAPI that provides many easy-to-use functions for accessing the client-server API. This class can be used directly, but generally you should use the higher-level wrappers that inherit from this class, such as :class:`mautrix.client.Client` or :class:`mautrix.appservice.IntentAPI`. Examples: >>> from mautrix.client import ClientAPI >>> client = ClientAPI("@user:matrix.org", base_url="https://matrix-client.matrix-org", token="syt_123_456") >>> await client.whoami() WhoamiResponse(user_id="@user:matrix.org", device_id="DEV123") >>> await client.get_joined_rooms() ["!roomid:matrix.org"] """ python-0.20.7/mautrix/client/api/events.py000066400000000000000000000675571473573527000205600ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Awaitable import json from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( JSON, BaseFileInfo, ContentURI, Event, EventContent, EventContext, EventID, EventType, FilterID, Format, ImageInfo, MediaMessageEventContent, Member, Membership, MessageEventContent, MessageType, Obj, PaginatedMessages, PaginationDirection, PresenceState, ReactionEventContent, RelatesTo, RelationType, RoomEventFilter, RoomID, Serializable, SerializerError, StateEvent, StateEventContent, SyncToken, TextMessageEventContent, UserID, ) from mautrix.util.formatter import parse_html from .base import BaseClientAPI class EventMethods(BaseClientAPI): """ Methods in section 8 Events of the spec. Includes ``/sync``'ing, getting messages and state, setting state, sending messages and redacting messages. See also: `Events API reference`_ .. _Events API reference: https://spec.matrix.org/v1.1/client-server-api/#events """ # region 8.4 Syncing # API reference: https://spec.matrix.org/v1.1/client-server-api/#syncing def sync( self, since: SyncToken | None = None, timeout: int = 30000, filter_id: FilterID | None = None, full_state: bool = False, set_presence: PresenceState | None = None, ) -> Awaitable[JSON]: """ Perform a sync request. See also: `/sync API reference`_ This method doesn't parse the response at all. You should use :class:`mautrix.client.Syncer` to parse sync responses and dispatch the data into event handlers. :class:`mautrix.client.Client` includes ``Syncer``. Args: since (str): Optional. A token which specifies where to continue a sync from. timeout (int): Optional. The time in milliseconds to wait. filter_id (int): A filter ID. full_state (bool): Return the full state for every room the user has joined Defaults to false. set_presence (str): Should the client be marked as "online" or" offline" .. _/sync API reference: https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync """ request = {"timeout": timeout} if since: request["since"] = str(since) if filter_id: request["filter"] = str(filter_id) if full_state: request["full_state"] = "true" if full_state else "false" if set_presence: request["set_presence"] = str(set_presence) return self.api.request( Method.GET, Path.v3.sync, query_params=request, retry_count=0, metrics_method="sync" ) # endregion # region 7.5 Getting events for a room # API reference: https://spec.matrix.org/v1.1/client-server-api/#getting-events-for-a-room async def get_event(self, room_id: RoomID, event_id: EventID) -> Event: """ Get a single event based on ``room_id``/``event_id``. You must have permission to retrieve this event e.g. by being a member in the room for this event. See also: `API reference `__ Args: room_id: The ID of the room the event is in. event_id: The event ID to get. Returns: The event. """ content = await self.api.request( Method.GET, Path.v3.rooms[room_id].event[event_id], metrics_method="getEvent" ) try: return Event.deserialize(content) except SerializerError as e: raise MatrixResponseError("Invalid event in response") from e async def get_event_context( self, room_id: RoomID, event_id: EventID, limit: int | None = 10, filter: RoomEventFilter | None = None, ) -> EventContext: """ Get a number of events that happened just before and after the specified event. This allows clients to get the context surrounding an event, as well as get the state at an event and paginate in either direction. Args: room_id: The room to get events from. event_id: The event to get context around. limit: The maximum number of events to return. The limit applies to the total number of events before and after the requested event. A limit of 0 means no other events are returned, while 2 means one event before and one after are returned. filter: A JSON RoomEventFilter_ to filter returned events with. Returns: The event itself, up to ``limit/2`` events before and after the event, the room state at the event, and pagination tokens to scroll up and down. .. _RoomEventFilter: https://spec.matrix.org/v1.1/client-server-api/#filtering """ query_params = {} if limit is not None: query_params["limit"] = str(limit) if filter is not None: query_params["filter"] = ( filter.serialize() if isinstance(filter, Serializable) else filter ) resp = await self.api.request( Method.GET, Path.v3.rooms[room_id].context[event_id], query_params=query_params, metrics_method="get_event_context", ) return EventContext.deserialize(resp) async def get_state_event( self, room_id: RoomID, event_type: EventType, state_key: str = "", ) -> StateEventContent: """ Looks up the contents of a state event in a room. If the user is joined to the room then the state is taken from the current state of the room. If the user has left the room then the state is taken from the state of the room when they left. See also: `API reference `__ Args: room_id: The ID of the room to look up the state in. event_type: The type of state to look up. state_key: The key of the state to look up. Defaults to empty string. Returns: The state event. """ content = await self.api.request( Method.GET, Path.v3.rooms[room_id].state[event_type][state_key], metrics_method="getStateEvent", ) content["__mautrix_event_type"] = event_type try: return StateEvent.deserialize_content(content) except SerializerError as e: raise MatrixResponseError("Invalid state event in response") from e async def get_state(self, room_id: RoomID) -> list[StateEvent]: """ Get the state events for the current state of a room. See also: `API reference `__ Args: room_id: The ID of the room to look up the state for. Returns: A list of state events with the most recent of each event_type/state_key pair. """ content = await self.api.request( Method.GET, Path.v3.rooms[room_id].state, metrics_method="getState" ) try: return [StateEvent.deserialize(event) for event in content] except SerializerError as e: raise MatrixResponseError("Invalid state events in response") from e async def get_members( self, room_id: RoomID, at: SyncToken | None = None, membership: Membership | None = None, not_membership: Membership | None = None, ) -> list[StateEvent]: """ Get the list of members for a room. See also: `API reference `__ Args: room_id: The ID of the room to get the member events for. at: The point in time (pagination token) to return members for in the room. This token can be obtained from a ``prev_batch`` token returned for each room by the sync API. Defaults to the current state of the room, as determined by the server. membership: The kind of membership to filter for. Defaults to no filtering if unspecified. When specified alongside ``not_membership``, the two parameters create an 'or' condition: either the ``membership`` is the same as membership or is not the same as ``not_membership``. not_membership: The kind of membership to exclude from the results. Defaults to no filtering if unspecified. Returns: A list of most recent member events for each user. """ query = {} if at: query["at"] = at if membership: query["membership"] = membership.value if not_membership: query["not_membership"] = not_membership.value content = await self.api.request( Method.GET, Path.v3.rooms[room_id].members, query_params=query, metrics_method="getMembers", ) try: return [StateEvent.deserialize(event) for event in content["chunk"]] except KeyError: raise MatrixResponseError("`chunk` not in response.") except SerializerError as e: raise MatrixResponseError("Invalid state events in response") from e async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]: """ Get a user ID -> member info map for a room. The current user must be in the room for it to work, unless it is an Application Service in which case any of the AS's users must be in the room. This API is primarily for Application Services and should be faster to respond than `/members`_ as it can be implemented more efficiently on the server. See also: `API reference `__ Args: room_id: The ID of the room to get the members of. Returns: A dictionary from user IDs to Member info objects. .. _/members: https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3roomsroomidmembers """ content = await self.api.request( Method.GET, Path.v3.rooms[room_id].joined_members, metrics_method="getJoinedMembers" ) try: return { user_id: Member( membership=Membership.JOIN, displayname=member.get("display_name", ""), avatar_url=member.get("avatar_url", ""), ) for user_id, member in content["joined"].items() } except KeyError: raise MatrixResponseError("`joined` not in response.") except SerializerError as e: raise MatrixResponseError("Invalid member objects in response") from e async def get_messages( self, room_id: RoomID, direction: PaginationDirection, from_token: SyncToken | None = None, to_token: SyncToken | None = None, limit: int | None = None, filter_json: str | dict | RoomEventFilter | None = None, ) -> PaginatedMessages: """ Get a list of message and state events for a room. Pagination parameters are used to paginate history in the room. See also: `API reference `__ Args: room_id: The ID of the room to get events from. direction: The direction to return events from. from_token: The token to start returning events from. This token can be obtained from a ``prev_batch`` token returned for each room by the `sync endpoint`_, or from a ``start`` or ``end`` token returned by a previous request to this endpoint. Starting from Matrix v1.3, this field can be omitted to fetch events from the beginning or end of the room. to_token: The token to stop returning events at. limit: The maximum number of events to return. Defaults to 10. filter_json: A JSON RoomEventFilter_ to filter returned events with. Returns: .. _RoomEventFilter: https://spec.matrix.org/v1.3/client-server-api/#filtering .. _sync endpoint: https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3sync """ if isinstance(filter_json, Serializable): filter_json = filter_json.json() elif isinstance(filter_json, dict): filter_json = json.dumps(filter_json) query_params = { "from": from_token, "dir": direction.value, "to": to_token, "limit": str(limit) if limit else None, "filter": filter_json, } content = await self.api.request( Method.GET, Path.v3.rooms[room_id].messages, query_params=query_params, metrics_method="getMessages", ) try: return PaginatedMessages( content["start"], content["end"], [Event.deserialize(event) for event in content["chunk"]], ) except KeyError: if "start" not in content: raise MatrixResponseError("`start` not in response.") elif "end" not in content: raise MatrixResponseError("`start` not in response.") raise MatrixResponseError("`content` not in response.") except SerializerError as e: raise MatrixResponseError("Invalid events in response") from e # endregion # region 7.6 Sending events to a room # API reference: https://spec.matrix.org/v1.1/client-server-api/#sending-events-to-a-room async def send_state_event( self, room_id: RoomID, event_type: EventType, content: StateEventContent, state_key: str = "", ensure_joined: bool = True, **kwargs, ) -> EventID: """ Send a state event to a room. State events with the same ``room_id``, ``event_type`` and ``state_key`` will be overridden. See also: `API reference `__ Args: room_id: The ID of the room to set the state in. event_type: The type of state to send. content: The content to send. state_key: The key for the state to send. Defaults to empty string. ensure_joined: Used by IntentAPI to determine if it should ensure the user is joined before sending the event. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Used by :class:`IntentAPI` to pass the timestamp massaging field to :meth:`AppServiceAPI.request`. Returns: The ID of the event that was sent. """ content = content.serialize() if isinstance(content, Serializable) else content resp = await self.api.request( Method.PUT, Path.v3.rooms[room_id].state[event_type][state_key], content, **kwargs, metrics_method="sendStateEvent", ) try: return resp["event_id"] except KeyError: raise MatrixResponseError("`event_id` not in response.") async def send_message_event( self, room_id: RoomID, event_type: EventType, content: EventContent, txn_id: str | None = None, **kwargs, ) -> EventID: """ Send a message event to a room. Message events allow access to historical events and pagination, making them suited for "once-off" activity in a room. See also: `API reference `__ Args: room_id: The ID of the room to send the message to. event_type: The type of message to send. content: The content to send. txn_id: The transaction ID to use. If not provided, a random ID will be generated. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Used by :class:`IntentAPI` to pass the timestamp massaging field to :meth:`AppServiceAPI.request`. Returns: The ID of the event that was sent. """ if not room_id: raise ValueError("Room ID not given") elif not event_type: raise ValueError("Event type not given") url = Path.v3.rooms[room_id].send[event_type][txn_id or self.api.get_txn_id()] content = content.serialize() if isinstance(content, Serializable) else content resp = await self.api.request( Method.PUT, url, content, **kwargs, metrics_method="sendMessageEvent" ) try: return resp["event_id"] except KeyError: raise MatrixResponseError("`event_id` not in response.") # region Message send helper functions def send_message( self, room_id: RoomID, content: MessageEventContent, **kwargs, ) -> Awaitable[EventID]: """ Send a message to a room. Args: room_id: The ID of the room to send the message to. content: The content to send. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Returns: The ID of the event that was sent. """ return self.send_message_event(room_id, EventType.ROOM_MESSAGE, content, **kwargs) def react(self, room_id: RoomID, event_id: EventID, key: str, **kwargs) -> Awaitable[EventID]: content = ReactionEventContent( relates_to=RelatesTo(rel_type=RelationType.ANNOTATION, event_id=event_id, key=key) ) return self.send_message_event(room_id, EventType.REACTION, content, **kwargs) async def send_text( self, room_id: RoomID, text: str | None = None, html: str | None = None, msgtype: MessageType = MessageType.TEXT, relates_to: RelatesTo | None = None, **kwargs, ) -> EventID: """ Send a text message to a room. Args: room_id: The ID of the room to send the message to. text: The text to send. If set to ``None``, the given HTML will be parsed to generate a plaintext representation. html: The HTML to send. msgtype: The message type to send. Defaults to :attr:`MessageType.TEXT` (normal text message). relates_to: Message relation metadata used for things like replies. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Returns: The ID of the event that was sent. Raises: ValueError: if both ``text`` and ``html`` are ``None``. """ if html is not None: if text is None: text = await parse_html(html) content = TextMessageEventContent( msgtype=msgtype, body=text, format=Format.HTML, formatted_body=html ) elif text is not None: content = TextMessageEventContent(msgtype=msgtype, body=text) else: raise TypeError("send_text() requires either text or html to be set") if relates_to: content.relates_to = relates_to return await self.send_message(room_id, content, **kwargs) def send_notice( self, room_id: RoomID, text: str | None = None, html: str | None = None, relates_to: RelatesTo | None = None, **kwargs, ) -> Awaitable[EventID]: """ Send a notice text message to a room. Notices are like normal text messages, but usually sent by bots to tell other bots not to react to them. If you're a bot, please send notices instead of normal text, unless there is a reason to do something else. Args: room_id: The ID of the room to send the message to. text: The text to send. If set to ``None``, the given HTML will be parsed to generate a plaintext representation. html: The HTML to send. relates_to: Message relation metadata used for things like replies. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Returns: The ID of the event that was sent. Raises: ValueError: if both ``text`` and ``html`` are ``None``. """ return self.send_text(room_id, text, html, MessageType.NOTICE, relates_to, **kwargs) def send_emote( self, room_id: RoomID, text: str | None = None, html: str | None = None, relates_to: RelatesTo | None = None, **kwargs, ) -> Awaitable[EventID]: """ Send an emote to a room. Emotes are usually displayed by prepending a star and the user's display name to the message, which means they're usually written in the third person. Args: room_id: The ID of the room to send the message to. text: The text to send. If set to ``None``, the given HTML will be parsed to generate a plaintext representation. html: The HTML to send. relates_to: Message relation metadata used for things like replies. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Returns: The ID of the event that was sent. Raises: ValueError: if both ``text`` and ``html`` are ``None``. """ return self.send_text(room_id, text, html, MessageType.EMOTE, relates_to, **kwargs) def send_file( self, room_id: RoomID, url: ContentURI, info: BaseFileInfo | None = None, file_name: str | None = None, file_type: MessageType = MessageType.FILE, relates_to: RelatesTo | None = None, **kwargs, ) -> Awaitable[EventID]: """ Send a file to a room. Args: room_id: The ID of the room to send the message to. url: The Matrix content repository URI of the file. You can upload files using :meth:`~MediaRepositoryMethods.upload_media`. info: Additional metadata about the file, e.g. mimetype, image size, video duration, etc file_name: The name for the file to send. file_type: The general file type to send. The file type can be further specified by setting the ``mimetype`` field of the ``info`` parameter. Defaults to :attr:`MessageType.FILE` (unspecified file type, e.g. document) relates_to: Message relation metadata used for things like replies. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Returns: The ID of the event that was sent. """ return self.send_message( room_id, MediaMessageEventContent( url=url, info=info, body=file_name, relates_to=relates_to, msgtype=file_type ), **kwargs, ) def send_sticker( self, room_id: RoomID, url: ContentURI, info: ImageInfo | None = None, text: str = "", relates_to: RelatesTo | None = None, **kwargs, ) -> Awaitable[EventID]: """ Send a sticker to a room. Stickers are basically images, but they're usually rendered slightly differently. Args: room_id: The ID of the room to send the message to. url: The Matrix content repository URI of the sticker. You can upload files using :meth:`~MediaRepositoryMethods.upload_media`. info: Additional metadata about the sticker, e.g. mimetype and image size text: A textual description of the sticker. relates_to: Message relation metadata used for things like replies. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Returns: The ID of the event that was sent. """ return self.send_message_event( room_id, EventType.STICKER, MediaMessageEventContent(url=url, info=info, body=text, relates_to=relates_to), **kwargs, ) def send_image( self, room_id: RoomID, url: ContentURI, info: ImageInfo | None = None, file_name: str = None, relates_to: RelatesTo | None = None, **kwargs, ) -> Awaitable[EventID]: """ Send an image to a room. Args: room_id: The ID of the room to send the message to. url: The Matrix content repository URI of the image. You can upload files using :meth:`~MediaRepositoryMethods.upload_media`. info: Additional metadata about the image, e.g. mimetype and image size file_name: The file name for the image to send. relates_to: Message relation metadata used for things like replies. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Returns: The ID of the event that was sent. """ return self.send_file( room_id, url, info, file_name, MessageType.IMAGE, relates_to, **kwargs ) # endregion # endregion # region 7.7 Redactions # API reference: https://spec.matrix.org/v1.1/client-server-api/#redactions async def redact( self, room_id: RoomID, event_id: EventID, reason: str | None = None, extra_content: dict[str, JSON] | None = None, **kwargs, ) -> EventID: """ Send an event to redact a previous event. Redacting an event strips all information out of an event which isn't critical to the integrity of the server-side representation of the room. This cannot be undone. Users may redact their own events, and any user with a power level greater than or equal to the redact power level of the room may redact events there. See also: `API reference `__ Args: room_id: The ID of the room the event is in. event_id: The ID of the event to redact. reason: The reason for the event being redacted. extra_content: Extra content for the event. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Used by :class:`IntentAPI` to pass the timestamp massaging field to :meth:`AppServiceAPI.request`. Returns: The ID of the event that was sent to redact the other event. """ url = Path.v3.rooms[room_id].redact[event_id][self.api.get_txn_id()] content = extra_content or {} if reason: content["reason"] = reason resp = await self.api.request( Method.PUT, url, content=content, **kwargs, metrics_method="redact" ) try: return resp["event_id"] except KeyError: raise MatrixResponseError("`event_id` not in response.") # endregion python-0.20.7/mautrix/client/api/filtering.py000066400000000000000000000042551473573527000212210ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import Filter, FilterID, Serializable from .base import BaseClientAPI class FilteringMethods(BaseClientAPI): """ Methods in section 7 Filtering of the spec. Filters can be created on the server and can be passed as as a parameter to APIs which return events. These filters alter the data returned from those APIs. Not all APIs accept filters. See also: `API reference `__ """ async def get_filter(self, filter_id: FilterID) -> Filter: """ Download a filter. See also: `API reference `__ Args: filter_id: The filter ID to download. Returns: The filter data. """ content = await self.api.request(Method.GET, Path.v3.user[self.mxid].filter[filter_id]) return Filter.deserialize(content) async def create_filter(self, filter_params: Filter) -> FilterID: """ Upload a new filter definition to the homeserver. See also: `API reference `__ Args: filter_params: The filter data. Returns: A filter ID that can be used in future requests to refer to the uploaded filter. """ resp = await self.api.request( Method.POST, Path.v3.user[self.mxid].filter, ( filter_params.serialize() if isinstance(filter_params, Serializable) else filter_params ), ) try: return resp["filter_id"] except KeyError: raise MatrixResponseError("`filter_id` not in response.") # endregion python-0.20.7/mautrix/client/api/modules/000077500000000000000000000000001473573527000203265ustar00rootroot00000000000000python-0.20.7/mautrix/client/api/modules/__init__.py000066400000000000000000000015321473573527000224400ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .account_data import AccountDataMethods from .crypto import CryptoMethods from .media_repository import MediaRepositoryMethods from .misc import MiscModuleMethods from .push_rules import PushRuleMethods from .room_tag import RoomTaggingMethods class ModuleMethods( MediaRepositoryMethods, CryptoMethods, AccountDataMethods, MiscModuleMethods, RoomTaggingMethods, PushRuleMethods, ): """ Methods in section 13 Modules of the spec. See also: `API reference `__ """ # TODO: subregions 15, 21, 26, 27, others? python-0.20.7/mautrix/client/api/modules/account_data.py000066400000000000000000000051761473573527000233360ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.api import Method, Path from mautrix.types import JSON, AccountDataEventContent, EventType, RoomID, Serializable from ..base import BaseClientAPI class AccountDataMethods(BaseClientAPI): """ Methods in section 13.9 Client Config of the spec. These methods are used for storing user-local data on the homeserver to synchronize client configuration across sessions. See also: `API reference `__""" async def get_account_data(self, type: EventType | str, room_id: RoomID | None = None) -> JSON: """ Get a specific account data event from the homeserver. See also: `API reference `__ Args: type: The type of the account data event to get. room_id: Optionally, the room ID to get per-room account data. Returns: The data in the event. """ if isinstance(type, EventType) and not type.is_account_data: raise ValueError("Event type is not an account data event type") base_path = Path.v3.user[self.mxid] if room_id: base_path = base_path.rooms[room_id] return await self.api.request(Method.GET, base_path.account_data[type]) async def set_account_data( self, type: EventType | str, data: AccountDataEventContent | dict[str, JSON], room_id: RoomID | None = None, ) -> None: """ Store account data on the homeserver. See also: `API reference `__ Args: type: The type of the account data event to set. data: The content to store in that account data event. room_id: Optionally, the room ID to set per-room account data. """ if isinstance(type, EventType) and not type.is_account_data: raise ValueError("Event type is not an account data event type") base_path = Path.v3.user[self.mxid] if room_id: base_path = base_path.rooms[room_id] await self.api.request( Method.PUT, base_path.account_data[type], data.serialize() if isinstance(data, Serializable) else data, ) python-0.20.7/mautrix/client/api/modules/crypto.py000066400000000000000000000157741473573527000222360ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( ClaimKeysResponse, DeviceID, EncryptionKeyAlgorithm, EventType, QueryKeysResponse, Serializable, SyncToken, ToDeviceEventContent, UserID, ) from ..base import BaseClientAPI class CryptoMethods(BaseClientAPI): """ Methods in section `13.9 Send-to-Device messaging `__ and `13.11 End-to-End Encryption of the spec `__. """ async def send_to_device( self, event_type: EventType, messages: dict[UserID, dict[DeviceID, ToDeviceEventContent]] ) -> None: """ Send to-device events to a set of client devices. See also: `API reference `__ Args: event_type: The type of event to send. messages: The messages to send. A map from user ID, to a map from device ID to message body. The device ID may also be ``*``, meaning all known devices for the user. """ if not event_type.is_to_device: raise ValueError("Event type must be a to-device event type") await self.api.request( Method.PUT, Path.v3.sendToDevice[event_type][self.api.get_txn_id()], { "messages": { user_id: { device_id: ( content.serialize() if isinstance(content, Serializable) else content ) for device_id, content in devices.items() } for user_id, devices in messages.items() }, }, ) async def send_to_one_device( self, event_type: EventType, user_id: UserID, device_id: DeviceID, message: ToDeviceEventContent, ) -> None: """ Send a to-device event to a single device. Args: event_type: The type of event to send. user_id: The user whose device to send the event to. device_id: The device ID to send the event to. message: The event content to send. """ return await self.send_to_device(event_type, {user_id: {device_id: message}}) async def upload_keys( self, one_time_keys: dict[str, Any] | None = None, device_keys: dict[str, Any] | None = None, ) -> dict[EncryptionKeyAlgorithm, int]: """ Publishes end-to-end encryption keys for the device. See also: `API reference `__ Args: one_time_keys: One-time public keys for "pre-key" messages. The names of the properties should be in the format ``:``. The format of the key is determined by the key algorithm. device_keys: Identity keys for the device. May be absent if no new identity keys are required. Returns: For each key algorithm, the number of unclaimed one-time keys of that type currently held on the server for this device. """ data = {} if device_keys: data["device_keys"] = device_keys if one_time_keys: data["one_time_keys"] = one_time_keys resp = await self.api.request(Method.POST, Path.v3.keys.upload, data) try: return { EncryptionKeyAlgorithm.deserialize(alg): count for alg, count in resp["one_time_key_counts"].items() } except KeyError as e: raise MatrixResponseError("`one_time_key_counts` not in response.") from e except AttributeError as e: raise MatrixResponseError("Invalid `one_time_key_counts` field in response.") from e async def query_keys( self, device_keys: list[UserID] | set[UserID] | dict[UserID, list[DeviceID]], token: SyncToken = "", timeout: int = 10000, ) -> QueryKeysResponse: """ Fetch devices and their identity keys for the given users. See also: `API reference `__ Args: device_keys: The keys to be downloaded. A map from user ID, to a list of device IDs, or to an empty list to indicate all devices for the corresponding user. token: If the client is fetching keys as a result of a device update received in a sync request, this should be the 'since' token of that sync request, or any later sync token. This allows the server to ensure its response contains the keys advertised by the notification in that sync. timeout: The time (in milliseconds) to wait when downloading keys from remote servers. Returns: Information on the queried devices and errors for homeservers that could not be reached. """ if isinstance(device_keys, (list, set)): device_keys = {user_id: [] for user_id in device_keys} data = { "timeout": timeout, "device_keys": device_keys, } if token: data["token"] = token resp = await self.api.request(Method.POST, Path.v3.keys.query, data) return QueryKeysResponse.deserialize(resp) async def claim_keys( self, one_time_keys: dict[UserID, dict[DeviceID, EncryptionKeyAlgorithm]], timeout: int = 10000, ) -> ClaimKeysResponse: """ Claim one-time keys for use in pre-key messages. See also: `API reference `__ Args: one_time_keys: The keys to be claimed. A map from user ID, to a map from device ID to algorithm name. timeout: The time (in milliseconds) to wait when downloading keys from remote servers. Returns: One-time keys for the queried devices and errors for homeservers that could not be reached. """ resp = await self.api.request( Method.POST, Path.v3.keys.claim, { "timeout": timeout, "one_time_keys": { user_id: {device_id: alg.serialize() for device_id, alg in devices.items()} for user_id, devices in one_time_keys.items() }, }, ) return ClaimKeysResponse.deserialize(resp) python-0.20.7/mautrix/client/api/modules/media_repository.py000066400000000000000000000340231473573527000242600ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, AsyncIterable, Literal from contextlib import contextmanager import asyncio import time from yarl import URL from mautrix import __optional_imports__ from mautrix.api import MediaPath, Method from mautrix.errors import MatrixResponseError, make_request_error from mautrix.types import ( ContentURI, MediaCreateResponse, MediaRepoConfig, MXOpenGraph, SerializerError, SpecVersions, ) from mautrix.util import background_task from mautrix.util.async_body import async_iter_bytes from mautrix.util.opt_prometheus import Histogram from ..base import BaseClientAPI try: from mautrix.util import magic except ImportError: if __optional_imports__: raise magic = None # type: ignore UPLOAD_TIME = Histogram( "bridge_media_upload_time", "Time spent uploading media (milliseconds per megabyte)", buckets=[10, 25, 50, 100, 250, 500, 750, 1000, 2500, 5000, 10000], ) class MediaRepositoryMethods(BaseClientAPI): """ Methods in section 13.8 Content Repository of the spec. These methods are used for uploading and downloading content from the media repository and for getting URL previews without leaking client IPs. See also: `API reference `__ """ async def create_mxc(self) -> MediaCreateResponse: """ Create a media ID for uploading media to the homeserver. See also: `API reference `__ Returns: MediaCreateResponse Containing the MXC URI that can be used to upload a file to later """ resp = await self.api.request(Method.POST, MediaPath.v1.create) return MediaCreateResponse.deserialize(resp) @contextmanager def _observe_upload_time(self, size: int | None, mxc: ContentURI | None = None) -> None: start = time.monotonic_ns() yield duration = time.monotonic_ns() - start if mxc: duration_sec = duration / 1000**3 self.log.debug(f"Completed asynchronous upload of {mxc} in {duration_sec:.3f} seconds") if size: UPLOAD_TIME.observe(duration / size) async def upload_media( self, data: bytes | bytearray | AsyncIterable[bytes], mime_type: str | None = None, filename: str | None = None, size: int | None = None, mxc: ContentURI | None = None, async_upload: bool = False, ) -> ContentURI: """ Upload a file to the content repository. See also: `API reference `__ Args: data: The data to upload. mime_type: The MIME type to send with the upload request. filename: The filename to send with the upload request. size: The file size to send with the upload request. mxc: An existing MXC URI which doesn't have content yet to upload into. async_upload: Should the media be uploaded in the background? If ``True``, this will create a MXC URI using :meth:`create_mxc`, start uploading in the background, and then immediately return the created URI. This is mutually exclusive with manually passing the ``mxc`` parameter. Returns: The MXC URI to the uploaded file. Raises: MatrixResponseError: If the response does not contain a ``content_uri`` field. ValueError: if both ``async_upload`` and ``mxc`` are provided at the same time. """ if magic and isinstance(data, bytes): mime_type = mime_type or magic.mimetype(data) headers = {} if mime_type: headers["Content-Type"] = mime_type if size: headers["Content-Length"] = str(size) elif isinstance(data, (bytes, bytearray)): size = len(data) query = {} if filename: query["filename"] = filename upload_url = None if async_upload: if mxc: raise ValueError("async_upload and mxc can't be provided simultaneously") create_response = await self.create_mxc() mxc = create_response.content_uri upload_url = create_response.unstable_upload_url path = MediaPath.v3.upload method = Method.POST if mxc: server_name, media_id = self.api.parse_mxc_uri(mxc) if upload_url is None: path = MediaPath.v3.upload[server_name][media_id] method = Method.PUT else: path = ( MediaPath.unstable["com.beeper.msc3870"].upload[server_name][media_id].complete ) if upload_url is not None: task = self._upload_to_url(upload_url, path, headers, data, post_upload_query=query) else: task = self.api.request( method, path, content=data, headers=headers, query_params=query ) if async_upload: async def _try_upload(): try: with self._observe_upload_time(size, mxc): await task except Exception as e: self.log.error(f"Failed to upload {mxc}: {type(e).__name__}: {e}") background_task.create(_try_upload()) return mxc else: with self._observe_upload_time(size): resp = await task try: return resp["content_uri"] except KeyError: raise MatrixResponseError("`content_uri` not in response.") async def download_media(self, url: ContentURI, timeout_ms: int | None = None) -> bytes: """ Download a file from the content repository. See also: `API reference `__ Args: url: The MXC URI to download. timeout_ms: The maximum number of milliseconds that the client is willing to wait to start receiving data. Used for asynchronous uploads. Returns: The raw downloaded data. """ authenticated = (await self.versions()).supports(SpecVersions.V111) url = self.api.get_download_url(url, authenticated=authenticated) query_params: dict[str, Any] = {"allow_redirect": "true"} if timeout_ms is not None: query_params["timeout_ms"] = timeout_ms headers: dict[str, str] = {} if authenticated: headers["Authorization"] = f"Bearer {self.api.token}" if self.api.as_user_id: query_params["user_id"] = self.api.as_user_id req_id = self.api.log_download_request(url, query_params) start = time.monotonic() async with self.api.session.get(url, params=query_params, headers=headers) as response: try: response.raise_for_status() return await response.read() finally: self.api.log_download_request_done( url, req_id, time.monotonic() - start, response.status ) async def download_thumbnail( self, url: ContentURI, width: int | None = None, height: int | None = None, resize_method: Literal["crop", "scale"] = None, allow_remote: bool | None = None, timeout_ms: int | None = None, ): """ Download a thumbnail for a file in the content repository. See also: `API reference `__ Args: url: The MXC URI to download. width: The _desired_ width of the thumbnail. The actual thumbnail may not match the size specified. height: The _desired_ height of the thumbnail. The actual thumbnail may not match the size specified. resize_method: The desired resizing method. Either ``crop`` or ``scale``. allow_remote: Indicates to the server that it should not attempt to fetch the media if it is deemed remote. This is to prevent routing loops where the server contacts itself. timeout_ms: The maximum number of milliseconds that the client is willing to wait to start receiving data. Used for asynchronous Uploads. Returns: The raw downloaded data. """ authenticated = (await self.versions()).supports(SpecVersions.V111) url = self.api.get_download_url( url, download_type="thumbnail", authenticated=authenticated ) query_params: dict[str, Any] = {"allow_redirect": "true"} if width is not None: query_params["width"] = width if height is not None: query_params["height"] = height if resize_method is not None: query_params["method"] = resize_method if allow_remote is not None: query_params["allow_remote"] = str(allow_remote).lower() if timeout_ms is not None: query_params["timeout_ms"] = timeout_ms headers: dict[str, str] = {} if authenticated: headers["Authorization"] = f"Bearer {self.api.token}" if self.api.as_user_id: query_params["user_id"] = self.api.as_user_id req_id = self.api.log_download_request(url, query_params) start = time.monotonic() async with self.api.session.get(url, params=query_params, headers=headers) as response: try: response.raise_for_status() return await response.read() finally: self.api.log_download_request_done( url, req_id, time.monotonic() - start, response.status ) async def get_url_preview(self, url: str, timestamp: int | None = None) -> MXOpenGraph: """ Get information about a URL for a client. See also: `API reference `__ Args: url: The URL to get a preview of. timestamp: The preferred point in time to return a preview for. The server may return a newer version if it does not have the requested version available. """ query_params = {"url": url} if timestamp is not None: query_params["ts"] = timestamp content = await self.api.request( Method.GET, MediaPath.v3.preview_url, query_params=query_params ) try: return MXOpenGraph.deserialize(content) except SerializerError as e: raise MatrixResponseError("Invalid MXOpenGraph in response.") from e async def get_media_repo_config(self) -> MediaRepoConfig: """ This endpoint allows clients to retrieve the configuration of the content repository, such as upload limitations. Clients SHOULD use this as a guide when using content repository endpoints. All values are intentionally left optional. Clients SHOULD follow the advice given in the field description when the field is not available. **NOTE:** Both clients and server administrators should be aware that proxies between the client and the server may affect the apparent behaviour of content repository APIs, for example, proxies may enforce a lower upload size limit than is advertised by the server on this endpoint. See also: `API reference `__ Returns: The media repository config. """ content = await self.api.request(Method.GET, MediaPath.v3.config) try: return MediaRepoConfig.deserialize(content) except SerializerError as e: raise MatrixResponseError("Invalid MediaRepoConfig in response") from e async def _upload_to_url( self, upload_url: str, post_upload_path: str, headers: dict[str, str], data: bytes | bytearray | AsyncIterable[bytes], post_upload_query: dict[str, str], min_iter_size: int = 25 * 1024 * 1024, ) -> None: retry_count = self.api.default_retry_count backoff = 2 do_fake_iter = data and hasattr(data, "__len__") and len(data) > min_iter_size if do_fake_iter: headers["Content-Length"] = str(len(data)) while True: self.log.debug("Uploading media to external URL %s", upload_url) upload_response = None try: req_data = async_iter_bytes(data) if do_fake_iter else data upload_response = await self.api.session.put( upload_url, data=req_data, headers=headers ) upload_response.raise_for_status() except Exception as e: if retry_count <= 0: raise make_request_error( http_status=upload_response.status if upload_response else -1, text=(await upload_response.text()) if upload_response else "", errcode="COM.BEEPER.EXTERNAL_UPLOAD_ERROR", message=None, ) self.log.warning( f"Uploading media to external URL {upload_url} failed: {e}, " f"retrying in {backoff} seconds", ) await asyncio.sleep(backoff) backoff *= 2 retry_count -= 1 else: break await self.api.request(Method.POST, post_upload_path, query_params=post_upload_query) python-0.20.7/mautrix/client/api/modules/misc.py000066400000000000000000000134031473573527000216340ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( JSON, EventID, PresenceEventContent, PresenceState, RoomID, SerializerError, UserID, ) from ..base import BaseClientAPI class MiscModuleMethods(BaseClientAPI): """ Miscellaneous subsections in the `Modules section`_ of the API spec. Currently included subsections: * 13.4 `Typing Notifications`_ * 13.5 `Receipts`_ * 13.6 `Fully Read Markers`_ * 13.7 `Presence`_ .. _Modules section: https://matrix.org/docs/spec/client_server/r0.4.0.html#modules .. _Typing Notifications: https://matrix.org/docs/spec/client_server/r0.4.0.html#id95 .. _Receipts: https://matrix.org/docs/spec/client_server/r0.4.0.html#id99 .. _Fully Read Markers: https://matrix.org/docs/spec/client_server/r0.4.0.html#fully-read-markers .. _Presence: https://matrix.org/docs/spec/client_server/r0.4.0.html#id107 """ # region 13.4 Typing Notifications async def set_typing(self, room_id: RoomID, timeout: int = 0) -> None: """ This tells the server that the user is typing for the next N milliseconds where N is the value specified in the timeout key. If the timeout is equal to or less than zero, it tells the server that the user has stopped typing. See also: `API reference `__ Args: room_id: The ID of the room in which the user is typing. timeout: The length of time in milliseconds to mark this user as typing. """ if timeout > 0: content = {"typing": True, "timeout": timeout} else: content = {"typing": False} await self.api.request(Method.PUT, Path.v3.rooms[room_id].typing[self.mxid], content) # endregion # region 13.5 Receipts async def send_receipt( self, room_id: RoomID, event_id: EventID, receipt_type: str = "m.read", ) -> None: """ Update the marker for the given receipt type to the event ID specified. See also: `API reference `__ Args: room_id: The ID of the room which to send the receipt to. event_id: The last event ID to acknowledge. receipt_type: The type of receipt to send. Currently only ``m.read`` is supported. """ await self.api.request(Method.POST, Path.v3.rooms[room_id].receipt[receipt_type][event_id]) # endregion # region 13.6 Fully read markers async def set_fully_read_marker( self, room_id: RoomID, fully_read: EventID, read_receipt: EventID | None = None, extra_content: dict[str, JSON] | None = None, ) -> None: """ Set the position of the read marker for the given room, and optionally send a new read receipt. See also: `API reference `__ Args: room_id: The ID of the room which to set the read marker in. fully_read: The last event up to which the user has either read all events or is not interested in reading the events. read_receipt: The new position for the user's normal read receipt, i.e. the last event the user has seen. extra_content: Additional fields to include in the ``/read_markers`` request. """ content = { "m.fully_read": fully_read, } if read_receipt: content["m.read"] = read_receipt if extra_content: content.update(extra_content) await self.api.request(Method.POST, Path.v3.rooms[room_id].read_markers, content) # endregion # region 13.7 Presence async def set_presence( self, presence: PresenceState = PresenceState.ONLINE, status: str | None = None ) -> None: """ Set the current user's presence state. When setting the status, the activity time is updated to reflect that activity; the client does not need to specify :attr:`Presence.last_active_ago`. See also: `API reference `__ Args: presence: The new presence state to set. status: The status message to attach to this state. """ content = { "presence": presence.value, } if status: content["status_msg"] = status await self.api.request(Method.PUT, Path.v3.presence[self.mxid].status, content) async def get_presence(self, user_id: UserID) -> PresenceEventContent: """ Get the presence info of a user. See also: `API reference `__ Args: user_id: The ID of the user whose presence info to get. Returns: The presence info of the given user. """ content = await self.api.request(Method.GET, Path.v3.presence[user_id].status) try: return PresenceEventContent.deserialize(content) except SerializerError: raise MatrixResponseError("Invalid presence in response") # endregion python-0.20.7/mautrix/client/api/modules/push_rules.py000066400000000000000000000066661473573527000231070ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.api import Method, Path from mautrix.types import ( PushAction, PushCondition, PushRule, PushRuleID, PushRuleKind, PushRuleScope, ) from ..base import BaseClientAPI class PushRuleMethods(BaseClientAPI): """ Methods in section 13.13 Push Notifications of the spec. These methods are used for modifying what triggers push notifications. See also: `API reference `__""" async def get_push_rule( self, scope: PushRuleScope, kind: PushRuleKind, rule_id: PushRuleID ) -> PushRule: """ Retrieve a single specified push rule. See also: `API reference `__ Args: scope: The scope of the push rule. kind: The kind of rule. rule_id: The identifier of the rule. Returns: The push rule information. """ resp = await self.api.request(Method.GET, Path.v3.pushrules[scope][kind][rule_id]) return PushRule.deserialize(resp) async def set_push_rule( self, scope: PushRuleScope, kind: PushRuleKind, rule_id: PushRuleID, actions: list[PushAction], pattern: str | None = None, before: PushRuleID | None = None, after: PushRuleID | None = None, conditions: list[PushCondition] = None, ) -> None: """ Create or modify a push rule. See also: `API reference `__ Args: scope: The scope of the push rule. kind: The kind of rule. rule_id: The identifier for the rule. before: after: actions: The actions to perform when the conditions for the rule are met. pattern: The glob-style pattern to match against for ``content`` rules. conditions: The conditions for the rule for ``underride`` and ``override`` rules. """ query = {} if after: query["after"] = after if before: query["before"] = before content = {"actions": [act.serialize() for act in actions]} if conditions: content["conditions"] = [cond.serialize() for cond in conditions] if pattern: content["pattern"] = pattern await self.api.request( Method.PUT, Path.v3.pushrules[scope][kind][rule_id], query_params=query, content=content, ) async def remove_push_rule( self, scope: PushRuleScope, kind: PushRuleKind, rule_id: PushRuleID ) -> None: """ Remove a push rule. See also: `API reference `__ Args: scope: The scope of the push rule. kind: The kind of rule. rule_id: The identifier of the rule. """ await self.api.request(Method.DELETE, Path.v3.pushrules[scope][kind][rule_id]) python-0.20.7/mautrix/client/api/modules/room_tag.py000066400000000000000000000060511473573527000225110ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.api import Method, Path from mautrix.types import RoomID, RoomTagAccountDataEventContent, RoomTagInfo, Serializable from ..base import BaseClientAPI class RoomTaggingMethods(BaseClientAPI): """ Methods in section 13.18 Room Tagging of the spec. These methods are used for organizing rooms into tags for the local user. See also: `API reference `__""" async def get_room_tags(self, room_id: RoomID) -> RoomTagAccountDataEventContent: """ Get all tags for a specific room. This is equivalent to getting the ``m.tag`` account data event for the room. See also: `API reference `__ Args: room_id: The room ID to get tags from. Returns: The m.tag account data event. """ resp = await self.api.request(Method.GET, Path.v3.user[self.mxid].rooms[room_id].tags) return RoomTagAccountDataEventContent.deserialize(resp) async def get_room_tag(self, room_id: RoomID, tag: str) -> RoomTagInfo | None: """ Get the info of a specific tag for a room. Args: room_id: The room to get the tag from. tag: The tag to get. Returns: The info about the tag, or ``None`` if the room does not have the specified tag. """ resp = await self.get_room_tags(room_id) try: return resp.tags[tag] except KeyError: return None async def set_room_tag( self, room_id: RoomID, tag: str, info: RoomTagInfo | None = None ) -> None: """ Add or update a tag for a specific room. See also: `API reference `__ Args: room_id: The room ID to add the tag to. tag: The tag to add. info: Optionally, information like ordering within the tag. """ await self.api.request( Method.PUT, Path.v3.user[self.mxid].rooms[room_id].tags[tag], content=(info.serialize() if isinstance(info, Serializable) else (info or {})), ) async def remove_room_tag(self, room_id: RoomID, tag: str) -> None: """ Remove a tag from a specific room. See also: `API reference `__ Args: room_id: The room ID to remove the tag from. tag: The tag to remove. """ await self.api.request(Method.DELETE, Path.v3.user[self.mxid].rooms[room_id].tags[tag]) python-0.20.7/mautrix/client/api/rooms.py000066400000000000000000001023111473573527000203650ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Awaitable, Callable import asyncio from multidict import CIMultiDict from mautrix.api import Method, Path from mautrix.errors import ( MatrixRequestError, MatrixResponseError, MNotFound, MNotJoined, MRoomInUse, ) from mautrix.types import ( JSON, DirectoryPaginationToken, EventID, EventType, Membership, MemberStateEventContent, PowerLevelStateEventContent, RoomAlias, RoomAliasInfo, RoomCreatePreset, RoomCreateStateEventContent, RoomDirectoryResponse, RoomDirectoryVisibility, RoomID, Serializable, StateEvent, StrippedStateEvent, UserID, ) from .base import BaseClientAPI from .events import EventMethods class RoomMethods(EventMethods, BaseClientAPI): """ Methods in section 8 Rooms of the spec. These methods are used for creating rooms, interacting with the room directory and using the easy room metadata editing endpoints. Generic state setting and sending events are in the :class:`EventMethods` (section 7) module. See also: `API reference `__ """ # region 8.1 Creation # API reference: https://spec.matrix.org/v1.1/client-server-api/#creation async def create_room( self, alias_localpart: str | None = None, visibility: RoomDirectoryVisibility = RoomDirectoryVisibility.PRIVATE, preset: RoomCreatePreset = RoomCreatePreset.PRIVATE, name: str | None = None, topic: str | None = None, is_direct: bool = False, invitees: list[UserID] | None = None, initial_state: list[StateEvent | StrippedStateEvent | dict[str, JSON]] | None = None, room_version: str = None, creation_content: RoomCreateStateEventContent | dict[str, JSON] | None = None, power_level_override: PowerLevelStateEventContent | dict[str, JSON] | None = None, beeper_auto_join_invites: bool = False, custom_request_fields: dict[str, Any] | None = None, ) -> RoomID: """ Create a new room with various configuration options. See also: `API reference `__ Args: alias_localpart: The desired room alias **local part**. If this is included, a room alias will be created and mapped to the newly created room. The alias will belong on the same homeserver which created the room. For example, if this was set to "foo" and sent to the homeserver "example.com" the complete room alias would be ``#foo:example.com``. visibility: A ``public`` visibility indicates that the room will be shown in the published room list. A ``private`` visibility will hide the room from the published room list. Defaults to ``private``. **NB:** This should not be confused with ``join_rules`` which also uses the word ``public``. preset: Convenience parameter for setting various default state events based on a preset. Defaults to private (invite-only). name: If this is included, an ``m.room.name`` event will be sent into the room to indicate the name of the room. See `Room Events`_ for more information on ``m.room.name``. topic: If this is included, an ``m.room.topic`` event will be sent into the room to indicate the topic for the room. See `Room Events`_ for more information on ``m.room.topic``. is_direct: This flag makes the server set the ``is_direct`` flag on the `m.room.member`_ events sent to the users in ``invite`` and ``invite_3pid``. See `Direct Messaging`_ for more information. invitees: A list of user IDs to invite to the room. This will tell the server to invite everyone in the list to the newly created room. initial_state: A list of state events to set in the new room. This allows the user to override the default state events set in the new room. The expected format of the state events are an object with type, state_key and content keys set. Takes precedence over events set by ``is_public``, but gets overriden by ``name`` and ``topic keys``. room_version: The room version to set for the room. If not provided, the homeserver will use its configured default. creation_content: Extra keys, such as ``m.federate``, to be added to the ``m.room.create`` event. The server will ignore ``creator`` and ``room_version``. Future versions of the specification may allow the server to ignore other keys. power_level_override: The power level content to override in the default power level event. This object is applied on top of the generated ``m.room.power_levels`` event content prior to it being sent to the room. Defaults to overriding nothing. beeper_auto_join_invites: A Beeper-specific extension which auto-joins all members in the invite array instead of sending invites. custom_request_fields: Additional fields to put in the top-level /createRoom content. Non-custom fields take precedence over fields here. Returns: The ID of the newly created room. Raises: MatrixResponseError: If the response does not contain a ``room_id`` field. .. _Room Events: https://spec.matrix.org/v1.1/client-server-api/#room-events .. _Direct Messaging: https://spec.matrix.org/v1.1/client-server-api/#direct-messaging .. _m.room.create: https://spec.matrix.org/v1.1/client-server-api/#mroomcreate .. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember """ content = { **(custom_request_fields or {}), "visibility": visibility.value, "is_direct": is_direct, "preset": preset.value, } if alias_localpart: content["room_alias_name"] = alias_localpart if invitees: content["invite"] = invitees if beeper_auto_join_invites: content["com.beeper.auto_join_invites"] = True if name: content["name"] = name if topic: content["topic"] = topic if initial_state: content["initial_state"] = [ event.serialize() if isinstance(event, Serializable) else event for event in initial_state ] if room_version: content["room_version"] = room_version if creation_content: content["creation_content"] = ( creation_content.serialize() if isinstance(creation_content, Serializable) else creation_content ) # Remove keys that the server will ignore anyway content["creation_content"].pop("room_version", None) if power_level_override: content["power_level_content_override"] = ( power_level_override.serialize() if isinstance(power_level_override, Serializable) else power_level_override ) resp = await self.api.request(Method.POST, Path.v3.createRoom, content) try: return resp["room_id"] except KeyError: raise MatrixResponseError("`room_id` not in response.") # endregion # region 8.2 Room aliases # API reference: https://spec.matrix.org/v1.1/client-server-api/#room-aliases async def add_room_alias( self, room_id: RoomID, alias_localpart: str, override: bool = False ) -> None: """ Create a new mapping from room alias to room ID. See also: `API reference `__ Args: room_id: The room ID to set. alias_localpart: The localpart of the room alias to set. override: Whether or not the alias should be removed and the request retried if the server responds with HTTP 409 Conflict """ room_alias = f"#{alias_localpart}:{self.domain}" content = {"room_id": room_id} try: await self.api.request(Method.PUT, Path.v3.directory.room[room_alias], content) except MatrixRequestError as e: if e.http_status == 409: if override: await self.remove_room_alias(alias_localpart) await self.api.request(Method.PUT, Path.v3.directory.room[room_alias], content) else: raise MRoomInUse(e.http_status, e.message) from e else: raise async def remove_room_alias(self, alias_localpart: str, raise_404: bool = False) -> None: """ Remove a mapping of room alias to room ID. Servers may choose to implement additional access control checks here, for instance that room aliases can only be deleted by their creator or server administrator. See also: `API reference `__ Args: alias_localpart: The room alias to remove. raise_404: Whether 404 errors should be raised as exceptions instead of ignored. """ room_alias = f"#{alias_localpart}:{self.domain}" try: await self.api.request(Method.DELETE, Path.v3.directory.room[room_alias]) except MNotFound: if raise_404: raise # else: ignore async def resolve_room_alias(self, room_alias: RoomAlias) -> RoomAliasInfo: """ Request the server to resolve a room alias to a room ID. The server will use the federation API to resolve the alias if the domain part of the alias does not correspond to the server's own domain. See also: `API reference `__ Args: room_alias: The room alias. Returns: The room ID and a list of servers that are aware of the room. """ content = await self.api.request(Method.GET, Path.v3.directory.room[room_alias]) return RoomAliasInfo.deserialize(content) # endregion # region 8.4 Room membership # API reference: https://spec.matrix.org/v1.1/client-server-api/#room-membership async def get_joined_rooms(self) -> list[RoomID]: """Get the list of rooms the user is in.""" content = await self.api.request(Method.GET, Path.v3.joined_rooms) try: return content["joined_rooms"] except KeyError: raise MatrixResponseError("`joined_rooms` not in response.") # region 8.4.1 Joining rooms # API reference: https://spec.matrix.org/v1.1/client-server-api/#joining-rooms async def join_room_by_id( self, room_id: RoomID, third_party_signed: JSON = None, extra_content: dict[str, Any] | None = None, ) -> RoomID: """ Start participating in a room, i.e. join it by its ID. See also: `API reference `__ Args: room_id: The ID of the room to join. third_party_signed: A signature of an ``m.third_party_invite`` token to prove that this user owns a third party identity which has been invited to the room. extra_content: Additional properties for the join event content. If a non-empty dict is passed, the join event will be created using the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /join``. Returns: The ID of the room the user joined. """ if extra_content: await self.send_member_event( room_id, self.mxid, Membership.JOIN, extra_content=extra_content ) return room_id content = await self.api.request( Method.POST, Path.v3.rooms[room_id].join, {"third_party_signed": third_party_signed} if third_party_signed is not None else None, ) try: return content["room_id"] except KeyError: raise MatrixResponseError("`room_id` not in response.") async def join_room( self, room_id_or_alias: RoomID | RoomAlias, servers: list[str] | None = None, third_party_signed: JSON = None, max_retries: int = 4, ) -> RoomID: """ Start participating in a room, i.e. join it by its ID or alias, with an optional list of servers to ask about the ID from. See also: `API reference `__ Args: room_id_or_alias: The ID of the room to join, or an alias pointing to the room. servers: A list of servers to ask about the room ID to join. Not applicable for aliases, as aliases already contain the necessary server information. third_party_signed: A signature of an ``m.third_party_invite`` token to prove that this user owns a third party identity which has been invited to the room. max_retries: The maximum number of retries. Used to circumvent a Synapse bug with accepting invites over federation. 0 means only one join call will be attempted. See: `matrix-org/synapse#2807 `__ Returns: The ID of the room the user joined. """ max_retries = max(0, max_retries) tries = 0 content = ( {"third_party_signed": third_party_signed} if third_party_signed is not None else None ) query_params = CIMultiDict() for server_name in servers or []: query_params.add("server_name", server_name) while tries <= max_retries: try: content = await self.api.request( Method.POST, Path.v3.join[room_id_or_alias], content=content, query_params=query_params, ) break except MatrixRequestError: tries += 1 if tries <= max_retries: wait = tries * 10 self.log.exception( f"Failed to join room {room_id_or_alias}, retrying in {wait} seconds..." ) await asyncio.sleep(wait) else: raise try: return content["room_id"] except KeyError: raise MatrixResponseError("`room_id` not in response.") fill_member_event_callback: ( Callable[ [RoomID, UserID, MemberStateEventContent], Awaitable[MemberStateEventContent | None] ] | None ) async def fill_member_event( self, room_id: RoomID, user_id: UserID, content: MemberStateEventContent ) -> MemberStateEventContent | None: """ Fill a membership event content that is going to be sent in :meth:`send_member_event`. This is used to set default fields like the displayname and avatar, which are usually set by the server in the sugar membership endpoints like /join and /invite, but are not set automatically when sending member events manually. This default implementation only calls :attr:`fill_member_event_callback`. Args: room_id: The room where the member event is going to be sent. user_id: The user whose membership is changing. content: The new member event content. Returns: The filled member event content. """ if self.fill_member_event_callback is not None: return await self.fill_member_event_callback(room_id, user_id, content) return None async def send_member_event( self, room_id: RoomID, user_id: UserID, membership: Membership, extra_content: dict[str, JSON] | None = None, ) -> EventID: """ Send a membership event manually. Args: room_id: The room to send the event to. user_id: The user whose membership to change. membership: The membership status. extra_content: Additional content to put in the member event. Returns: The event ID of the new member event. """ content = MemberStateEventContent(membership=membership) for key, value in extra_content.items(): content[key] = value content = await self.fill_member_event(room_id, user_id, content) or content return await self.send_state_event( room_id, EventType.ROOM_MEMBER, content=content, state_key=user_id, ensure_joined=False ) async def invite_user( self, room_id: RoomID, user_id: UserID, reason: str | None = None, extra_content: dict[str, JSON] | None = None, ) -> None: """ Invite a user to participate in a particular room. They do not start participating in the room until they actually join the room. Only users currently in the room can invite other users to join that room. If the user was invited to the room, the homeserver will add a `m.room.member`_ event to the room. See also: `API reference `__ .. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember Args: room_id: The ID of the room to which to invite the user. user_id: The fully qualified user ID of the invitee. reason: The reason the user was invited. This will be supplied as the ``reason`` on the `m.room.member`_ event. extra_content: Additional properties for the invite event content. If a non-empty dict is passed, the invite event will be created using the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /invite``. """ if extra_content: await self.send_member_event( room_id, user_id, Membership.INVITE, extra_content=extra_content ) else: data = {"user_id": user_id} if reason: data["reason"] = reason await self.api.request(Method.POST, Path.v3.rooms[room_id].invite, content=data) # endregion # region 8.4.2 Leaving rooms # API reference: https://spec.matrix.org/v1.1/client-server-api/#leaving-rooms async def leave_room( self, room_id: RoomID, reason: str | None = None, extra_content: dict[str, JSON] | None = None, raise_not_in_room: bool = False, ) -> None: """ Stop participating in a particular room, i.e. leave the room. If the user was already in the room, they will no longer be able to see new events in the room. If the room requires an invite to join, they will need to be re-invited before they can re-join. If the user was invited to the room, but had not joined, this call serves to reject the invite. The user will still be allowed to retrieve history from the room which they were previously allowed to see. See also: `API reference `__ Args: room_id: The ID of the room to leave. reason: The reason for leaving the room. This will be supplied as the ``reason`` on the updated `m.room.member`_ event. extra_content: Additional properties for the leave event content. If a non-empty dict is passed, the leave event will be created using the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /leave``. raise_not_in_room: Should errors about the user not being in the room be raised? """ try: if extra_content: await self.send_member_event( room_id, self.mxid, Membership.LEAVE, extra_content=extra_content ) else: data = {} if reason: data["reason"] = reason await self.api.request(Method.POST, Path.v3.rooms[room_id].leave, content=data) except MNotJoined: if raise_not_in_room: raise except MatrixRequestError as e: # TODO remove this once MSC3848 is released and minimum spec version is bumped if "not in room" not in e.message or raise_not_in_room: raise async def knock_room( self, room_id_or_alias: RoomID | RoomAlias, reason: str | None = None, servers: list[str] | None = None, ) -> RoomID: """ Knock on a room, i.e. request to join it by its ID or alias, with an optional list of servers to ask about the ID from. See also: `API reference `__ Args: room_id_or_alias: The ID of the room to knock on, or an alias pointing to the room. reason: The reason for knocking on the room. This will be supplied as the ``reason`` on the updated `m.room.member`_ event. servers: A list of servers to ask about the room ID to knock. Not applicable for aliases, as aliases already contain the necessary server information. Returns: The ID of the room the user knocked on. """ data = {} if reason: data["reason"] = reason query_params = CIMultiDict() for server_name in servers or []: query_params.add("server_name", server_name) content = await self.api.request( Method.POST, Path.v3.knock[room_id_or_alias], content=data, query_params=query_params, ) try: return content["room_id"] except KeyError: raise MatrixResponseError("`room_id` not in response.") async def forget_room(self, room_id: RoomID) -> None: """ Stop remembering a particular room, i.e. forget it. In general, history is a first class citizen in Matrix. After this API is called, however, a user will no longer be able to retrieve history for this room. If all users on a homeserver forget a room, the room is eligible for deletion from that homeserver. If the user is currently joined to the room, they must leave the room before calling this API. See also: `API reference `__ Args: room_id: The ID of the room to forget. """ await self.api.request(Method.POST, Path.v3.rooms[room_id].forget) async def kick_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: """ Kick a user from the room. The caller must have the required power level in order to perform this operation. Kicking a user adjusts the target member's membership state to be ``leave`` with an optional ``reason``. Like with other membership changes, a user can directly adjust the target member's state by calling :meth:`EventMethods.send_state_event` with :attr:`EventType.ROOM_MEMBER` as the event type and the ``user_id`` as the state key. See also: `API reference `__ Args: room_id: The ID of the room from which the user should be kicked. user_id: The fully qualified user ID of the user being kicked. reason: The reason the user has been kicked. This will be supplied as the ``reason`` on the target's updated `m.room.member`_ event. extra_content: Additional properties for the kick event content. If a non-empty dict is passed, the kick event will be created using the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /kick``. .. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember """ if extra_content: if reason and "reason" not in extra_content: extra_content["reason"] = reason await self.send_member_event( room_id, user_id, Membership.LEAVE, extra_content=extra_content ) return await self.api.request( Method.POST, Path.v3.rooms[room_id].kick, {"user_id": user_id, "reason": reason} ) # endregion # region 8.4.2.1 Banning users in a room # API reference: https://spec.matrix.org/v1.1/client-server-api/#banning-users-in-a-room async def ban_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: """ Ban a user in the room. If the user is currently in the room, also kick them. When a user is banned from a room, they may not join it or be invited to it until they are unbanned. The caller must have the required power level in order to perform this operation. See also: `API reference `__ Args: room_id: The ID of the room from which the user should be banned. user_id: The fully qualified user ID of the user being banned. reason: The reason the user has been kicked. This will be supplied as the ``reason`` on the target's updated `m.room.member`_ event. extra_content: Additional properties for the ban event content. If a non-empty dict is passed, the ban will be created using the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /ban``. .. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember """ if extra_content: if reason and "reason" not in extra_content: extra_content["reason"] = reason await self.send_member_event( room_id, user_id, Membership.BAN, extra_content=extra_content ) return await self.api.request( Method.POST, Path.v3.rooms[room_id].ban, {"user_id": user_id, "reason": reason} ) async def unban_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: """ Unban a user from the room. This allows them to be invited to the room, and join if they would otherwise be allowed to join according to its join rules. The caller must have the required power level in order to perform this operation. See also: `API reference `__ Args: room_id: The ID of the room from which the user should be unbanned. user_id: The fully qualified user ID of the user being banned. reason: The reason the user has been unbanned. This will be supplied as the ``reason`` on the target's updated `m.room.member`_ event. extra_content: Additional properties for the unban (leave) event content. If a non-empty dict is passed, the unban will be created using the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /unban``. """ if extra_content: if reason and "reason" not in extra_content: extra_content["reason"] = reason await self.send_member_event( room_id, user_id, Membership.LEAVE, extra_content=extra_content ) return await self.api.request( Method.POST, Path.v3.rooms[room_id].unban, {"user_id": user_id, "reason": reason} ) # endregion # endregion # region 8.5 Listing rooms # API reference: https://spec.matrix.org/v1.1/client-server-api/#listing-rooms async def get_room_directory_visibility(self, room_id: RoomID) -> RoomDirectoryVisibility: """ Get the visibility of the room on the server's public room directory. See also: `API reference `__ Args: room_id: The ID of the room. Returns: The visibility of the room in the directory. """ resp = await self.api.request(Method.GET, Path.v3.directory.list.room[room_id]) try: return RoomDirectoryVisibility(resp["visibility"]) except KeyError: raise MatrixResponseError("`visibility` not in response.") except ValueError: raise MatrixResponseError( f"Invalid value for `visibility` in response: {resp['visibility']}" ) async def set_room_directory_visibility( self, room_id: RoomID, visibility: RoomDirectoryVisibility ) -> None: """ Set the visibility of the room in the server's public room directory. Servers may choose to implement additional access control checks here, for instance that room visibility can only be changed by the room creator or a server administrator. Args: room_id: The ID of the room. visibility: The new visibility setting for the room. .. _API reference: https://spec.matrix.org/v1.1/client-server-api/#put_matrixclientv3directorylistroomroomid """ await self.api.request( Method.PUT, Path.v3.directory.list.room[room_id], { "visibility": visibility.value, }, ) async def get_room_directory( self, limit: int | None = None, server: str | None = None, since: DirectoryPaginationToken | None = None, search_query: str | None = None, include_all_networks: bool | None = None, third_party_instance_id: str | None = None, ) -> RoomDirectoryResponse: """ Get a list of public rooms from the server's room directory. See also: `API reference `__ Args: limit: The maximum number of results to return. server: The server to fetch the room directory from. Defaults to the user's server. since: A pagination token from a previous request, allowing clients to get the next (or previous) batch of rooms. The direction of pagination is specified solely by which token is supplied, rather than via an explicit flag. search_query: A string to search for in the room metadata, e.g. name, topic, canonical alias etc. include_all_networks: Whether or not to include rooms from all known networks/protocols from application services on the homeserver. Defaults to false. third_party_instance_id: The specific third party network/protocol to request from the homeserver. Can only be used if ``include_all_networks`` is false. Returns: The relevant pagination tokens, an estimate of the total number of public rooms and the paginated chunk of public rooms. """ method = ( Method.GET if ( search_query is None and include_all_networks is None and third_party_instance_id is None ) else Method.POST ) content = {} if limit is not None: content["limit"] = limit if since is not None: content["since"] = since if search_query is not None: content["filter"] = {"generic_search_term": search_query} if include_all_networks is not None: content["include_all_networks"] = include_all_networks if third_party_instance_id is not None: content["third_party_instance_id"] = third_party_instance_id query_params = {"server": server} if server is not None else None content = await self.api.request( method, Path.v3.publicRooms, content, query_params=query_params ) return RoomDirectoryResponse.deserialize(content) # endregion python-0.20.7/mautrix/client/api/user_data.py000066400000000000000000000154721473573527000212100ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError, MNotFound from mautrix.types import ContentURI, Member, SerializerError, User, UserID, UserSearchResults from .base import BaseClientAPI class UserDataMethods(BaseClientAPI): """ Methods in section 10 User Data of the spec. These methods are used for setting and getting user metadata and searching for users. See also: `API reference `__ """ # region 10.1 User Directory # API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#user-directory async def search_users(self, search_query: str, limit: int | None = 10) -> UserSearchResults: """ Performs a search for users on the homeserver. The homeserver may determine which subset of users are searched, however the homeserver MUST at a minimum consider the users the requesting user shares a room with and those who reside in public rooms (known to the homeserver). The search MUST consider local users to the homeserver, and SHOULD query remote users as part of the search. The search is performed case-insensitively on user IDs and display names preferably using a collation determined based upon the Accept-Language header provided in the request, if present. See also: `API reference `__ Args: search_query: The query to search for. limit: The maximum number of results to return. Returns: The results of the search and whether or not the results were limited. """ content = await self.api.request( Method.POST, Path.v3.user_directory.search, { "search_term": search_query, "limit": limit, }, ) try: return UserSearchResults( [User.deserialize(user) for user in content["results"]], content["limited"] ) except SerializerError as e: raise MatrixResponseError("Invalid user in search results") from e except KeyError: if "results" not in content: raise MatrixResponseError("`results` not in content.") elif "limited" not in content: raise MatrixResponseError("`limited` not in content.") raise # endregion # region 10.2 Profiles # API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#profiles async def set_displayname(self, displayname: str, check_current: bool = True) -> None: """ Set the display name of the current user. See also: `API reference `__ Args: displayname: The new display name for the user. check_current: Whether or not to check if the displayname is already set. """ if check_current and await self.get_displayname(self.mxid) == displayname: return await self.api.request( Method.PUT, Path.v3.profile[self.mxid].displayname, { "displayname": displayname, }, ) async def get_displayname(self, user_id: UserID) -> str | None: """ Get the display name of a user. See also: `API reference `__ Args: user_id: The ID of the user whose display name to get. Returns: The display name of the given user. """ try: content = await self.api.request(Method.GET, Path.v3.profile[user_id].displayname) except MNotFound: return None try: return content["displayname"] except KeyError: return None async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = True) -> None: """ Set the avatar of the current user. See also: `API reference `__ Args: avatar_url: The ``mxc://`` URI to the new avatar. check_current: Whether or not to check if the avatar is already set. """ if check_current and await self.get_avatar_url(self.mxid) == avatar_url: return await self.api.request( Method.PUT, Path.v3.profile[self.mxid].avatar_url, { "avatar_url": avatar_url, }, ) async def get_avatar_url(self, user_id: UserID) -> ContentURI | None: """ Get the avatar URL of a user. See also: `API reference `__ Args: user_id: The ID of the user whose avatar to get. Returns: The ``mxc://`` URI to the user's avatar. """ try: content = await self.api.request(Method.GET, Path.v3.profile[user_id].avatar_url) except MNotFound: return None try: return content["avatar_url"] except KeyError: return None async def get_profile(self, user_id: UserID) -> Member: """ Get the combined profile information for a user. See also: `API reference `__ Args: user_id: The ID of the user whose profile to get. Returns: The profile information of the given user. """ content = await self.api.request(Method.GET, Path.v3.profile[user_id]) try: return Member.deserialize(content) except SerializerError as e: raise MatrixResponseError("Invalid member in response") from e # endregion # region Beeper Custom Fields API async def beeper_update_profile(self, custom_fields: dict[str, Any]) -> None: """ Set custom fields on the user's profile. Only works on Hungryserv. Args: custom_fields: A dictionary of fields to set in the custom content of the profile. """ await self.api.request(Method.PATCH, Path.v3.profile[self.mxid], custom_fields) # endregion python-0.20.7/mautrix/client/client.py000066400000000000000000000035001473573527000177330ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix import __optional_imports__ from mautrix.types import Event, EventType, StateEvent from .encryption_manager import DecryptionDispatcher, EncryptingAPI from .state_store import StateStore, SyncStore from .syncer import Syncer if __optional_imports__: from .. import crypto as crypt class Client(EncryptingAPI, Syncer): """Client is a high-level wrapper around the client API.""" def __init__( self, *args, sync_store: SyncStore | None = None, state_store: StateStore | None = None, **kwargs, ) -> None: EncryptingAPI.__init__(self, *args, state_store=state_store, **kwargs) Syncer.__init__(self, sync_store) self.add_event_handler(EventType.ALL, self._update_state) async def _update_state(self, evt: Event) -> None: if not isinstance(evt, StateEvent) or not self.state_store: return await self.state_store.update_state(evt) @EncryptingAPI.crypto.setter def crypto(self, crypto: crypt.OlmMachine | None) -> None: """ Set the olm machine and enable the automatic event decryptor. Args: crypto: The olm machine to use for crypto Raises: ValueError: if :attr:`state_store` is not set. """ if not self.state_store: raise ValueError("State store must be set to use encryption") self._crypto = crypto if self.crypto_enabled: self.add_dispatcher(DecryptionDispatcher) else: self.remove_dispatcher(DecryptionDispatcher) python-0.20.7/mautrix/client/dispatcher.py000066400000000000000000000047101473573527000206070ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import ClassVar from abc import ABC, abstractmethod from mautrix.types import Event, EventType, Membership, StateEvent from . import syncer class Dispatcher(ABC): client: syncer.Syncer def __init__(self, client: syncer.Syncer) -> None: self.client = client @abstractmethod def register(self) -> None: pass @abstractmethod def unregister(self) -> None: pass class SimpleDispatcher(Dispatcher, ABC): event_type: ClassVar[EventType] def register(self) -> None: self.client.add_event_handler(self.event_type, self.handle) def unregister(self) -> None: self.client.remove_event_handler(self.event_type, self.handle) @abstractmethod async def handle(self, evt: Event) -> None: pass class MembershipEventDispatcher(SimpleDispatcher): event_type = EventType.ROOM_MEMBER async def handle(self, evt: StateEvent) -> None: if evt.type != EventType.ROOM_MEMBER: return if evt.content.membership == Membership.JOIN: if evt.prev_content.membership != Membership.JOIN: change_type = syncer.InternalEventType.JOIN else: change_type = syncer.InternalEventType.PROFILE_CHANGE elif evt.content.membership == Membership.INVITE: change_type = syncer.InternalEventType.INVITE elif evt.content.membership == Membership.LEAVE: if evt.prev_content.membership == Membership.BAN: change_type = syncer.InternalEventType.UNBAN elif evt.prev_content.membership == Membership.INVITE: if evt.state_key == evt.sender: change_type = syncer.InternalEventType.REJECT_INVITE else: change_type = syncer.InternalEventType.DISINVITE elif evt.state_key == evt.sender: change_type = syncer.InternalEventType.LEAVE else: change_type = syncer.InternalEventType.KICK elif evt.content.membership == Membership.BAN: change_type = syncer.InternalEventType.BAN else: return self.client.dispatch_manual_event(change_type, evt) python-0.20.7/mautrix/client/encryption_manager.py000066400000000000000000000157231473573527000223530ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import asyncio import logging from mautrix import __optional_imports__ from mautrix.errors import DecryptionError, EncryptionError, MNotFound from mautrix.types import ( EncryptedEvent, EncryptedMegolmEventContent, EventContent, EventID, EventType, RoomID, ) from mautrix.util.logging import TraceLogger from . import client, dispatcher, store_updater if __optional_imports__: from .. import crypto as crypt class EncryptingAPI(store_updater.StoreUpdatingAPI): """ EncryptingAPI is a wrapper around StoreUpdatingAPI that automatically encrypts messages. For automatic decryption, see :class:`DecryptionDispatcher`. """ _crypto: crypt.OlmMachine | None encryption_blacklist: set[EventType] = {EventType.REACTION} """A set of event types which shouldn't be encrypted even in encrypted rooms.""" crypto_log: TraceLogger = logging.getLogger("mau.client.crypto") """The logger to use for crypto-related things.""" _share_session_events: dict[RoomID, asyncio.Event] def __init__(self, *args, crypto_log: TraceLogger | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) if crypto_log: self.crypto_log = crypto_log self._crypto = None self._share_session_events = {} @property def crypto(self) -> crypt.OlmMachine | None: """The :class:`crypto.OlmMachine` to use for e2ee stuff.""" return self._crypto @crypto.setter def crypto(self, crypto: crypt.OlmMachine) -> None: """ Args: crypto: The olm machine to use for crypto Raises: ValueError: if :attr:`state_store` is not set. """ if not self.state_store: raise ValueError("State store must be set to use encryption") self._crypto = crypto @property def crypto_enabled(self) -> bool: """``True`` if both the olm machine and state store are set properly.""" return bool(self.crypto) and bool(self.state_store) async def encrypt( self, room_id: RoomID, event_type: EventType, content: EventContent ) -> EncryptedMegolmEventContent: """ Encrypt a message for the given room. Automatically creates and shares a group session if necessary. Args: room_id: The room to encrypt the event to. event_type: The type of event. content: The content of the event. Returns: The content of the encrypted event. """ try: return await self.crypto.encrypt_megolm_event(room_id, event_type, content) except EncryptionError: self.crypto_log.debug("Got EncryptionError, sharing group session and trying again") await self.share_group_session(room_id) self.crypto_log.trace( f"Shared group session, now trying to encrypt in {room_id} again" ) return await self.crypto.encrypt_megolm_event(room_id, event_type, content) async def _share_session_lock(self, room_id: RoomID) -> bool: try: event = self._share_session_events[room_id] except KeyError: self._share_session_events[room_id] = asyncio.Event() return True else: await event.wait() return False async def share_group_session(self, room_id: RoomID) -> None: """ Create and share a Megolm session for the given room. Args: room_id: The room to share the session for. """ if not await self._share_session_lock(room_id): self.log.silly("Group session was already being shared, so didn't share new one") return try: if not await self.state_store.has_full_member_list(room_id): self.crypto_log.trace( f"Don't have full member list for {room_id}, fetching from server" ) members = list((await self.get_joined_members(room_id)).keys()) else: self.crypto_log.trace(f"Fetching member list for {room_id} from state store") members = await self.state_store.get_members(room_id) await self.crypto.share_group_session(room_id, members) finally: self._share_session_events.pop(room_id).set() async def send_message_event( self, room_id: RoomID, event_type: EventType, content: EventContent, disable_encryption: bool = False, **kwargs, ) -> EventID: """ A wrapper around :meth:`ClientAPI.send_message_event` that encrypts messages if the target room is encrypted. Args: room_id: The room to send the message to. event_type: The unencrypted event type. content: The unencrypted event content. disable_encryption: Set to ``True`` if you want to force-send an unencrypted message. **kwargs: Additional parameters to pass to :meth:`ClientAPI.send_message_event`. Returns: The ID of the event that was sent. """ if self.crypto and event_type not in self.encryption_blacklist and not disable_encryption: is_encrypted = await self.state_store.is_encrypted(room_id) if is_encrypted is None: try: await self.get_state_event(room_id, EventType.ROOM_ENCRYPTION) is_encrypted = True except MNotFound: is_encrypted = False if is_encrypted: content = await self.encrypt(room_id, event_type, content) event_type = EventType.ROOM_ENCRYPTED return await super().send_message_event(room_id, event_type, content, **kwargs) class DecryptionDispatcher(dispatcher.SimpleDispatcher): """ DecryptionDispatcher is a dispatcher that can be used with a :class:`client.Syncer` to automatically decrypt events and dispatch the unencrypted versions for event handlers. The easiest way to use this is with :class:`client.Client`, which automatically registers this dispatcher when :attr:`EncryptingAPI.crypto` is set. """ event_type = EventType.ROOM_ENCRYPTED client: client.Client async def handle(self, evt: EncryptedEvent) -> None: try: self.client.crypto_log.trace(f"Decrypting {evt.event_id} in {evt.room_id}...") decrypted = await self.client.crypto.decrypt_megolm_event(evt) except DecryptionError as e: self.client.crypto_log.warning(f"Failed to decrypt {evt.event_id}: {e}") return self.client.crypto_log.trace(f"Decrypted {evt.event_id}: {decrypted}") self.client.dispatch_event(decrypted, evt.source) python-0.20.7/mautrix/client/state_store/000077500000000000000000000000001473573527000204415ustar00rootroot00000000000000python-0.20.7/mautrix/client/state_store/__init__.py000066400000000000000000000004321473573527000225510ustar00rootroot00000000000000from .abstract import StateStore from .file import FileStateStore from .memory import MemoryStateStore from .sync import MemorySyncStore, SyncStore __all__ = [ "StateStore", "FileStateStore", "MemoryStateStore", "MemorySyncStore", "SyncStore", "asyncpg", ] python-0.20.7/mautrix/client/state_store/abstract.py000066400000000000000000000133131473573527000226170ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Awaitable from abc import ABC, abstractmethod from mautrix.types import ( EventType, Member, Membership, MemberStateEventContent, PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, StateEvent, UserID, ) class StateStore(ABC): async def open(self) -> None: pass async def close(self) -> None: await self.flush() async def flush(self) -> None: pass @abstractmethod async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None: pass @abstractmethod async def set_member( self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent ) -> None: pass @abstractmethod async def set_membership( self, room_id: RoomID, user_id: UserID, membership: Membership ) -> None: pass @abstractmethod async def get_member_profiles( self, room_id: RoomID, memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), ) -> dict[UserID, Member]: pass async def get_members( self, room_id: RoomID, memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), ) -> list[UserID]: profiles = await self.get_member_profiles(room_id, memberships) return list(profiles.keys()) async def get_members_filtered( self, room_id: RoomID, not_prefix: str, not_suffix: str, not_id: str, memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), ) -> list[UserID]: """ A filtered version of get_members that only returns user IDs that aren't operated by a bridge. This should return the same as :meth:`get_members`, except users where the user ID is equal to not_id OR it starts with not_prefix AND ends with not_suffix. The default implementation simply calls :meth:`get_members`, but databases can implement this more efficiently. Args: room_id: The room ID to find. not_prefix: The user ID prefix to disallow. not_suffix: The user ID suffix to disallow. not_id: The user ID to disallow. memberships: The membership states to include. """ members = await self.get_members(room_id, memberships=memberships) return [ user_id for user_id in members if user_id != not_id and not (user_id.startswith(not_prefix) and user_id.endswith(not_suffix)) ] @abstractmethod async def set_members( self, room_id: RoomID, members: dict[UserID, Member | MemberStateEventContent], only_membership: Membership | None = None, ) -> None: pass @abstractmethod async def has_full_member_list(self, room_id: RoomID) -> bool: pass @abstractmethod async def has_power_levels_cached(self, room_id: RoomID) -> bool: pass @abstractmethod async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None: pass @abstractmethod async def set_power_levels( self, room_id: RoomID, content: PowerLevelStateEventContent ) -> None: pass @abstractmethod async def has_encryption_info_cached(self, room_id: RoomID) -> bool: pass @abstractmethod async def is_encrypted(self, room_id: RoomID) -> bool | None: pass @abstractmethod async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None: pass @abstractmethod async def set_encryption_info( self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, any] ) -> None: pass async def update_state(self, evt: StateEvent) -> None: if evt.type == EventType.ROOM_POWER_LEVELS: await self.set_power_levels(evt.room_id, evt.content) elif evt.type == EventType.ROOM_MEMBER: evt.unsigned["mautrix_prev_membership"] = await self.get_member( evt.room_id, UserID(evt.state_key) ) await self.set_member(evt.room_id, UserID(evt.state_key), evt.content) elif evt.type == EventType.ROOM_ENCRYPTION: await self.set_encryption_info(evt.room_id, evt.content) async def get_membership(self, room_id: RoomID, user_id: UserID) -> Membership: member = await self.get_member(room_id, user_id) return member.membership if member else Membership.LEAVE async def is_joined(self, room_id: RoomID, user_id: UserID) -> bool: return (await self.get_membership(room_id, user_id)) == Membership.JOIN def joined(self, room_id: RoomID, user_id: UserID) -> Awaitable[None]: return self.set_membership(room_id, user_id, Membership.JOIN) def invited(self, room_id: RoomID, user_id: UserID) -> Awaitable[None]: return self.set_membership(room_id, user_id, Membership.INVITE) def left(self, room_id: RoomID, user_id: UserID) -> Awaitable[None]: return self.set_membership(room_id, user_id, Membership.LEAVE) async def has_power_level( self, room_id: RoomID, user_id: UserID, event_type: EventType ) -> bool | None: room_levels = await self.get_power_levels(room_id) if not room_levels: return None return room_levels.get_user_level(user_id) >= room_levels.get_event_level(event_type) python-0.20.7/mautrix/client/state_store/asyncpg/000077500000000000000000000000001473573527000221055ustar00rootroot00000000000000python-0.20.7/mautrix/client/state_store/asyncpg/__init__.py000066400000000000000000000000741473573527000242170ustar00rootroot00000000000000from .store import PgStateStore __all__ = ["PgStateStore"] python-0.20.7/mautrix/client/state_store/asyncpg/store.py000066400000000000000000000244441473573527000236230ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, NamedTuple import json from mautrix.types import ( Member, Membership, MemberStateEventContent, PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, Serializable, UserID, ) from mautrix.util.async_db import Database, Scheme from ..abstract import StateStore from .upgrade import upgrade_table class RoomState(NamedTuple): is_encrypted: bool has_full_member_list: bool encryption: RoomEncryptionStateEventContent power_levels: PowerLevelStateEventContent class PgStateStore(StateStore): upgrade_table = upgrade_table db: Database def __init__(self, db: Database) -> None: self.db = db async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None: res = await self.db.fetchrow( "SELECT membership, displayname, avatar_url " "FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", room_id, user_id, ) if res is None: return None return Member( membership=Membership.deserialize(res["membership"]), displayname=res["displayname"], avatar_url=res["avatar_url"], ) async def set_member( self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent ) -> None: q = ( "INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) " "VALUES ($1, $2, $3, $4, $5)" "ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3, displayname=$4," " avatar_url=$5" ) await self.db.execute( q, room_id, user_id, member.membership.value, member.displayname, member.avatar_url ) async def set_membership( self, room_id: RoomID, user_id: UserID, membership: Membership ) -> None: q = ( "INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3) " "ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3" ) await self.db.execute(q, room_id, user_id, membership.value) async def get_members( self, room_id: RoomID, memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), ) -> list[UserID]: membership_values = [membership.value for membership in memberships] if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): q = "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND membership=ANY($2)" res = await self.db.fetch(q, room_id, membership_values) else: membership_placeholders = ("?," * len(memberships)).rstrip(",") q = ( "SELECT user_id FROM mx_user_profile " f"WHERE room_id=? AND membership IN ({membership_placeholders})" ) res = await self.db.fetch(q, room_id, *membership_values) return [profile["user_id"] for profile in res] async def get_member_profiles( self, room_id: RoomID, memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), ) -> dict[UserID, Member]: membership_values = [membership.value for membership in memberships] if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): q = ( "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile " "WHERE room_id=$1 AND membership=ANY($2)" ) res = await self.db.fetch(q, room_id, membership_values) else: membership_placeholders = ("?," * len(memberships)).rstrip(",") q = ( "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile " f"WHERE room_id=? AND membership IN ({membership_placeholders})" ) res = await self.db.fetch(q, room_id, *membership_values) return {profile["user_id"]: Member.deserialize(profile) for profile in res} async def get_members_filtered( self, room_id: RoomID, not_prefix: str, not_suffix: str, not_id: str, memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), ) -> list[UserID]: not_like = f"{not_prefix}%{not_suffix}" membership_values = [membership.value for membership in memberships] if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): q = ( "SELECT user_id FROM mx_user_profile " "WHERE room_id=$1 AND membership=ANY($2)" "AND user_id != $3 AND user_id NOT LIKE $4" ) res = await self.db.fetch(q, room_id, membership_values, not_id, not_like) else: membership_placeholders = ("?," * len(memberships)).rstrip(",") q = ( "SELECT user_id FROM mx_user_profile " f"WHERE room_id=? AND membership IN ({membership_placeholders})" "AND user_id != ? AND user_id NOT LIKE ?" ) res = await self.db.fetch(q, room_id, *membership_values, not_id, not_like) return [profile["user_id"] for profile in res] async def set_members( self, room_id: RoomID, members: dict[UserID, Member | MemberStateEventContent], only_membership: Membership | None = None, ) -> None: columns = ["room_id", "user_id", "membership", "displayname", "avatar_url"] records = [ (room_id, user_id, str(member.membership), member.displayname, member.avatar_url) for user_id, member in members.items() ] async with self.db.acquire() as conn, conn.transaction(): del_q = "DELETE FROM mx_user_profile WHERE room_id=$1" if only_membership is None: await conn.execute(del_q, room_id) elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): del_q = f"{del_q} AND (membership=$2 OR user_id = ANY($3))" await conn.execute(del_q, room_id, only_membership.value, list(members.keys())) else: member_placeholders = ("?," * len(members)).rstrip(",") del_q = f"{del_q} AND (membership=? OR user_id IN ({member_placeholders}))" await conn.execute(del_q, room_id, only_membership.value, *members.keys()) if self.db.scheme == Scheme.POSTGRES: await conn.copy_records_to_table( "mx_user_profile", records=records, columns=columns ) else: q = ( "INSERT INTO mx_user_profile (room_id, user_id, membership, " "displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)" ) await conn.executemany(q, records) if not only_membership or only_membership == Membership.JOIN: await conn.execute( "UPDATE mx_room_state SET has_full_member_list=true WHERE room_id=$1", room_id, ) async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]: q = ( "SELECT mx_user_profile.room_id FROM mx_user_profile " "LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id " "WHERE user_id=$1 AND mx_room_state.is_encrypted=true" ) rows = await self.db.fetch(q, user_id) return [row["room_id"] for row in rows] async def has_full_member_list(self, room_id: RoomID) -> bool: return bool( await self.db.fetchval( "SELECT has_full_member_list FROM mx_room_state WHERE room_id=$1", room_id ) ) async def has_power_levels_cached(self, room_id: RoomID) -> bool: return bool( await self.db.fetchval( "SELECT power_levels IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id ) ) async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None: power_levels_json = await self.db.fetchval( "SELECT power_levels FROM mx_room_state WHERE room_id=$1", room_id ) if power_levels_json is None: return None return PowerLevelStateEventContent.parse_json(power_levels_json) async def set_power_levels( self, room_id: RoomID, content: PowerLevelStateEventContent | dict[str, Any] ) -> None: await self.db.execute( "INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) " "ON CONFLICT (room_id) DO UPDATE SET power_levels=$2", room_id, json.dumps(content.serialize() if isinstance(content, Serializable) else content), ) async def has_encryption_info_cached(self, room_id: RoomID) -> bool: return bool( await self.db.fetchval( "SELECT encryption IS NULL FROM mx_room_state WHERE room_id=$1", room_id ) ) async def is_encrypted(self, room_id: RoomID) -> bool | None: return await self.db.fetchval( "SELECT is_encrypted FROM mx_room_state WHERE room_id=$1", room_id ) async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None: row = await self.db.fetchrow( "SELECT is_encrypted, encryption FROM mx_room_state WHERE room_id=$1", room_id ) if row is None or not row["is_encrypted"]: return None return RoomEncryptionStateEventContent.parse_json(row["encryption"]) async def set_encryption_info( self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: q = ( "INSERT INTO mx_room_state (room_id, is_encrypted, encryption) VALUES ($1, true, $2) " "ON CONFLICT (room_id) DO UPDATE SET is_encrypted=true, encryption=$2" ) await self.db.execute( q, room_id, json.dumps(content.serialize() if isinstance(content, Serializable) else content), ) python-0.20.7/mautrix/client/state_store/asyncpg/upgrade.py000066400000000000000000000053331473573527000241120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import logging from mautrix.util.async_db import Connection, Scheme, UpgradeTable upgrade_table = UpgradeTable( version_table_name="mx_version", database_name="matrix state cache", log=logging.getLogger("mau.client.db.upgrade"), ) @upgrade_table.register(description="Latest revision", upgrades_to=3) async def upgrade_blank_to_v3(conn: Connection, scheme: Scheme) -> None: await conn.execute( """CREATE TABLE mx_room_state ( room_id TEXT PRIMARY KEY, is_encrypted BOOLEAN, has_full_member_list BOOLEAN, encryption TEXT, power_levels TEXT )""" ) membership_check = "" if scheme != Scheme.SQLITE: await conn.execute( "CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')" ) else: membership_check = "CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock'))" await conn.execute( f"""CREATE TABLE mx_user_profile ( room_id TEXT, user_id TEXT, membership membership NOT NULL {membership_check}, displayname TEXT, avatar_url TEXT, PRIMARY KEY (room_id, user_id) )""" ) @upgrade_table.register(description="Stop using size-limited string fields") async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: if scheme == Scheme.SQLITE: # SQLite doesn't care about types return await conn.execute("ALTER TABLE mx_room_state ALTER COLUMN room_id TYPE TEXT") await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN room_id TYPE TEXT") await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT") await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN displayname TYPE TEXT") await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT") @upgrade_table.register(description="Mark rooms that need crypto state event resynced") async def upgrade_v3(conn: Connection) -> None: if await conn.table_exists("portal"): await conn.execute( """ INSERT INTO mx_room_state (room_id, encryption) SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption WHERE mx_room_state.encryption IS NULL """ ) python-0.20.7/mautrix/client/state_store/file.py000066400000000000000000000041171473573527000217350ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import IO, Any from pathlib import Path from mautrix.types import ( Member, Membership, MemberStateEventContent, PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, UserID, ) from mautrix.util.file_store import Filer, FileStore from .memory import MemoryStateStore class FileStateStore(MemoryStateStore, FileStore): def __init__( self, path: str | Path | IO, filer: Filer | None = None, binary: bool = True, save_interval: float = 60.0, ) -> None: FileStore.__init__(self, path, filer, binary, save_interval) MemoryStateStore.__init__(self) async def set_membership( self, room_id: RoomID, user_id: UserID, membership: Membership ) -> None: await super().set_membership(room_id, user_id, membership) self._time_limited_flush() async def set_member( self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent ) -> None: await super().set_member(room_id, user_id, member) self._time_limited_flush() async def set_members( self, room_id: RoomID, members: dict[UserID, Member | MemberStateEventContent], only_membership: Membership | None = None, ) -> None: await super().set_members(room_id, members, only_membership) self._time_limited_flush() async def set_encryption_info( self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: await super().set_encryption_info(room_id, content) self._time_limited_flush() async def set_power_levels( self, room_id: RoomID, content: PowerLevelStateEventContent ) -> None: await super().set_power_levels(room_id, content) self._time_limited_flush() python-0.20.7/mautrix/client/state_store/memory.py000066400000000000000000000154141473573527000223300ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, TypedDict from mautrix.types import ( Member, Membership, MemberStateEventContent, PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, UserID, ) from .abstract import StateStore class SerializedStateStore(TypedDict): members: dict[RoomID, dict[UserID, Any]] full_member_list: dict[RoomID, bool] power_levels: dict[RoomID, Any] encryption: dict[RoomID, Any] class MemoryStateStore(StateStore): members: dict[RoomID, dict[UserID, Member]] full_member_list: dict[RoomID, bool] power_levels: dict[RoomID, PowerLevelStateEventContent] encryption: dict[RoomID, RoomEncryptionStateEventContent | None] def __init__(self) -> None: self.members = {} self.full_member_list = {} self.power_levels = {} self.encryption = {} def serialize(self) -> SerializedStateStore: """ Convert the data in the store into a JSON-friendly dict. Returns: A dict that can be safely serialized with most object serialization methods. """ return { "members": { room_id: {user_id: member.serialize() for user_id, member in members.items()} for room_id, members in self.members.items() }, "full_member_list": self.full_member_list, "power_levels": { room_id: content.serialize() for room_id, content in self.power_levels.items() }, "encryption": { room_id: (content.serialize() if content is not None else None) for room_id, content in self.encryption.items() }, } def deserialize(self, data: SerializedStateStore) -> None: """ Parse a previously serialized dict into this state store. Args: data: A dict returned by :meth:`serialize`. """ self.members = { room_id: {user_id: Member.deserialize(member) for user_id, member in members.items()} for room_id, members in data["members"].items() } self.full_member_list = data["full_member_list"] self.power_levels = { room_id: PowerLevelStateEventContent.deserialize(content) for room_id, content in data["power_levels"].items() } self.encryption = { room_id: ( RoomEncryptionStateEventContent.deserialize(content) if content is not None else None ) for room_id, content in data["encryption"].items() } async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None: try: return self.members[room_id][user_id] except KeyError: return None async def set_member( self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent ) -> None: if not isinstance(member, Member): member = Member( membership=member.membership, avatar_url=member.avatar_url, displayname=member.displayname, ) try: self.members[room_id][user_id] = member except KeyError: self.members[room_id] = {user_id: member} async def set_membership( self, room_id: RoomID, user_id: UserID, membership: Membership ) -> None: try: room_members = self.members[room_id] except KeyError: self.members[room_id] = {user_id: Member(membership=membership)} return try: room_members[user_id].membership = membership except (KeyError, TypeError): room_members[user_id] = Member(membership=membership) async def get_member_profiles( self, room_id: RoomID, memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), ) -> dict[UserID, Member]: try: return { user_id: member for user_id, member in self.members[room_id].items() if member.membership in memberships } except KeyError: return {} async def set_members( self, room_id: RoomID, members: dict[UserID, Member | MemberStateEventContent], only_membership: Membership | None = None, ) -> None: old_members = {} if only_membership is not None: old_members = { user_id: member for user_id, member in self.members.get(room_id, {}).items() if member.membership != only_membership } self.members[room_id] = { user_id: ( member if isinstance(member, Member) else Member( membership=member.membership, avatar_url=member.avatar_url, displayname=member.displayname, ) ) for user_id, member in members.items() } self.members[room_id].update(old_members) self.full_member_list[room_id] = True async def has_full_member_list(self, room_id: RoomID) -> bool: return self.full_member_list.get(room_id, False) async def has_power_levels_cached(self, room_id: RoomID) -> bool: return room_id in self.power_levels async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None: return self.power_levels.get(room_id) async def set_power_levels( self, room_id: RoomID, content: PowerLevelStateEventContent | dict[str, Any] ) -> None: if not isinstance(content, PowerLevelStateEventContent): content = PowerLevelStateEventContent.deserialize(content) self.power_levels[room_id] = content async def has_encryption_info_cached(self, room_id: RoomID) -> bool: return room_id in self.encryption async def is_encrypted(self, room_id: RoomID) -> bool | None: try: return self.encryption[room_id] is not None except KeyError: return None async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent: return self.encryption.get(room_id) async def set_encryption_info( self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: if not isinstance(content, RoomEncryptionStateEventContent): content = RoomEncryptionStateEventContent.deserialize(content) self.encryption[room_id] = content python-0.20.7/mautrix/client/state_store/sync.py000066400000000000000000000020371473573527000217710ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from abc import ABC, abstractmethod from mautrix.types import SyncToken class SyncStore(ABC): """SyncStore persists information used by /sync.""" @abstractmethod async def put_next_batch(self, next_batch: SyncToken) -> None: pass @abstractmethod async def get_next_batch(self) -> SyncToken: pass class MemorySyncStore(SyncStore): """MemorySyncStore is a :class:`SyncStore` implementation that stores the data in memory.""" def __init__(self, next_batch: SyncToken | None = None) -> None: self._next_batch: SyncToken | None = next_batch async def put_next_batch(self, next_batch: SyncToken) -> None: self._next_batch = next_batch async def get_next_batch(self) -> SyncToken: return self._next_batch python-0.20.7/mautrix/client/state_store/tests/000077500000000000000000000000001473573527000216035ustar00rootroot00000000000000python-0.20.7/mautrix/client/state_store/tests/__init__.py000066400000000000000000000000001473573527000237020ustar00rootroot00000000000000python-0.20.7/mautrix/client/state_store/tests/joined_members.json000066400000000000000000000017131473573527000254620ustar00rootroot00000000000000{ "!telegram-group:example.com": { "@telegrambot:example.com": { "avatar_url": "mxc://maunium.net/tJCRmUyJDsgRNgqhOgoiHWbX", "display_name": "Telegram bridge bot" }, "@telegram_84359547:example.com": { "avatar_url": "mxc://example.com/321cba", "display_name": "tulir (Telegram)" }, "@telegram_5647382910:example.com": { "avatar_url": "mxc://example.com/o3sSOTEE6F7aSZELL8PfP3N7", "display_name": "Tulir #4 (Telegram)" }, "@tulir:example.com": { "avatar_url": "mxc://example.com/123abc", "display_name": "tulir" }, "@telegram_374880943:example.com": { "avatar_url": "mxc://example.com/9yk1G8ZmSHbvP4JI2IFbfSAn", "display_name": "Mautrix (Telegram)" }, "@telegram_987654321:example.com": { "avatar_url": "mxc://example.com/TTi6q0SABMNNNlC38014KMBV", "display_name": "Mautrix / Testing (Telegram)" }, "@telegram_123456789:example.com": { "avatar_url": null, "display_name": "Tulir #3 (Telegram)" } } } python-0.20.7/mautrix/client/state_store/tests/members.json000066400000000000000000000117751473573527000241430ustar00rootroot00000000000000{ "!telegram-group:example.com": [ { "content": { "avatar_url": "mxc://maunium.net/tJCRmUyJDsgRNgqhOgoiHWbX", "displayname": "Telegram bridge bot", "membership": "join" }, "origin_server_ts": 1637017135583, "sender": "@telegrambot:example.com", "state_key": "@telegrambot:example.com", "type": "m.room.member", "event_id": "$ToMVBrkfMXcaMt6lrNn_AACxGtWGk9acDQplsYMjdu8" }, { "content": { "avatar_url": "mxc://example.com/321cba", "displayname": "tulir (Telegram)", "membership": "join" }, "origin_server_ts": 1637017138997, "sender": "@telegram_84359547:example.com", "state_key": "@telegram_84359547:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$tRU6IBdkgFoG9PD-PRjtKOFMPAQVMMFdU-ZfB09hm48", "prev_content": { "avatar_url": "mxc://example.com/321cba", "displayname": "tulir (Telegram)", "membership": "invite" }, "prev_sender": "@telegrambot:example.com" }, "event_id": "$CnpXeISvK8MkbzgryzlrMKKk134KLwuvdIDTLp8v1ME" }, { "content": { "avatar_url": "mxc://example.com/o3sSOTEE6F7aSZELL8PfP3N7", "displayname": "Tulir #4 (Telegram)", "membership": "join" }, "origin_server_ts": 1637017139441, "sender": "@telegram_5647382910:example.com", "state_key": "@telegram_5647382910:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$_yUbIfhy46AVZ8Aj3bEoGUvzJ6xC38mbFIZ79JOrIhw", "prev_content": { "avatar_url": "mxc://example.com/o3sSOTEE6F7aSZELL8PfP3N7", "displayname": "Tulir #4 (Telegram)", "membership": "invite" }, "prev_sender": "@telegrambot:example.com" }, "event_id": "$n_dhyHuytT4lwJ1rzdkGeRIAaCQNtdM-6KhI5rQcVSw" }, { "content": { "avatar_url": "mxc://example.com/123abc", "displayname": "tulir", "membership": "join" }, "origin_server_ts": 1637017136989, "sender": "@tulir:example.com", "state_key": "@tulir:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$jMwIGQPtxaRLfCfUm0XhYRgWK1WsVRnI-2shdpLpJqg", "prev_content": { "avatar_url": "mxc://example.com/123abc", "displayname": "tulir", "fi.mau.will_auto_accept": true, "is_direct": false, "membership": "invite" }, "prev_sender": "@telegrambot:example.com" }, "event_id": "$xv9FUkljVHVY2aHvQgtq9yPAo8YSD6IX80bs2eeLKQo" }, { "content": { "avatar_url": "mxc://example.com/9yk1G8ZmSHbvP4JI2IFbfSAn", "displayname": "Mautrix (Telegram)", "membership": "join" }, "origin_server_ts": 1637017140015, "sender": "@telegram_374880943:example.com", "state_key": "@telegram_374880943:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$phLlRQ3dHd7rVOZmfMWVscQCnwAT3k-CAhh8aByHJOQ", "prev_content": { "avatar_url": "mxc://example.com/9yk1G8ZmSHbvP4JI2IFbfSAn", "displayname": "Mautrix (Telegram)", "membership": "invite" }, "prev_sender": "@telegrambot:example.com" }, "event_id": "$agGmH9w9I-KgBnOoUBXp6RmXxemGaYwN3PnwOgIvmEk" }, { "content": { "avatar_url": "mxc://example.com/TTi6q0SABMNNNlC38014KMBV", "displayname": "Mautrix / Testing (Telegram)", "membership": "join" }, "origin_server_ts": 1637017141173, "sender": "@telegram_987654321:example.com", "state_key": "@telegram_987654321:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$VYToMy5ap8j2no4bL5neeYrDqAOPnt2JpY0yfdbnbWk", "prev_content": { "avatar_url": "mxc://example.com/TTi6q0SABMNNNlC38014KMBV", "displayname": "Mautrix / Testing (Telegram)", "membership": "invite" }, "prev_sender": "@telegrambot:example.com" }, "event_id": "$DZxbFRtQ8oKgw9hGiQRCpFLV-PF20uhS6DMDfFjOx38" }, { "content": { "displayname": "Tulir #3 (Telegram)", "membership": "join" }, "origin_server_ts": 1637017138285, "sender": "@telegram_123456789:example.com", "state_key": "@telegram_123456789:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$WLvGliXjUCuXxJV1wqbZd7suumC3HiOf-pVP94_JUAE", "prev_content": { "displayname": "Tulir #3 (Telegram)", "membership": "invite" }, "prev_sender": "@telegrambot:example.com" }, "event_id": "$DW23zfhHKNBTYJNq40R7leaIZAPzf54eyWgxmeQ7WYc" }, { "content": { "membership": "leave" }, "origin_server_ts": 1637017153623, "sender": "@telegram_476034259:example.com", "state_key": "@telegram_476034259:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$2yMikO4Th2GYcegL-AWLVMbH-5gBw7ru4pJIZG_XsHc" }, "event_id": "$r8T_b9upDSscRZqulUOJSQeK80RnN7LKDRSvzXflxlQ" }, { "content": { "membership": "ban", "avatar_url": "mxc://maunium.net/456def", "displayname": "WhatsApp bridge bot" }, "origin_server_ts": 1637406936936, "sender": "@tulir:example.com", "state_key": "@whatsappbot:example.com", "type": "m.room.member", "event_id": "$9zhRDMhfvo37qFlRYZ45H0afcx9KSyGkn-gr72S6TsU" } ] } python-0.20.7/mautrix/client/state_store/tests/new_state.json000066400000000000000000000133021473573527000244660ustar00rootroot00000000000000{ "!telegram-group:example.com": [ { "content": { "algorithm": "m.megolm.v1.aes-sha2" }, "origin_server_ts": 1637017136383, "sender": "@telegrambot:example.com", "state_key": "", "type": "m.room.encryption", "event_id": "$bFRKXissP1baFP9XmqnPqhNF5qqQPNFaxzToadnVGzA" }, { "content": { "ban": 50, "events": { "m.room.avatar": 0, "m.room.encryption": 50, "m.room.history_visibility": 75, "m.room.name": 0, "m.room.pinned_events": 0, "m.room.power_levels": 75, "m.room.tombstone": 99, "m.room.topic": 0, "m.sticker": 0 }, "events_default": 0, "invite": 0, "kick": 50, "redact": 50, "state_default": 50, "users": { "@telegrambot:example.com": 100, "@tulir:example.com": 95 }, "users_default": 0 }, "origin_server_ts": 1637017135659, "sender": "@telegrambot:example.com", "state_key": "", "type": "m.room.power_levels", "event_id": "$98TzHn07MQ_fogX-hzYQfsb-1KlTL2XMVIOe3vvWY3g" }, { "content": { "avatar_url": "mxc://example.com/123abc", "displayname": "tulir", "fi.mau.will_auto_accept": true, "is_direct": false, "membership": "invite" }, "origin_server_ts": 1637017136716, "sender": "@telegrambot:example.com", "state_key": "@tulir:example.com", "type": "m.room.member", "event_id": "$jMwIGQPtxaRLfCfUm0XhYRgWK1WsVRnI-2shdpLpJqg" }, { "content": { "avatar_url": "mxc://example.com/123abc", "displayname": "tulir", "membership": "join" }, "origin_server_ts": 1637017136989, "sender": "@tulir:example.com", "state_key": "@tulir:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$jMwIGQPtxaRLfCfUm0XhYRgWK1WsVRnI-2shdpLpJqg", "prev_content": { "avatar_url": "mxc://example.com/123abc", "displayname": "tulir", "fi.mau.will_auto_accept": true, "is_direct": false, "membership": "invite" }, "prev_sender": "@telegrambot:example.com" }, "event_id": "$xv9FUkljVHVY2aHvQgtq9yPAo8YSD6IX80bs2eeLKQo" }, { "content": { "avatar_url": "mxc://example.com/321cba", "displayname": "tulir (Telegram)", "membership": "invite" }, "origin_server_ts": 1637017138734, "sender": "@telegrambot:example.com", "state_key": "@telegram_84359547:example.com", "type": "m.room.member", "event_id": "$tRU6IBdkgFoG9PD-PRjtKOFMPAQVMMFdU-ZfB09hm48" }, { "content": { "avatar_url": "mxc://example.com/321cba", "displayname": "tulir (Telegram)", "membership": "join" }, "origin_server_ts": 1637017138997, "sender": "@telegram_84359547:example.com", "state_key": "@telegram_84359547:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$tRU6IBdkgFoG9PD-PRjtKOFMPAQVMMFdU-ZfB09hm48", "prev_content": { "avatar_url": "mxc://example.com/321cba", "displayname": "tulir (Telegram)", "membership": "invite" }, "prev_sender": "@telegrambot:example.com", "age": 1234 }, "event_id": "$CnpXeISvK8MkbzgryzlrMKKk134KLwuvdIDTLp8v1ME" } ], "!unencrypted-room:example.com": [ { "content": { "ban": 50, "events": { "m.room.avatar": 50, "m.room.canonical_alias": 50, "m.room.encryption": 100, "m.room.history_visibility": 100, "m.room.name": 50, "m.room.power_levels": 100, "m.room.server_acl": 100, "m.room.tombstone": 100 }, "events_default": 0, "historical": 100, "invite": 0, "kick": 50, "redact": 50, "state_default": 50, "users": { "@tulir:example.com": 9001 }, "users_default": 0 }, "origin_server_ts": 1637343767626, "sender": "@tulir:example.com", "state_key": "", "type": "m.room.power_levels", "event_id": "$TMUQddqg0CzcNviUBPT38_3i51-SC86p2w3oADNhyZA" }, { "content": { "avatar_url": "mxc://example.com/987zyx", "displayname": "Maubot", "membership": "invite" }, "origin_server_ts": 1637343773909, "sender": "@tulir:example.com", "state_key": "@maubot:example.com", "type": "m.room.member", "event_id": "$9XvlRmVmzCdNna0uvgm4NeMiVKJLmd-8c_tX10gmg4k" }, { "content": { "avatar_url": "mxc://example.com/987zyx", "displayname": "Maubot", "membership": "join" }, "origin_server_ts": 1637343773991, "sender": "@maubot:example.com", "state_key": "@maubot:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$9XvlRmVmzCdNna0uvgm4NeMiVKJLmd-8c_tX10gmg4k", "prev_content": { "avatar_url": "mxc://example.com/987zyx", "displayname": "Maubot", "membership": "invite" }, "prev_sender": "@tulir:example.com", "age": 1234 }, "event_id": "$XZMALYE9N30jP5_x8S1IWFlzt5F6tZB--W2kkoKGJDM" }, { "content": { "avatar_url": "mxc://maunium.net/456def", "displayname": "WhatsApp bridge bot", "membership": "invite" }, "origin_server_ts": 1637406933363, "sender": "@tulir:example.com", "state_key": "@whatsappbot:example.com", "type": "m.room.member", "event_id": "$jODwwttaZd-flc2eyh0JirR0p6EtDkX5BaJmPfaXcrc" }, { "content": { "membership": "ban", "avatar_url": "mxc://maunium.net/456def", "displayname": "WhatsApp bridge bot" }, "origin_server_ts": 1637406936936, "sender": "@tulir:example.com", "state_key": "@whatsappbot:example.com", "type": "m.room.member", "unsigned": { "replaces_state": "$jODwwttaZd-flc2eyh0JirR0p6EtDkX5BaJmPfaXcrc", "prev_content": { "avatar_url": "mxc://maunium.net/456def", "displayname": "WhatsApp bridge bot", "membership": "invite" }, "prev_sender": "@tulir:example.com", "age": 65 }, "event_id": "$9zhRDMhfvo37qFlRYZ45H0afcx9KSyGkn-gr72S6TsU" } ] } python-0.20.7/mautrix/client/state_store/tests/store_test.py000066400000000000000000000143171473573527000243560ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import AsyncContextManager, AsyncIterator, Callable from contextlib import asynccontextmanager import json import os import pathlib import random import string import time import asyncpg import pytest from mautrix.types import EncryptionAlgorithm, Member, Membership, RoomID, StateEvent, UserID from mautrix.util.async_db import Database from .. import MemoryStateStore, StateStore from ..asyncpg import PgStateStore @asynccontextmanager async def async_postgres_store() -> AsyncIterator[PgStateStore]: try: pg_url = os.environ["MEOW_TEST_PG_URL"] except KeyError: pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)") return conn: asyncpg.Connection = await asyncpg.connect(pg_url) schema_name = "".join(random.choices(string.ascii_lowercase, k=8)) schema_name = f"test_schema_{schema_name}_{int(time.time())}" await conn.execute(f"CREATE SCHEMA {schema_name}") db = Database.create( pg_url, upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}}, ) store = PgStateStore(db) await db.start() yield store await db.stop() await conn.execute(f"DROP SCHEMA {schema_name} CASCADE") await conn.close() @asynccontextmanager async def async_sqlite_store() -> AsyncIterator[PgStateStore]: db = Database.create( "sqlite::memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1} ) store = PgStateStore(db) await db.start() yield store await db.stop() @asynccontextmanager async def memory_store() -> AsyncIterator[MemoryStateStore]: yield MemoryStateStore() @pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store]) async def store(request) -> AsyncIterator[StateStore]: param: Callable[[], AsyncContextManager[StateStore]] = request.param async with param() as state_store: yield state_store def read_state_file(request, file) -> dict[RoomID, list[StateEvent]]: path = pathlib.Path(request.node.fspath).with_name(file) with path.open() as fp: content = json.load(fp) return { room_id: [StateEvent.deserialize({**evt, "room_id": room_id}) for evt in events] for room_id, events in content.items() } async def store_room_state(request, store: StateStore) -> None: room_state_changes = read_state_file(request, "new_state.json") for events in room_state_changes.values(): for evt in events: await store.update_state(evt) async def get_all_members(request, store: StateStore) -> None: room_state = read_state_file(request, "members.json") for room_id, member_events in room_state.items(): await store.set_members(room_id, {evt.state_key: evt.content for evt in member_events}) async def get_joined_members(request, store: StateStore) -> None: path = pathlib.Path(request.node.fspath).with_name("joined_members.json") with path.open() as fp: content = json.load(fp) for room_id, members in content.items(): parsed_members = { user_id: Member( membership=Membership.JOIN, displayname=member.get("display_name", ""), avatar_url=member.get("avatar_url", ""), ) for user_id, member in members.items() } await store.set_members(room_id, parsed_members, only_membership=Membership.JOIN) async def test_basic(store: StateStore) -> None: room_id = RoomID("!foo:example.com") user_id = UserID("@tulir:example.com") assert not await store.is_encrypted(room_id) assert not await store.is_joined(room_id, user_id) await store.joined(room_id, user_id) assert await store.is_joined(room_id, user_id) assert not await store.has_encryption_info_cached(RoomID("!unknown-room:example.com")) assert await store.is_encrypted(RoomID("!unknown-room:example.com")) is None async def test_basic_updated(request, store: StateStore) -> None: await store_room_state(request, store) test_group = RoomID("!telegram-group:example.com") assert await store.is_encrypted(test_group) assert (await store.get_encryption_info(test_group)).algorithm == EncryptionAlgorithm.MEGOLM_V1 assert not await store.is_encrypted(RoomID("!unencrypted-room:example.com")) async def test_updates(request, store: StateStore) -> None: await store_room_state(request, store) room_id = RoomID("!telegram-group:example.com") initial_members = {"@tulir:example.com", "@telegram_84359547:example.com"} joined_members = initial_members | { "@telegrambot:example.com", "@telegram_5647382910:example.com", "@telegram_374880943:example.com", "@telegram_987654321:example.com", "@telegram_123456789:example.com", } left_members = {"@telegram_476034259:example.com", "@whatsappbot:example.com"} full_members = joined_members | left_members any_membership = ( Membership.JOIN, Membership.INVITE, Membership.LEAVE, Membership.BAN, Membership.KNOCK, ) leave_memberships = (Membership.BAN, Membership.LEAVE) assert set(await store.get_members(room_id)) == initial_members await get_all_members(request, store) assert set(await store.get_members(room_id)) == joined_members assert set(await store.get_members(room_id, memberships=any_membership)) == full_members await get_joined_members(request, store) assert set(await store.get_members(room_id)) == joined_members assert set(await store.get_members(room_id, memberships=any_membership)) == full_members assert set(await store.get_members(room_id, memberships=leave_memberships)) == left_members assert set( await store.get_members_filtered( room_id, memberships=leave_memberships, not_id="", not_prefix="@telegram_", not_suffix=":example.com", ) ) == {"@whatsappbot:example.com"} python-0.20.7/mautrix/client/store_updater.py000066400000000000000000000257111473573527000213450ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import asyncio from mautrix.errors import MForbidden, MNotFound from mautrix.types import ( JSON, EventID, EventType, Member, Membership, MemberStateEventContent, RoomAlias, RoomID, StateEvent, StateEventContent, SyncToken, UserID, ) from .api import ClientAPI from .state_store import StateStore class StoreUpdatingAPI(ClientAPI): """ StoreUpdatingAPI is a wrapper around the medium-level ClientAPI that optionally updates a client state store with outgoing state events (after they're successfully sent). """ state_store: StateStore | None def __init__(self, *args, state_store: StateStore | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) self.state_store = state_store async def join_room_by_id( self, room_id: RoomID, third_party_signed: JSON = None, extra_content: dict[str, JSON] | None = None, ) -> RoomID: room_id = await super().join_room_by_id( room_id, third_party_signed=third_party_signed, extra_content=extra_content ) if room_id and not extra_content and self.state_store: await self.state_store.set_membership(room_id, self.mxid, Membership.JOIN) return room_id async def join_room( self, room_id_or_alias: RoomID | RoomAlias, servers: list[str] | None = None, third_party_signed: JSON = None, max_retries: int = 4, ) -> RoomID: room_id = await super().join_room( room_id_or_alias, servers, third_party_signed, max_retries ) if room_id and self.state_store: await self.state_store.set_membership(room_id, self.mxid, Membership.JOIN) return room_id async def leave_room( self, room_id: RoomID, reason: str | None = None, extra_content: dict[str, JSON] | None = None, raise_not_in_room: bool = False, ) -> None: await super().leave_room(room_id, reason, extra_content, raise_not_in_room) if not extra_content and self.state_store: await self.state_store.set_membership(room_id, self.mxid, Membership.LEAVE) async def knock_room( self, room_id_or_alias: RoomID | RoomAlias, reason: str | None = None, servers: list[str] | None = None, ) -> RoomID: room_id = await super().knock_room(room_id_or_alias, reason, servers) if room_id and self.state_store: await self.state_store.set_membership(room_id, self.mxid, Membership.KNOCK) return room_id async def invite_user( self, room_id: RoomID, user_id: UserID, reason: str | None = None, extra_content: dict[str, JSON] | None = None, ) -> None: await super().invite_user(room_id, user_id, reason, extra_content=extra_content) if not extra_content and self.state_store: await self.state_store.set_membership(room_id, user_id, Membership.INVITE) async def kick_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: await super().kick_user(room_id, user_id, reason=reason, extra_content=extra_content) if not extra_content and self.state_store: await self.state_store.set_membership(room_id, user_id, Membership.LEAVE) async def ban_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: await super().ban_user(room_id, user_id, reason=reason, extra_content=extra_content) if not extra_content and self.state_store: await self.state_store.set_membership(room_id, user_id, Membership.BAN) async def unban_user( self, room_id: RoomID, user_id: UserID, reason: str = "", extra_content: dict[str, JSON] | None = None, ) -> None: await super().unban_user(room_id, user_id, reason=reason, extra_content=extra_content) if self.state_store: await self.state_store.set_membership(room_id, user_id, Membership.LEAVE) async def get_state(self, room_id: RoomID) -> list[StateEvent]: state = await super().get_state(room_id) if self.state_store: update_members = self.state_store.set_members( room_id, {evt.state_key: evt.content for evt in state if evt.type == EventType.ROOM_MEMBER}, ) await asyncio.gather( update_members, *[ self.state_store.update_state(evt) for evt in state if evt.type != EventType.ROOM_MEMBER ], ) return state async def create_room(self, *args, **kwargs) -> RoomID: room_id = await super().create_room(*args, **kwargs) if self.state_store: invitee_membership = Membership.INVITE if kwargs.get("beeper_auto_join_invites"): invitee_membership = Membership.JOIN for user_id in kwargs.get("invitees", []): await self.state_store.set_membership(room_id, user_id, invitee_membership) for evt in kwargs.get("initial_state", []): await self.state_store.update_state( StateEvent( type=EventType.find(evt["type"], t_class=EventType.Class.STATE), room_id=room_id, event_id=EventID("$fake-create-id"), sender=self.mxid, state_key=evt.get("state_key", ""), timestamp=0, content=evt["content"], ) ) return room_id async def send_state_event( self, room_id: RoomID, event_type: EventType, content: StateEventContent | dict[str, JSON], state_key: str = "", **kwargs, ) -> EventID: event_id = await super().send_state_event( room_id, event_type, content, state_key, **kwargs ) if self.state_store: fake_event = StateEvent( type=event_type, room_id=room_id, event_id=event_id, sender=self.mxid, state_key=state_key, timestamp=0, content=content, ) await self.state_store.update_state(fake_event) return event_id async def get_state_event( self, room_id: RoomID, event_type: EventType, state_key: str = "" ) -> StateEventContent: event = await super().get_state_event(room_id, event_type, state_key) if self.state_store: fake_event = StateEvent( type=event_type, room_id=room_id, event_id=EventID(""), sender=UserID(""), state_key=state_key, timestamp=0, content=event, ) await self.state_store.update_state(fake_event) return event async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]: members = await super().get_joined_members(room_id) if self.state_store: await self.state_store.set_members(room_id, members, only_membership=Membership.JOIN) return members async def get_members( self, room_id: RoomID, at: SyncToken | None = None, membership: Membership | None = None, not_membership: Membership | None = None, ) -> list[StateEvent]: member_events = await super().get_members(room_id, at, membership, not_membership) if self.state_store and not_membership != Membership.JOIN: await self.state_store.set_members( room_id, {evt.state_key: evt.content for evt in member_events}, only_membership=membership, ) return member_events async def fill_member_event( self, room_id: RoomID, user_id: UserID, content: MemberStateEventContent, ) -> MemberStateEventContent | None: """ Fill a membership event content that is going to be sent in :meth:`send_member_event`. This is used to set default fields like the displayname and avatar, which are usually set by the server in the sugar membership endpoints like /join and /invite, but are not set automatically when sending member events manually. This implementation in StoreUpdatingAPI will first try to call the default implementation (which calls :attr:`fill_member_event_callback`). If that doesn't return anything, this will try to get the profile from the current member event, and then fall back to fetching the global profile from the server. Args: room_id: The room where the member event is going to be sent. user_id: The user whose membership is changing. content: The new member event content. Returns: The filled member event content. """ callback_content = await super().fill_member_event(room_id, user_id, content) if callback_content is not None: self.log.trace("Filled new member event for %s using callback", user_id) return callback_content if content.displayname is None and content.avatar_url is None: existing_member = await self.state_store.get_member(room_id, user_id) if existing_member is not None: self.log.trace( "Found existing member event %s to fill new member event for %s", existing_member, user_id, ) content.displayname = existing_member.displayname content.avatar_url = existing_member.avatar_url return content try: profile = await self.get_profile(user_id) except (MNotFound, MForbidden): profile = None if profile: self.log.trace( "Fetched profile %s to fill new member event of %s", profile, user_id ) content.displayname = profile.displayname content.avatar_url = profile.avatar_url return content else: self.log.trace("Didn't find profile info to fill new member event of %s", user_id) else: self.log.trace( "Member event for %s already contains displayname or avatar, not re-filling", user_id, ) return None python-0.20.7/mautrix/client/syncer.py000066400000000000000000000444111473573527000177660ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Awaitable, Callable, Type, TypeVar from abc import ABC, abstractmethod from contextlib import suppress from enum import Enum, Flag, auto import asyncio import time from mautrix.errors import MUnknownToken from mautrix.types import ( JSON, AccountDataEvent, BaseMessageEventContentFuncs, DeviceLists, DeviceOTKCount, EphemeralEvent, Event, EventType, Filter, FilterID, GenericEvent, MessageEvent, PresenceState, SerializerError, StateEvent, StrippedStateEvent, SyncToken, ToDeviceEvent, UserID, ) from mautrix.util import background_task from mautrix.util.logging import TraceLogger from . import dispatcher from .state_store import MemorySyncStore, SyncStore EventHandler = Callable[[Event], Awaitable[None]] T = TypeVar("T", bound=Event) class SyncStream(Flag): INTERNAL = auto() JOINED_ROOM = auto() INVITED_ROOM = auto() LEFT_ROOM = auto() TIMELINE = auto() STATE = auto() EPHEMERAL = auto() ACCOUNT_DATA = auto() TO_DEVICE = auto() class InternalEventType(Enum): SYNC_STARTED = auto() SYNC_ERRORED = auto() SYNC_SUCCESSFUL = auto() SYNC_STOPPED = auto() JOIN = auto() PROFILE_CHANGE = auto() INVITE = auto() REJECT_INVITE = auto() DISINVITE = auto() LEAVE = auto() KICK = auto() BAN = auto() UNBAN = auto() DEVICE_LISTS = auto() DEVICE_OTK_COUNT = auto() class Syncer(ABC): loop: asyncio.AbstractEventLoop log: TraceLogger mxid: UserID global_event_handlers: list[tuple[EventHandler, bool]] event_handlers: dict[EventType | InternalEventType, list[tuple[EventHandler, bool]]] dispatchers: dict[Type[dispatcher.Dispatcher], dispatcher.Dispatcher] syncing_task: asyncio.Task | None ignore_initial_sync: bool ignore_first_sync: bool presence: PresenceState sync_store: SyncStore def __init__(self, sync_store: SyncStore) -> None: self.global_event_handlers = [] self.event_handlers = {} self.dispatchers = {} self.syncing_task = None self.ignore_initial_sync = False self.ignore_first_sync = False self.presence = PresenceState.ONLINE self.sync_store = sync_store or MemorySyncStore() def on( self, var: EventHandler | EventType | InternalEventType ) -> EventHandler | Callable[[EventHandler], EventHandler]: """ Add a new event handler. This method is for decorator usage. Use :meth:`add_event_handler` if you don't use a decorator. Args: var: Either the handler function or the event type to handle. Returns: If ``var`` was the handler function, the handler function is returned. If ``var`` was an event type, a function that takes the handler function as an argument is returned. Examples: >>> from mautrix.client import Client >>> cli = Client(...) >>> @cli.on(EventType.ROOM_MESSAGE) >>> def handler(event: MessageEvent) -> None: ... pass """ if isinstance(var, (EventType, InternalEventType)): def decorator(func: EventHandler) -> EventHandler: self.add_event_handler(var, func) return func return decorator else: self.add_event_handler(EventType.ALL, var) return var def add_dispatcher(self, dispatcher_type: Type[dispatcher.Dispatcher]) -> None: if dispatcher_type in self.dispatchers: return self.log.debug(f"Enabling {dispatcher_type.__name__}") self.dispatchers[dispatcher_type] = dispatcher_type(self) self.dispatchers[dispatcher_type].register() def remove_dispatcher(self, dispatcher_type: Type[dispatcher.Dispatcher]) -> None: if dispatcher_type not in self.dispatchers: return self.log.debug(f"Disabling {dispatcher_type.__name__}") self.dispatchers[dispatcher_type].unregister() del self.dispatchers[dispatcher_type] def add_event_handler( self, event_type: InternalEventType | EventType, handler: EventHandler, wait_sync: bool = False, ) -> None: """ Add a new event handler. Args: event_type: The event type to add. If not specified, the handler will be called for all event types. handler: The handler function to add. wait_sync: Whether or not the handler should be awaited before the next sync request. """ if not isinstance(event_type, (EventType, InternalEventType)): raise ValueError("Invalid event type") if event_type == EventType.ALL: self.global_event_handlers.append((handler, wait_sync)) else: self.event_handlers.setdefault(event_type, []).append((handler, wait_sync)) def remove_event_handler( self, event_type: EventType | InternalEventType, handler: EventHandler ) -> None: """ Remove an event handler. Args: handler: The handler function to remove. event_type: The event type to remove the handler function from. """ if not isinstance(event_type, (EventType, InternalEventType)): raise ValueError("Invalid event type") try: handler_list = ( self.global_event_handlers if event_type == EventType.ALL else self.event_handlers[event_type] ) except KeyError: # No handlers for this event type registered return # FIXME this is a bit hacky with suppress(ValueError): handler_list.remove((handler, True)) with suppress(ValueError): handler_list.remove((handler, False)) if len(handler_list) == 0 and event_type != EventType.ALL: del self.event_handlers[event_type] def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asyncio.Task]: """ Send the given event to all applicable event handlers. Args: event: The event to send. source: The sync stream the event was received in. """ if event is None: return [] if isinstance(event.content, BaseMessageEventContentFuncs): event.content.trim_reply_fallback() if getattr(event, "state_key", None) is not None: event.type = event.type.with_class(EventType.Class.STATE) elif source & SyncStream.EPHEMERAL: event.type = event.type.with_class(EventType.Class.EPHEMERAL) elif source & SyncStream.ACCOUNT_DATA: event.type = event.type.with_class(EventType.Class.ACCOUNT_DATA) elif source & SyncStream.TO_DEVICE: event.type = event.type.with_class(EventType.Class.TO_DEVICE) else: event.type = event.type.with_class(EventType.Class.MESSAGE) setattr(event, "source", source) return self.dispatch_manual_event(event.type, event, include_global_handlers=True) async def _catch_errors(self, handler: EventHandler, data: Any) -> None: try: await handler(data) except Exception: self.log.exception("Failed to run handler") def dispatch_manual_event( self, event_type: EventType | InternalEventType, data: Any, include_global_handlers: bool = False, force_synchronous: bool = False, ) -> list[asyncio.Task]: handlers = self.event_handlers.get(event_type, []) if include_global_handlers: handlers = self.global_event_handlers + handlers tasks = [] for handler, wait_sync in handlers: if force_synchronous or wait_sync: tasks.append(asyncio.create_task(self._catch_errors(handler, data))) else: background_task.create(self._catch_errors(handler, data)) return tasks async def run_internal_event( self, event_type: InternalEventType, custom_type: Any = None, **kwargs: Any ) -> None: kwargs["source"] = SyncStream.INTERNAL tasks = self.dispatch_manual_event( event_type, custom_type if custom_type is not None else kwargs, include_global_handlers=False, ) await asyncio.gather(*tasks) def dispatch_internal_event( self, event_type: InternalEventType, custom_type: Any = None, **kwargs: Any ) -> list[asyncio.Task]: kwargs["source"] = SyncStream.INTERNAL return self.dispatch_manual_event( event_type, custom_type if custom_type is not None else kwargs, include_global_handlers=False, ) def _try_deserialize(self, type: Type[T], data: JSON) -> T | GenericEvent: try: return type.deserialize(data) except SerializerError as e: self.log.trace("Deserialization error traceback", exc_info=True) self.log.warning(f"Failed to deserialize {data} into {type.__name__}: {e}") try: return GenericEvent.deserialize(data) except SerializerError: return None def handle_sync(self, data: JSON) -> list[asyncio.Task]: """ Handle a /sync object. Args: data: The data from a /sync request. """ tasks = [] otk_count = data.get("device_one_time_keys_count", {}) tasks += self.dispatch_internal_event( InternalEventType.DEVICE_OTK_COUNT, custom_type=DeviceOTKCount( curve25519=otk_count.get("curve25519", 0), signed_curve25519=otk_count.get("signed_curve25519", 0), ), ) device_lists = data.get("device_lists", {}) tasks += self.dispatch_internal_event( InternalEventType.DEVICE_LISTS, custom_type=DeviceLists( changed=device_lists.get("changed", []), left=device_lists.get("left", []), ), ) for raw_event in data.get("account_data", {}).get("events", []): tasks += self.dispatch_event( self._try_deserialize(AccountDataEvent, raw_event), source=SyncStream.ACCOUNT_DATA ) for raw_event in data.get("ephemeral", {}).get("events", []): tasks += self.dispatch_event( self._try_deserialize(EphemeralEvent, raw_event), source=SyncStream.EPHEMERAL ) for raw_event in data.get("to_device", {}).get("events", []): tasks += self.dispatch_event( self._try_deserialize(ToDeviceEvent, raw_event), source=SyncStream.TO_DEVICE ) rooms = data.get("rooms", {}) for room_id, room_data in rooms.get("join", {}).items(): for raw_event in room_data.get("state", {}).get("events", []): raw_event["room_id"] = room_id tasks += self.dispatch_event( self._try_deserialize(StateEvent, raw_event), source=SyncStream.JOINED_ROOM | SyncStream.STATE, ) for raw_event in room_data.get("timeline", {}).get("events", []): raw_event["room_id"] = room_id tasks += self.dispatch_event( self._try_deserialize(Event, raw_event), source=SyncStream.JOINED_ROOM | SyncStream.TIMELINE, ) for raw_event in room_data.get("ephemeral", {}).get("events", []): raw_event["room_id"] = room_id tasks += self.dispatch_event( self._try_deserialize(EphemeralEvent, raw_event), source=SyncStream.JOINED_ROOM | SyncStream.EPHEMERAL, ) for room_id, room_data in rooms.get("invite", {}).items(): events: list[dict[str, JSON]] = room_data.get("invite_state", {}).get("events", []) for raw_event in events: raw_event["room_id"] = room_id raw_invite = next( raw_event for raw_event in events if raw_event.get("type", "") == "m.room.member" and raw_event.get("state_key", "") == self.mxid ) # These aren't required by the spec, so make sure they're set raw_invite.setdefault("event_id", None) raw_invite.setdefault("origin_server_ts", int(time.time() * 1000)) invite = self._try_deserialize(StateEvent, raw_invite) invite.unsigned.invite_room_state = [ self._try_deserialize(StrippedStateEvent, raw_event) for raw_event in events if raw_event != raw_invite ] tasks += self.dispatch_event(invite, source=SyncStream.INVITED_ROOM | SyncStream.STATE) for room_id, room_data in rooms.get("leave", {}).items(): for raw_event in room_data.get("timeline", {}).get("events", []): if "state_key" in raw_event: raw_event["room_id"] = room_id tasks += self.dispatch_event( self._try_deserialize(StateEvent, raw_event), source=SyncStream.LEFT_ROOM | SyncStream.TIMELINE, ) return tasks def start(self, filter_data: FilterID | Filter | None) -> asyncio.Future: """ Start syncing with the server. Can be stopped with :meth:`stop`. Args: filter_data: The filter data or filter ID to use for syncing. """ if self.syncing_task is not None: self.syncing_task.cancel() self.syncing_task = asyncio.create_task(self._try_start(filter_data)) return self.syncing_task async def _try_start(self, filter_data: FilterID | Filter | None) -> None: try: if isinstance(filter_data, Filter): filter_data = await self.create_filter(filter_data) await self._start(filter_data) except asyncio.CancelledError: self.log.debug("Syncing cancelled") except Exception as e: self.log.critical("Fatal error while syncing", exc_info=True) await self.run_internal_event(InternalEventType.SYNC_STOPPED, error=e) return except BaseException as e: self.log.warning( f"Syncing stopped with unexpected {e.__class__.__name__}", exc_info=True ) raise else: self.log.debug("Syncing stopped without exception") await self.run_internal_event(InternalEventType.SYNC_STOPPED, error=None) async def _start(self, filter_id: FilterID | None) -> None: fail_sleep = 5 is_first = True self.log.debug("Starting syncing") next_batch = await self.sync_store.get_next_batch() await self.run_internal_event(InternalEventType.SYNC_STARTED) timeout = 30 while True: current_batch = next_batch start = time.monotonic() try: data = await self.sync( since=current_batch, filter_id=filter_id, set_presence=self.presence, timeout=timeout * 1000, ) except (asyncio.CancelledError, MUnknownToken): raise except Exception as e: self.log.warning( f"Sync request errored: {type(e).__name__}: {e}, waiting {fail_sleep}" " seconds before continuing" ) await self.run_internal_event( InternalEventType.SYNC_ERRORED, error=e, sleep_for=fail_sleep ) await asyncio.sleep(fail_sleep) if fail_sleep < 320: fail_sleep *= 2 continue if fail_sleep != 5: self.log.debug("Sync error resolved") fail_sleep = 5 duration = time.monotonic() - start if current_batch and duration > timeout + 10: self.log.warning(f"Sync request ({current_batch}) took {duration:.3f} seconds") is_initial = not current_batch data["net.maunium.mautrix"] = { "is_initial": is_initial, "is_first": is_first, } next_batch = data.get("next_batch") try: await self.sync_store.put_next_batch(next_batch) except Exception: self.log.warning("Failed to store next batch", exc_info=True) await self.run_internal_event(InternalEventType.SYNC_SUCCESSFUL, data=data) if (self.ignore_first_sync and is_first) or (self.ignore_initial_sync and is_initial): is_first = False continue is_first = False self.log.silly(f"Starting sync handling ({current_batch})") start = time.monotonic() try: tasks = self.handle_sync(data) await asyncio.gather(*tasks) except Exception: self.log.exception(f"Sync handling ({current_batch}) errored") else: self.log.silly(f"Finished sync handling ({current_batch})") finally: duration = time.monotonic() - start if duration > 10: self.log.warning( f"Sync handling ({current_batch}) took {duration:.3f} seconds" ) def stop(self) -> None: """ Stop a sync started with :meth:`start`. """ if self.syncing_task: self.syncing_task.cancel() self.syncing_task = None @abstractmethod async def create_filter(self, filter_params: Filter) -> FilterID: pass @abstractmethod async def sync( self, since: SyncToken = None, timeout: int = 30000, filter_id: FilterID = None, full_state: bool = False, set_presence: PresenceState = None, ) -> JSON: pass python-0.20.7/mautrix/crypto/000077500000000000000000000000001473573527000161475ustar00rootroot00000000000000python-0.20.7/mautrix/crypto/__init__.py000066400000000000000000000011771473573527000202660ustar00rootroot00000000000000from .account import OlmAccount from .key_share import RejectKeyShare from .sessions import InboundGroupSession, OutboundGroupSession, RatchetSafety, Session # These have to be last from .store import ( # isort: skip CryptoStore, MemoryCryptoStore, PgCryptoStateStore, PgCryptoStore, StateStore, ) from .machine import OlmMachine # isort: skip __all__ = [ "OlmAccount", "RejectKeyShare", "InboundGroupSession", "OutboundGroupSession", "Session", "CryptoStore", "MemoryCryptoStore", "PgCryptoStateStore", "PgCryptoStore", "StateStore", "OlmMachine", "attachments", ] python-0.20.7/mautrix/crypto/account.py000066400000000000000000000075341473573527000201660ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Dict, Optional, cast from datetime import datetime import olm from mautrix.types import ( DeviceID, EncryptionAlgorithm, EncryptionKeyAlgorithm, IdentityKey, SigningKey, UserID, ) from . import base from .sessions import Session class OlmAccount(olm.Account): shared: bool _signing_key: Optional[SigningKey] _identity_key: Optional[IdentityKey] def __init__(self) -> None: super().__init__() self.shared = False self._signing_key = None self._identity_key = None @property def signing_key(self) -> SigningKey: if self._signing_key is None: self._signing_key = SigningKey(self.identity_keys["ed25519"]) return self._signing_key @property def identity_key(self) -> IdentityKey: if self._identity_key is None: self._identity_key = IdentityKey(self.identity_keys["curve25519"]) return self._identity_key @property def fingerprint(self) -> str: """ Fingerprint is the base64-encoded signing key of this account, with spaces every 4 characters. This is what is used for manual device verification. """ key = self.signing_key return " ".join([key[i : i + 4] for i in range(0, len(key), 4)]) @classmethod def from_pickle(cls, pickle: bytes, passphrase: str, shared: bool) -> "OlmAccount": account = cast(OlmAccount, super().from_pickle(pickle, passphrase)) account.shared = shared account._signing_key = None account._identity_key = None return account def new_inbound_session(self, sender_key: IdentityKey, ciphertext: str) -> Session: session = olm.InboundSession(self, olm.OlmPreKeyMessage(ciphertext), sender_key) self.remove_one_time_keys(session) return Session.from_pickle( session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now() ) def new_outbound_session(self, target_key: IdentityKey, one_time_key: IdentityKey) -> Session: session = olm.OutboundSession(self, target_key, one_time_key) return Session.from_pickle( session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now() ) def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> Dict[str, Any]: device_keys = { "user_id": user_id, "device_id": device_id, "algorithms": [EncryptionAlgorithm.OLM_V1.value, EncryptionAlgorithm.MEGOLM_V1.value], "keys": { f"{algorithm}:{device_id}": key for algorithm, key in self.identity_keys.items() }, } signature = self.sign(base.canonical_json(device_keys)) device_keys["signatures"] = { user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature} } return device_keys def get_one_time_keys( self, user_id: UserID, device_id: DeviceID, current_otk_count: int ) -> Dict[str, Any]: new_count = self.max_one_time_keys // 2 - current_otk_count if new_count > 0: self.generate_one_time_keys(new_count) keys = {} for key_id, key in self.one_time_keys.get("curve25519", {}).items(): signature = self.sign(base.canonical_json({"key": key})) keys[f"{EncryptionKeyAlgorithm.SIGNED_CURVE25519}:{key_id}"] = { "key": key, "signatures": { user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature} }, } self.mark_keys_as_published() return keys python-0.20.7/mautrix/crypto/attachments/000077500000000000000000000000001473573527000204625ustar00rootroot00000000000000python-0.20.7/mautrix/crypto/attachments/__init__.py000066400000000000000000000010131473573527000225660ustar00rootroot00000000000000from .async_attachments import ( async_encrypt_attachment, async_generator_from_data, async_inplace_encrypt_attachment, ) from .attachments import ( decrypt_attachment, encrypt_attachment, encrypted_attachment_generator, inplace_encrypt_attachment, ) __all__ = [ "async_encrypt_attachment", "async_generator_from_data", "async_inplace_encrypt_attachment", "decrypt_attachment", "encrypt_attachment", "encrypted_attachment_generator", "inplace_encrypt_attachment", ] python-0.20.7/mautrix/crypto/attachments/async_attachments.py000066400000000000000000000066301473573527000245510ustar00rootroot00000000000000# Copyright © 2018, 2019 Damir Jelić # Copyright © 2019 miruka # # Permission to use, copy, modify, and/or distribute this software for # any purpose with or without fee is hereby granted, provided that the # above copyright notice and this permission notice appear in all copies. # # Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import AsyncGenerator, AsyncIterable, Iterable from functools import partial import asyncio import io from mautrix.types import EncryptedFile from .attachments import _get_decryption_info, _prepare_encryption, inplace_encrypt_attachment async def async_encrypt_attachment( data: bytes | Iterable[bytes] | AsyncIterable[bytes] | io.BufferedIOBase, ) -> AsyncGenerator[bytes | EncryptedFile, None]: """Async generator to encrypt data in order to send it as an encrypted attachment. This function lazily encrypts and yields data, thus it can be used to encrypt large files without fully loading them into memory if an iterable or async iterable of bytes is passed as data. Args: data: The data to encrypt. Passing an async iterable allows the file data to be read in an asynchronous and lazy (without reading the entire file into memory) way. Passing a non-async iterable or standard open binary file object will still allow the data to be read lazily, but not asynchronously. Yields: The encrypted bytes for each chunk of data. The last yielded value will be a dict containing the info needed to decrypt data. The keys are: | key: AES-CTR JWK key object. | iv: Base64 encoded 16 byte AES-CTR IV. | hashes.sha256: Base64 encoded SHA-256 hash of the ciphertext. """ key, iv, cipher, sha256 = _prepare_encryption() loop = asyncio.get_running_loop() async for chunk in async_generator_from_data(data): update_crypt = partial(cipher.encrypt, chunk) crypt_chunk = await loop.run_in_executor(None, update_crypt) update_hash = partial(sha256.update, crypt_chunk) await loop.run_in_executor(None, update_hash) yield crypt_chunk yield _get_decryption_info(key, iv, sha256) async def async_inplace_encrypt_attachment(data: bytearray) -> EncryptedFile: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, partial(inplace_encrypt_attachment, data)) async def async_generator_from_data( data: bytes | Iterable[bytes] | AsyncIterable[bytes] | io.BufferedIOBase, chunk_size: int = 4 * 1024, ) -> AsyncGenerator[bytes, None]: if isinstance(data, bytes): chunks = (data[i : i + chunk_size] for i in range(0, len(data), chunk_size)) for chunk in chunks: yield chunk elif isinstance(data, io.BufferedIOBase): while True: chunk = data.read(chunk_size) if not chunk: return yield chunk elif isinstance(data, Iterable): for chunk in data: yield chunk elif isinstance(data, AsyncIterable): async for chunk in data: yield chunk else: raise TypeError(f"Unknown type for data: {data!r}") python-0.20.7/mautrix/crypto/attachments/async_attachments_test.py000066400000000000000000000026361473573527000256120ustar00rootroot00000000000000# Copyright © 2019 Damir Jelić (under the Apache 2.0 license) # Copyright © 2019 miruka (under the Apache 2.0 license) # Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.types import EncryptedFile from .async_attachments import async_encrypt_attachment, async_inplace_encrypt_attachment from .attachments import decrypt_attachment try: from Crypto import Random except ImportError: from Cryptodome import Random async def _get_data_cypher_keys(data: bytes) -> tuple[bytes, EncryptedFile]: *chunks, keys = [i async for i in async_encrypt_attachment(data)] return b"".join(chunks), keys async def test_async_encrypt(): data = b"Test bytes" cyphertext, keys = await _get_data_cypher_keys(data) plaintext = decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], keys.iv) assert data == plaintext async def test_async_inplace_encrypt(): orig_data = b"Test bytes" data = bytearray(orig_data) keys = await async_inplace_encrypt_attachment(data) assert data != orig_data decrypt_attachment(data, keys.key.key, keys.hashes["sha256"], keys.iv, inplace=True) assert data == orig_data python-0.20.7/mautrix/crypto/attachments/attachments.py000066400000000000000000000123631473573527000233540ustar00rootroot00000000000000# Copyright 2018 Zil0 (under the Apache 2.0 license) # Copyright © 2019 Damir Jelić (under the Apache 2.0 license) # Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Generator, Iterable import binascii import struct import unpaddedbase64 from mautrix.errors import DecryptionError from mautrix.types import EncryptedFile, JSONWebKey try: from Crypto import Random from Crypto.Cipher import AES from Crypto.Hash import SHA256 from Crypto.Util import Counter except ImportError: from Cryptodome import Random from Cryptodome.Cipher import AES from Cryptodome.Hash import SHA256 from Cryptodome.Util import Counter def decrypt_attachment( ciphertext: bytes | bytearray | memoryview, key: str, hash: str, iv: str, inplace: bool = False ) -> bytes: """Decrypt an encrypted attachment. Args: ciphertext: The data to decrypt. key: AES_CTR JWK key object. hash: Base64 encoded SHA-256 hash of the ciphertext. iv: Base64 encoded 16 byte AES-CTR IV. inplace: Should the decryption be performed in-place? The input must be a bytearray or writable memoryview to use this. Returns: The plaintext bytes. Raises: EncryptionError: if the integrity check fails. """ expected_hash = unpaddedbase64.decode_base64(hash) h = SHA256.new() h.update(ciphertext) if h.digest() != expected_hash: raise DecryptionError("Mismatched SHA-256 digest") try: byte_key: bytes = unpaddedbase64.decode_base64(key) except (binascii.Error, TypeError): raise DecryptionError("Error decoding key") try: byte_iv: bytes = unpaddedbase64.decode_base64(iv) if len(byte_iv) != 16: raise DecryptionError("Invalid IV length") prefix = byte_iv[:8] # A non-zero IV counter is not spec-compliant, but some clients still do it, # so decode the counter part too. initial_value = struct.unpack(">Q", byte_iv[8:])[0] except (binascii.Error, TypeError, IndexError, struct.error): raise DecryptionError("Error decoding IV") ctr = Counter.new(64, prefix=prefix, initial_value=initial_value) try: cipher = AES.new(byte_key, AES.MODE_CTR, counter=ctr) except ValueError as e: raise DecryptionError("Failed to create AES cipher") from e if inplace: cipher.decrypt(ciphertext, ciphertext) return ciphertext else: return cipher.decrypt(ciphertext) def encrypt_attachment(plaintext: bytes) -> tuple[bytes, EncryptedFile]: """Encrypt data in order to send it as an encrypted attachment. Args: plaintext: The data to encrypt. Returns: A tuple with the encrypted bytes and a dict containing the info needed to decrypt data. See ``encrypted_attachment_generator()`` for the keys. """ values = list(encrypted_attachment_generator(plaintext)) return b"".join(values[:-1]), values[-1] def _prepare_encryption() -> tuple[bytes, bytes, AES, SHA256.SHA256Hash]: key = Random.new().read(32) # 8 bytes IV iv = Random.new().read(8) # 8 bytes counter, prefixed by the IV ctr = Counter.new(64, prefix=iv, initial_value=0) cipher = AES.new(key, AES.MODE_CTR, counter=ctr) sha256 = SHA256.new() return key, iv, cipher, sha256 def inplace_encrypt_attachment(data: bytearray | memoryview) -> EncryptedFile: key, iv, cipher, sha256 = _prepare_encryption() cipher.encrypt(plaintext=data, output=data) sha256.update(data) return _get_decryption_info(key, iv, sha256) def encrypted_attachment_generator( data: bytes | Iterable[bytes], ) -> Generator[bytes | EncryptedFile, None, None]: """Generator to encrypt data in order to send it as an encrypted attachment. Unlike ``encrypt_attachment()``, this function lazily encrypts and yields data, thus it can be used to encrypt large files without fully loading them into memory if an iterable of bytes is passed as data. Args: data: The data to encrypt. Yields: The encrypted bytes for each chunk of data. The last yielded value will be a dict containing the info needed to decrypt data. """ key, iv, cipher, sha256 = _prepare_encryption() if isinstance(data, bytes): data = [data] for chunk in data: encrypted_chunk = cipher.encrypt(chunk) # in executor sha256.update(encrypted_chunk) # in executor yield encrypted_chunk yield _get_decryption_info(key, iv, sha256) def _get_decryption_info(key: bytes, iv: bytes, sha256: SHA256.SHA256Hash) -> EncryptedFile: return EncryptedFile( version="v2", iv=unpaddedbase64.encode_base64(iv + b"\x00" * 8), hashes={"sha256": unpaddedbase64.encode_base64(sha256.digest())}, key=JSONWebKey( key_type="oct", algorithm="A256CTR", extractable=True, key_ops=["encrypt", "decrypt"], key=unpaddedbase64.encode_base64(key, urlsafe=True), ), ) python-0.20.7/mautrix/crypto/attachments/attachments_test.py000066400000000000000000000053211473573527000244070ustar00rootroot00000000000000# Copyright © 2019 Damir Jelić (under the Apache 2.0 license) # Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import pytest import unpaddedbase64 from mautrix.errors import DecryptionError from .attachments import decrypt_attachment, encrypt_attachment, inplace_encrypt_attachment try: from Crypto import Random except ImportError: from Cryptodome import Random def test_encrypt(): data = b"Test bytes" cyphertext, keys = encrypt_attachment(data) plaintext = decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], keys.iv) assert data == plaintext def test_inplace_encrypt(): orig_data = b"Test bytes" data = bytearray(orig_data) keys = inplace_encrypt_attachment(data) assert data != orig_data decrypt_attachment(data, keys.key.key, keys.hashes["sha256"], keys.iv, inplace=True) assert data == orig_data def test_hash_verification(): data = b"Test bytes" cyphertext, keys = encrypt_attachment(data) with pytest.raises(DecryptionError): decrypt_attachment(cyphertext, keys.key.key, "Fake hash", keys.iv) def test_invalid_key(): data = b"Test bytes" cyphertext, keys = encrypt_attachment(data) with pytest.raises(DecryptionError): decrypt_attachment(cyphertext, "Fake key", keys.hashes["sha256"], keys.iv) def test_invalid_iv(): data = b"Test bytes" cyphertext, keys = encrypt_attachment(data) with pytest.raises(DecryptionError): decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], "Fake iv") def test_short_key(): data = b"Test bytes" cyphertext, keys = encrypt_attachment(data) with pytest.raises(DecryptionError): decrypt_attachment( cyphertext, unpaddedbase64.encode_base64(b"Fake key", urlsafe=True), keys["hashes"]["sha256"], keys["iv"], ) def test_short_iv(): data = b"Test bytes" cyphertext, keys = encrypt_attachment(data) with pytest.raises(DecryptionError): decrypt_attachment( cyphertext, keys.key.key, keys.hashes["sha256"], unpaddedbase64.encode_base64(b"F" + b"\x00" * 8), ) def test_fake_key(): data = b"Test bytes" cyphertext, keys = encrypt_attachment(data) fake_key = Random.new().read(32) plaintext = decrypt_attachment( cyphertext, unpaddedbase64.encode_base64(fake_key, urlsafe=True), keys["hashes"]["sha256"], keys["iv"], ) assert plaintext != data python-0.20.7/mautrix/crypto/base.py000066400000000000000000000111641473573527000174360ustar00rootroot00000000000000# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Awaitable, Callable, TypedDict import asyncio import functools import json import olm from mautrix.errors import MForbidden, MNotFound from mautrix.types import ( DeviceID, EncryptionKeyAlgorithm, EventType, IdentityKey, KeyID, RequestedKeyInfo, RoomEncryptionStateEventContent, RoomID, RoomKeyEventContent, SessionID, SigningKey, TrustState, UserID, ) from mautrix.util.logging import TraceLogger from .. import client as cli, crypto class SignedObject(TypedDict): signatures: dict[UserID, dict[str, str]] unsigned: Any class BaseOlmMachine: client: cli.Client log: TraceLogger crypto_store: crypto.CryptoStore state_store: crypto.StateStore account: account.OlmAccount send_keys_min_trust: TrustState share_keys_min_trust: TrustState allow_key_share: Callable[[crypto.DeviceIdentity, RequestedKeyInfo], Awaitable[bool]] delete_outbound_keys_on_ack: bool dont_store_outbound_keys: bool delete_previous_keys_on_receive: bool ratchet_keys_on_decrypt: bool delete_fully_used_keys_on_decrypt: bool delete_keys_on_device_delete: bool disable_device_change_key_rotation: bool # Futures that wait for responses to a key request _key_request_waiters: dict[SessionID, asyncio.Future] # Futures that wait for a session to be received (either normally or through a key request) _inbound_session_waiters: dict[SessionID, asyncio.Future] _prev_unwedge: dict[IdentityKey, float] _fetch_keys_lock: asyncio.Lock _megolm_decrypt_lock: asyncio.Lock _share_keys_lock: asyncio.Lock _last_key_share: float _cs_fetch_attempted: set[UserID] async def wait_for_session( self, room_id: RoomID, session_id: SessionID, timeout: float = 3 ) -> bool: try: fut = self._inbound_session_waiters[session_id] except KeyError: fut = asyncio.get_running_loop().create_future() self._inbound_session_waiters[session_id] = fut try: return await asyncio.wait_for(asyncio.shield(fut), timeout) except asyncio.TimeoutError: return await self.crypto_store.has_group_session(room_id, session_id) def _mark_session_received(self, session_id: SessionID) -> None: try: self._inbound_session_waiters.pop(session_id).set_result(True) except KeyError: return async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None: encryption_info = await self.state_store.get_encryption_info(evt.room_id) if not encryption_info: self.log.warning( f"Encryption info for {evt.room_id} not found in state store, fetching from server" ) try: encryption_info = await self.client.get_state_event( evt.room_id, EventType.ROOM_ENCRYPTION ) except (MNotFound, MForbidden) as e: self.log.warning( f"Failed to get encryption info for {evt.room_id} from server: {e}," " using defaults" ) encryption_info = RoomEncryptionStateEventContent() if not encryption_info: self.log.warning( f"Didn't find encryption info for {evt.room_id} on server either," " using defaults" ) encryption_info = RoomEncryptionStateEventContent() if not evt.beeper_max_age_ms: evt.beeper_max_age_ms = encryption_info.rotation_period_ms if not evt.beeper_max_messages: evt.beeper_max_messages = encryption_info.rotation_period_msgs canonical_json = functools.partial( json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True ) def verify_signature_json( data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey ) -> bool: data_copy = {**data} data_copy.pop("unsigned", None) signatures = data_copy.pop("signatures") key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name)) try: signature = signatures[user_id][key_id] except KeyError: return False signed_data = canonical_json(data_copy) try: olm.ed25519_verify(key, signed_data, signature) return True except olm.OlmVerifyError: return False python-0.20.7/mautrix/crypto/decrypt_megolm.py000066400000000000000000000173661473573527000215500ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import json import olm from mautrix.errors import ( DecryptedPayloadError, DecryptionError, DuplicateMessageIndex, MismatchingRoomError, SessionNotFound, VerificationError, ) from mautrix.types import ( EncryptedEvent, EncryptedMegolmEventContent, EncryptionAlgorithm, Event, SessionID, TrustState, ) from .device_lists import DeviceListMachine from .sessions import InboundGroupSession class MegolmDecryptionMachine(DeviceListMachine): async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event: """ Decrypt an event that was encrypted using Megolm. Args: evt: The whole encrypted event. Returns: The decrypted event, including some unencrypted metadata from the input event. Raises: DecryptionError: If decryption failed. """ if not isinstance(evt.content, EncryptedMegolmEventContent): raise DecryptionError("Unsupported event content class") elif evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1: raise DecryptionError("Unsupported event encryption algorithm") async with self._megolm_decrypt_lock: session = await self.crypto_store.get_group_session( evt.room_id, evt.content.session_id ) if session is None: # TODO check if olm session is wedged raise SessionNotFound(evt.content.session_id, evt.content.sender_key) try: plaintext, index = session.decrypt(evt.content.ciphertext) except olm.OlmGroupSessionError as e: raise DecryptionError("Failed to decrypt megolm event") from e if not await self.crypto_store.validate_message_index( session.sender_key, SessionID(session.id), evt.event_id, index, evt.timestamp ): raise DuplicateMessageIndex() await self._ratchet_session(session, index) forwarded_keys = False if ( evt.content.device_id == self.client.device_id and session.signing_key == self.account.signing_key and session.sender_key == self.account.identity_key and not session.forwarding_chain ): trust_level = TrustState.VERIFIED else: device = await self.get_or_fetch_device_by_key(evt.sender, session.sender_key) if not session.forwarding_chain or ( len(session.forwarding_chain) == 1 and session.forwarding_chain[0] == session.sender_key ): if not device: self.log.debug( f"Couldn't resolve trust level of session {session.id}: " f"sent by unknown device {evt.sender}/{session.sender_key}" ) trust_level = TrustState.UNKNOWN_DEVICE elif ( device.signing_key != session.signing_key or device.identity_key != session.sender_key ): raise VerificationError() else: trust_level = await self.resolve_trust(device) else: forwarded_keys = True last_chain_item = session.forwarding_chain[-1] received_from = await self.crypto_store.find_device_by_key( evt.sender, last_chain_item ) if received_from: trust_level = await self.resolve_trust(received_from) else: self.log.debug( f"Couldn't resolve trust level of session {session.id}: " f"forwarding chain ends with unknown device {last_chain_item}" ) trust_level = TrustState.FORWARDED try: data = json.loads(plaintext) room_id = data["room_id"] event_type = data["type"] content = data["content"] except json.JSONDecodeError as e: raise DecryptedPayloadError("Failed to parse megolm payload") from e except KeyError as e: raise DecryptedPayloadError("Megolm payload is missing fields") from e if room_id != evt.room_id: raise MismatchingRoomError() if evt.content.relates_to and "m.relates_to" not in content: content["m.relates_to"] = evt.content.relates_to.serialize() result = Event.deserialize( { "room_id": evt.room_id, "event_id": evt.event_id, "sender": evt.sender, "origin_server_ts": evt.timestamp, "type": event_type, "content": content, } ) result.unsigned = evt.unsigned result.type = result.type.with_class(evt.type.t_class) result["mautrix"] = { "trust_state": trust_level, "forwarded_keys": forwarded_keys, "was_encrypted": True, } return result async def _ratchet_session(self, sess: InboundGroupSession, index: int) -> None: expected_message_index = sess.ratchet_safety.next_index did_modify = True if index > expected_message_index: sess.ratchet_safety.missed_indices += list(range(expected_message_index, index)) sess.ratchet_safety.next_index = index + 1 elif index == expected_message_index: sess.ratchet_safety.next_index = index + 1 else: try: sess.ratchet_safety.missed_indices.remove(index) except ValueError: did_modify = False # Use presence of received_at as a sign that this is a recent megolm session, # and therefore it's safe to drop missed indices entirely. if ( sess.received_at and sess.ratchet_safety.missed_indices and sess.ratchet_safety.missed_indices[0] < expected_message_index - 10 ): i = 0 for i, lost_index in enumerate(sess.ratchet_safety.missed_indices): if lost_index < expected_message_index - 10: sess.ratchet_safety.lost_indices.append(lost_index) else: break sess.ratchet_safety.missed_indices = sess.ratchet_safety.missed_indices[i + 1 :] ratchet_target_index = sess.ratchet_safety.next_index if len(sess.ratchet_safety.missed_indices) > 0: ratchet_target_index = min(sess.ratchet_safety.missed_indices) self.log.debug( f"Ratchet safety info for {sess.id}: {sess.ratchet_safety}, {ratchet_target_index=}" ) sess_id = SessionID(sess.id) if ( sess.max_messages and ratchet_target_index >= sess.max_messages and not sess.ratchet_safety.missed_indices and self.delete_fully_used_keys_on_decrypt ): self.log.info(f"Deleting fully used session {sess.id}") await self.crypto_store.redact_group_session( sess.room_id, sess_id, reason="maximum messages reached" ) return elif sess.first_known_index < ratchet_target_index and self.ratchet_keys_on_decrypt: self.log.info(f"Ratcheting session {sess.id} to {ratchet_target_index}") sess = sess.ratchet_to(ratchet_target_index) elif not did_modify: return await self.crypto_store.put_group_session(sess.room_id, sess.sender_key, sess_id, sess) python-0.20.7/mautrix/crypto/decrypt_olm.py000066400000000000000000000127251473573527000210510ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional import asyncio import olm from mautrix.errors import DecryptionError, MatchingSessionDecryptionError from mautrix.types import ( DecryptedOlmEvent, EncryptedOlmEventContent, EncryptionAlgorithm, IdentityKey, OlmCiphertext, OlmMsgType, ToDeviceEvent, UserID, ) from mautrix.util import background_task from .base import BaseOlmMachine from .sessions import Session class OlmDecryptionMachine(BaseOlmMachine): async def _decrypt_olm_event(self, evt: ToDeviceEvent) -> DecryptedOlmEvent: if not isinstance(evt.content, EncryptedOlmEventContent): raise DecryptionError("unsupported event content class") elif evt.content.algorithm != EncryptionAlgorithm.OLM_V1: raise DecryptionError("unsupported event encryption algorithm") try: own_content = evt.content.ciphertext[self.account.identity_key] except KeyError: raise DecryptionError("olm event doesn't contain ciphertext for this device") self.log.debug( f"Decrypting to-device olm event from {evt.sender}/{evt.content.sender_key}" ) plaintext = await self._decrypt_olm_ciphertext( evt.sender, evt.content.sender_key, own_content ) try: decrypted_evt: DecryptedOlmEvent = DecryptedOlmEvent.parse_json(plaintext) except Exception: self.log.trace("Failed to parse olm event plaintext: %s", plaintext) raise if decrypted_evt.sender != evt.sender: raise DecryptionError("mismatched sender in olm payload") elif decrypted_evt.recipient != self.client.mxid: raise DecryptionError("mismatched recipient in olm payload") elif decrypted_evt.recipient_keys.ed25519 != self.account.signing_key: raise DecryptionError("mismatched recipient key in olm payload") decrypted_evt.sender_key = evt.content.sender_key decrypted_evt.source = evt self.log.debug( f"Successfully decrypted olm event from {evt.sender}/{decrypted_evt.sender_device} " f"(sender key: {decrypted_evt.sender_key} into a {decrypted_evt.type}" ) return decrypted_evt async def _decrypt_olm_ciphertext( self, sender: UserID, sender_key: IdentityKey, message: OlmCiphertext ) -> str: if message.type not in (OlmMsgType.PREKEY, OlmMsgType.MESSAGE): raise DecryptionError("unsupported olm message type") try: plaintext = await self._try_decrypt_olm_ciphertext(sender_key, message) except MatchingSessionDecryptionError: self.log.warning( f"Found matching session yet decryption failed for sender {sender}" f" with key {sender_key}" ) background_task.create(self._unwedge_session(sender, sender_key)) raise if not plaintext: if message.type != OlmMsgType.PREKEY: background_task.create(self._unwedge_session(sender, sender_key)) raise DecryptionError("Decryption failed for normal message") self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}") try: session = await self._create_inbound_session(sender_key, message.body) except olm.OlmSessionError as e: background_task.create(self._unwedge_session(sender, sender_key)) raise DecryptionError("Failed to create new session from prekey message") from e self.log.debug( f"Created inbound session {session.id} for {sender} (sender key: {sender_key})" ) try: plaintext = session.decrypt(message) except olm.OlmSessionError as e: raise DecryptionError( "Failed to decrypt olm event with session created from prekey message" ) from e await self.crypto_store.update_session(sender_key, session) return plaintext async def _try_decrypt_olm_ciphertext( self, sender_key: IdentityKey, message: OlmCiphertext ) -> Optional[str]: sessions = await self.crypto_store.get_sessions(sender_key) for session in sessions: if message.type == OlmMsgType.PREKEY and not session.matches(message.body): continue try: plaintext = session.decrypt(message) except olm.OlmSessionError as e: if message.type == OlmMsgType.PREKEY: raise MatchingSessionDecryptionError( "decryption failed with matching session" ) from e else: await self.crypto_store.update_session(sender_key, session) return plaintext return None async def _create_inbound_session(self, sender_key: IdentityKey, ciphertext: str) -> Session: session = self.account.new_inbound_session(sender_key, ciphertext) await self.crypto_store.put_account(self.account) await self.crypto_store.add_session(sender_key, session) return session async def _unwedge_session(self, sender: UserID, sender_key: IdentityKey) -> None: raise NotImplementedError() python-0.20.7/mautrix/crypto/device_lists.py000066400000000000000000000355361473573527000212120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.errors import DeviceValidationError from mautrix.types import ( CrossSigner, CrossSigningKeys, CrossSigningUsage, DeviceID, DeviceIdentity, DeviceKeys, EncryptionKeyAlgorithm, IdentityKey, KeyID, QueryKeysResponse, SigningKey, SyncToken, TrustState, UserID, ) from .base import BaseOlmMachine, verify_signature_json class DeviceListMachine(BaseOlmMachine): async def _fetch_keys( self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False ) -> dict[UserID, dict[DeviceID, DeviceIdentity]]: if not include_untracked: users = await self.crypto_store.filter_tracked_users(users) if len(users) == 0: return {} users = set(users) self.log.trace(f"Querying keys for {users}") resp = await self.client.query_keys(users, token=since) missing_users = users.copy() for server, err in resp.failures.items(): self.log.warning(f"Query keys failure for {server}: {err}") data = {} for user_id, devices in resp.device_keys.items(): missing_users.remove(user_id) new_devices = {} existing_devices = (await self.crypto_store.get_devices(user_id)) or {} self.log.trace( f"Updating devices for {user_id}, got {len(devices)}, " f"have {len(existing_devices)} in store" ) changed = False ssks = resp.self_signing_keys.get(user_id) ssk = ssks.first_ed25519_key if ssks else None for device_id, device_keys in devices.items(): try: existing = existing_devices[device_id] except KeyError: existing = None changed = True self.log.trace(f"Validating device {device_keys} of {user_id}") try: new_device = await self._validate_device( user_id, device_id, device_keys, existing ) except DeviceValidationError as e: self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}") else: if new_device: new_devices[device_id] = new_device await self._store_device_self_signatures(device_keys, ssk) self.log.debug( f"Storing new device list for {user_id} containing {len(new_devices)} devices" ) await self.crypto_store.put_devices(user_id, new_devices) data[user_id] = new_devices if changed or len(new_devices) != len(existing_devices): if self.delete_keys_on_device_delete: for device_id in existing_devices.keys() - new_devices.keys(): device = existing_devices[device_id] removed_ids = await self.crypto_store.redact_group_sessions( room_id=None, sender_key=device.identity_key, reason="device removed" ) self.log.info( "Redacted megolm sessions sent by removed device " f"{device.user_id}/{device.device_id}: {removed_ids}" ) await self.on_devices_changed(user_id) for user_id in missing_users: self.log.warning(f"Didn't get any devices for user {user_id}") for user_id in users: await self._store_cross_signing_keys(resp, user_id) return data async def _store_device_self_signatures( self, device_keys: DeviceKeys, self_signing_key: SigningKey | None ) -> None: device_desc = f"Device {device_keys.user_id}/{device_keys.device_id}" try: self_signatures = device_keys.signatures[device_keys.user_id].copy() except KeyError: self.log.warning(f"{device_desc} doesn't have any signatures from the user") return if len(device_keys.signatures) > 1: self.log.debug( f"{device_desc} has signatures from other users (%s)", set(device_keys.signatures.keys()) - {device_keys.user_id}, ) device_self_sig = self_signatures.pop( KeyID(EncryptionKeyAlgorithm.ED25519, device_keys.device_id) ) target = CrossSigner(device_keys.user_id, device_keys.ed25519) # This one is already validated by _validate_device await self.crypto_store.put_signature(target, target, device_self_sig) try: cs_self_sig = self_signatures.pop( KeyID(EncryptionKeyAlgorithm.ED25519, self_signing_key) ) except KeyError: self.log.warning(f"{device_desc} isn't cross-signed") else: is_valid_self_sig = verify_signature_json( device_keys.serialize(), device_keys.user_id, self_signing_key, self_signing_key ) if is_valid_self_sig: signer = CrossSigner(device_keys.user_id, self_signing_key) await self.crypto_store.put_signature(target, signer, cs_self_sig) else: self.log.warning(f"{device_desc} doesn't have a valid cross-signing signature") if len(self_signatures) > 0: self.log.debug( f"{device_desc} has signatures from unexpected keys (%s)", set(self_signatures.keys()), ) async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: UserID) -> None: new_keys: dict[CrossSigningUsage, CrossSigningKeys] = {} try: master = new_keys[CrossSigningUsage.MASTER] = resp.master_keys[user_id] except KeyError: self.log.debug(f"Didn't get a cross-signing master key for {user_id}") return try: new_keys[CrossSigningUsage.SELF] = resp.self_signing_keys[user_id] except KeyError: self.log.debug(f"Didn't get a cross-signing self-signing key for {user_id}") return try: new_keys[CrossSigningUsage.USER] = resp.user_signing_keys[user_id] except KeyError: pass current_keys = await self.crypto_store.get_cross_signing_keys(user_id) for usage, key in current_keys.items(): if usage in new_keys and key.key != new_keys[usage].first_ed25519_key: num = await self.crypto_store.drop_signatures_by_key(CrossSigner(user_id, key.key)) if num >= 0: self.log.debug( f"Dropped {num} signatures made by key {user_id}/{key.key} ({usage})" " as it has been replaced" ) for usage, key in new_keys.items(): actual_key = key.first_ed25519_key self.log.debug(f"Storing cross-signing key for {user_id}: {actual_key} (type {usage})") await self.crypto_store.put_cross_signing_key(user_id, usage, actual_key) if usage != CrossSigningUsage.MASTER and ( KeyID(EncryptionKeyAlgorithm.ED25519, master.first_ed25519_key) not in key.signatures[user_id] ): self.log.warning( f"Cross-signing key {user_id}/{actual_key}/{usage}" " doesn't seem to have a signature from the master key" ) for signer_user_id, signatures in key.signatures.items(): for key_id, signature in signatures.items(): signing_key = SigningKey(key_id.key_id) if signer_user_id == user_id: try: device = resp.device_keys[signer_user_id][DeviceID(key_id.key_id)] signing_key = device.ed25519 except KeyError: pass if len(signing_key) != 43: self.log.debug( f"Cross-signing key {user_id}/{actual_key} has a signature from " f"an unknown key {key_id}" ) continue signing_key_log = signing_key if signing_key != key_id.key_id: signing_key_log = f"{signing_key} ({key_id})" self.log.debug( f"Verifying cross-signing key {user_id}/{actual_key} " f"with key {signer_user_id}/{signing_key_log}" ) is_valid_sig = verify_signature_json( key.serialize(), signer_user_id, key_id.key_id, signing_key ) if is_valid_sig: self.log.debug(f"Signature from {signing_key_log} for {key_id} verified") await self.crypto_store.put_signature( target=CrossSigner(user_id, actual_key), signer=CrossSigner(signer_user_id, signing_key), signature=signature, ) else: self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}") async def get_or_fetch_device( self, user_id: UserID, device_id: DeviceID ) -> DeviceIdentity | None: device = await self.crypto_store.get_device(user_id, device_id) if device is not None: return device devices = await self._fetch_keys([user_id], include_untracked=True) try: return devices[user_id][device_id] except KeyError: return None async def get_or_fetch_device_by_key( self, user_id: UserID, identity_key: IdentityKey ) -> DeviceIdentity | None: device = await self.crypto_store.find_device_by_key(user_id, identity_key) if device is not None: return device devices = await self._fetch_keys([user_id], include_untracked=True) for device in devices.get(user_id, {}).values(): if device.identity_key == identity_key: return device return None async def on_devices_changed(self, user_id: UserID) -> None: if self.disable_device_change_key_rotation: return shared_rooms = await self.state_store.find_shared_rooms(user_id) self.log.debug( f"Devices of {user_id} changed, invalidating group session in {shared_rooms}" ) await self.crypto_store.remove_outbound_group_sessions(shared_rooms) @staticmethod async def _validate_device( user_id: UserID, device_id: DeviceID, device_keys: DeviceKeys, existing: DeviceIdentity | None = None, ) -> DeviceIdentity: if user_id != device_keys.user_id: raise DeviceValidationError( f"mismatching user ID (expected {user_id}, got {device_keys.user_id})" ) elif device_id != device_keys.device_id: raise DeviceValidationError( f"mismatching device ID (expected {device_id}, got {device_keys.device_id})" ) signing_key = device_keys.ed25519 if not signing_key: raise DeviceValidationError("didn't find ed25519 signing key") identity_key = device_keys.curve25519 if not identity_key: raise DeviceValidationError("didn't find curve25519 identity key") if existing and existing.signing_key != signing_key: raise DeviceValidationError( f"received update for device with different signing key " f"(expected {existing.signing_key}, got {signing_key})" ) if not verify_signature_json(device_keys.serialize(), user_id, device_id, signing_key): raise DeviceValidationError("invalid signature on device keys") name = device_keys.unsigned.device_display_name or device_id return DeviceIdentity( user_id=user_id, device_id=device_id, identity_key=identity_key, signing_key=signing_key, trust=TrustState.UNVERIFIED, name=name, deleted=False, ) async def resolve_trust(self, device: DeviceIdentity) -> TrustState: try: return await self._try_resolve_trust(device) except Exception: self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}") return TrustState.UNVERIFIED async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState: if device.trust in (TrustState.VERIFIED, TrustState.BLACKLISTED): return device.trust their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id) if len(their_keys) == 0 and device.user_id not in self._cs_fetch_attempted: self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...") async with self._fetch_keys_lock: if device.user_id not in self._cs_fetch_attempted: self._cs_fetch_attempted.add(device.user_id) await self._fetch_keys([device.user_id]) their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id) try: msk = their_keys[CrossSigningUsage.MASTER] ssk = their_keys[CrossSigningUsage.SELF] except KeyError as e: self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}") return TrustState.UNVERIFIED ssk_signed = await self.crypto_store.is_key_signed_by( target=CrossSigner(device.user_id, ssk.key), signer=CrossSigner(device.user_id, msk.key), ) if not ssk_signed: self.log.warning( f"Self-signing key of {device.user_id} is not signed by their master key" ) return TrustState.UNVERIFIED device_signed = await self.crypto_store.is_key_signed_by( target=CrossSigner(device.user_id, device.signing_key), signer=CrossSigner(device.user_id, ssk.key), ) if device_signed: if await self.is_user_trusted(device.user_id): return TrustState.CROSS_SIGNED_TRUSTED elif msk.key == msk.first: return TrustState.CROSS_SIGNED_TOFU return TrustState.CROSS_SIGNED_UNTRUSTED return TrustState.UNVERIFIED async def is_user_trusted(self, user_id: UserID) -> bool: # TODO implement once own cross-signing key stuff is ready return False python-0.20.7/mautrix/crypto/encrypt_megolm.py000066400000000000000000000353521473573527000215550ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Dict, List, Tuple, Union from collections import defaultdict from datetime import datetime, timedelta import asyncio import json import time from mautrix.errors import EncryptionError, SessionShareError from mautrix.types import ( DeviceID, DeviceIdentity, EncryptedMegolmEventContent, EncryptionAlgorithm, EventType, IdentityKey, RelatesTo, RoomID, RoomKeyWithheldCode, RoomKeyWithheldEventContent, Serializable, SessionID, SigningKey, TrustState, UserID, ) from .device_lists import DeviceListMachine from .encrypt_olm import OlmEncryptionMachine from .sessions import InboundGroupSession, OutboundGroupSession, Session class Sentinel: pass already_shared = Sentinel() key_missing = Sentinel() DeviceSessionWrapper = Tuple[Session, DeviceIdentity] DeviceMap = Dict[UserID, Dict[DeviceID, DeviceSessionWrapper]] SessionEncryptResult = Union[ type(already_shared), # already shared DeviceSessionWrapper, # share successful RoomKeyWithheldEventContent, # won't share type(key_missing), # missing device ] class MegolmEncryptionMachine(OlmEncryptionMachine, DeviceListMachine): _megolm_locks: Dict[RoomID, asyncio.Lock] _sharing_group_session: Dict[RoomID, asyncio.Event] def __init__(self) -> None: super().__init__() self._megolm_locks = defaultdict(lambda: asyncio.Lock()) self._sharing_group_session = {} async def encrypt_megolm_event( self, room_id: RoomID, event_type: EventType, content: Any ) -> EncryptedMegolmEventContent: """ Encrypt an event for a specific room using Megolm. Args: room_id: The room to encrypt the message for. event_type: The event type. content: The event content. Using the content structs in the mautrix.types module is recommended. Returns: The encrypted event content. Raises: EncryptionError: If a group session has not been shared. Use :meth:`share_group_session` to share a group session if this error is raised. """ # The crypto store is async, so we need to make sure only one thing is writing at a time. async with self._megolm_locks[room_id]: return await self._encrypt_megolm_event(room_id, event_type, content) async def _encrypt_megolm_event( self, room_id: RoomID, event_type: EventType, content: Any ) -> EncryptedMegolmEventContent: self.log.debug(f"Encrypting event of type {event_type} for {room_id}") session = await self.crypto_store.get_outbound_group_session(room_id) if not session: raise EncryptionError("No group session created") ciphertext = session.encrypt( json.dumps( { "room_id": room_id, "type": event_type.serialize(), "content": ( content.serialize() if isinstance(content, Serializable) else content ), } ) ) try: relates_to = content.relates_to except AttributeError: try: relates_to = RelatesTo.deserialize(content["m.relates_to"]) except KeyError: relates_to = None await self.crypto_store.update_outbound_group_session(session) return EncryptedMegolmEventContent( sender_key=self.account.identity_key, device_id=self.client.device_id, ciphertext=ciphertext, session_id=SessionID(session.id), relates_to=relates_to, ) def is_sharing_group_session(self, room_id: RoomID) -> bool: """ Check if there's a group session being shared for a specific room Args: room_id: The room ID to check. Returns: True if a group session share is in progress, False if not """ return room_id in self._sharing_group_session async def wait_group_session_share(self, room_id: RoomID) -> None: """ Wait for a group session to be shared. Args: room_id: The room ID to wait for. """ try: event = self._sharing_group_session[room_id] await event.wait() except KeyError: pass async def share_group_session(self, room_id: RoomID, users: List[UserID]) -> None: """ Create a Megolm session for a specific room and share it with the given list of users. Note that you must not call this method again before the previous share has finished. You should either lock calls yourself, or use :meth:`wait_group_session_share` to use built-in locking capabilities. Args: room_id: The room to create the session for. users: The list of users to share the session with. Raises: SessionShareError: If something went wrong while sharing the session. """ if room_id in self._sharing_group_session: raise SessionShareError("Already sharing group session for that room") self._sharing_group_session[room_id] = asyncio.Event() try: await self._share_group_session(room_id, users) finally: self._sharing_group_session.pop(room_id).set() async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> None: session = await self.crypto_store.get_outbound_group_session(room_id) if session and session.shared and not session.expired: raise SessionShareError("Group session has already been shared") if not session or session.expired: session = await self._new_outbound_group_session(room_id) self.log.debug(f"Sharing group session {session.id} for room {room_id} with {users}") olm_sessions: DeviceMap = defaultdict(lambda: {}) withhold_key_msgs = defaultdict(lambda: {}) missing_sessions: Dict[UserID, Dict[DeviceID, DeviceIdentity]] = defaultdict(lambda: {}) fetch_keys = [] for user_id in users: devices = await self.crypto_store.get_devices(user_id) if devices is None: self.log.debug( f"get_devices returned nil for {user_id}, will fetch keys and retry" ) fetch_keys.append(user_id) elif len(devices) == 0: self.log.debug(f"{user_id} has no devices, skipping") else: self.log.debug(f"Trying to encrypt group session {session.id} for {user_id}") for device_id, device in devices.items(): result = await self._find_olm_sessions(session, user_id, device_id, device) if isinstance(result, RoomKeyWithheldEventContent): withhold_key_msgs[user_id][device_id] = result elif result == key_missing: missing_sessions[user_id][device_id] = device elif isinstance(result, tuple): olm_sessions[user_id][device_id] = result if fetch_keys: self.log.debug(f"Fetching missing keys for {fetch_keys}") fetched_keys = await self._fetch_keys(users, include_untracked=True) for user_id, devices in fetched_keys.items(): missing_sessions[user_id] = devices if missing_sessions: self.log.debug(f"Creating missing outbound sessions {missing_sessions}") try: await self._create_outbound_sessions(missing_sessions) except Exception: self.log.exception("Failed to create missing outbound sessions") for user_id, devices in missing_sessions.items(): for device_id, device in devices.items(): result = await self._find_olm_sessions(session, user_id, device_id, device) if isinstance(result, RoomKeyWithheldEventContent): withhold_key_msgs[user_id][device_id] = result elif isinstance(result, tuple): olm_sessions[user_id][device_id] = result # We don't care about missing keys at this point if len(olm_sessions) > 0: async with self._olm_lock: await self._encrypt_and_share_group_session(session, olm_sessions) if len(withhold_key_msgs) > 0: event_count = sum(len(map) for map in withhold_key_msgs.values()) self.log.debug( f"Sending {event_count} to-device events to report {session.id} is withheld" ) await self.client.send_to_device(EventType.ROOM_KEY_WITHHELD, withhold_key_msgs) await self.client.send_to_device( EventType.ORG_MATRIX_ROOM_KEY_WITHHELD, withhold_key_msgs ) self.log.info(f"Group session {session.id} for {room_id} successfully shared") session.shared = True await self.crypto_store.add_outbound_group_session(session) async def _new_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession: session = OutboundGroupSession(room_id) encryption_info = await self.state_store.get_encryption_info(room_id) if encryption_info: if encryption_info.algorithm != EncryptionAlgorithm.MEGOLM_V1: raise SessionShareError("Room encryption algorithm is not supported") session.max_messages = encryption_info.rotation_period_msgs or session.max_messages session.max_age = ( timedelta(milliseconds=encryption_info.rotation_period_ms) if encryption_info.rotation_period_ms else session.max_age ) self.log.debug( "Got stored encryption state event and configured session to rotate " f"after {session.max_messages} messages or {session.max_age}" ) if not self.dont_store_outbound_keys: await self._create_group_session( self.account.identity_key, self.account.signing_key, room_id, SessionID(session.id), session.session_key, max_messages=session.max_messages, max_age=session.max_age, is_scheduled=False, ) return session async def _encrypt_and_share_group_session( self, session: OutboundGroupSession, olm_sessions: DeviceMap ): msgs = defaultdict(lambda: {}) count = 0 for user_id, devices in olm_sessions.items(): count += len(devices) for device_id, (olm_session, device_identity) in devices.items(): msgs[user_id][device_id] = await self._encrypt_olm_event( olm_session, device_identity, EventType.ROOM_KEY, session.share_content ) self.log.debug( f"Sending to-device events to {count} devices of {len(msgs)} users " f"to share {session.id}" ) await self.client.send_to_device(EventType.TO_DEVICE_ENCRYPTED, msgs) async def _create_group_session( self, sender_key: IdentityKey, signing_key: SigningKey, room_id: RoomID, session_id: SessionID, session_key: str, max_age: Union[timedelta, int], max_messages: int, is_scheduled: bool = False, ) -> None: start = time.monotonic() session = InboundGroupSession( session_key=session_key, signing_key=signing_key, sender_key=sender_key, room_id=room_id, received_at=datetime.utcnow(), max_age=max_age, max_messages=max_messages, is_scheduled=is_scheduled, ) olm_duration = time.monotonic() - start if olm_duration > 5: self.log.warning(f"Creating inbound group session took {olm_duration:.3f} seconds!") if session_id != session.id: self.log.warning(f"Mismatching session IDs: expected {session_id}, got {session.id}") session_id = session.id await self.crypto_store.put_group_session(room_id, sender_key, session_id, session) self._mark_session_received(session_id) self.log.debug( f"Created inbound group session {room_id}/{sender_key}/{session_id} " f"(max {max_age} / {max_messages} messages, {is_scheduled=})" ) async def _find_olm_sessions( self, session: OutboundGroupSession, user_id: UserID, device_id: DeviceID, device: DeviceIdentity, ) -> SessionEncryptResult: key = (user_id, device_id) if key in session.users_ignored or key in session.users_shared_with: return already_shared elif user_id == self.client.mxid and device_id == self.client.device_id: session.users_ignored.add(key) return already_shared trust = await self.resolve_trust(device) if trust == TrustState.BLACKLISTED: self.log.debug( f"Not encrypting group session {session.id} for {device_id} " f"of {user_id}: device is blacklisted" ) session.users_ignored.add(key) return RoomKeyWithheldEventContent( room_id=session.room_id, algorithm=EncryptionAlgorithm.MEGOLM_V1, session_id=SessionID(session.id), sender_key=self.account.identity_key, code=RoomKeyWithheldCode.BLACKLISTED, reason="Device is blacklisted", ) elif self.send_keys_min_trust > trust: self.log.debug( f"Not encrypting group session {session.id} for {device_id} " f"of {user_id}: device is not trusted " f"(min: {self.send_keys_min_trust}, device: {trust})" ) session.users_ignored.add(key) return RoomKeyWithheldEventContent( room_id=session.room_id, algorithm=EncryptionAlgorithm.MEGOLM_V1, session_id=SessionID(session.id), sender_key=self.account.identity_key, code=RoomKeyWithheldCode.UNVERIFIED, reason="This device does not encrypt messages for unverified devices", ) device_session = await self.crypto_store.get_latest_session(device.identity_key) if not device_session: return key_missing session.users_shared_with.add(key) return device_session, device python-0.20.7/mautrix/crypto/encrypt_olm.py000066400000000000000000000123241473573527000210560ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Dict import asyncio from mautrix.types import ( DecryptedOlmEvent, DeviceID, DeviceIdentity, EncryptedOlmEventContent, EncryptionKeyAlgorithm, EventType, OlmEventKeys, ToDeviceEventContent, UserID, ) from .base import BaseOlmMachine, verify_signature_json from .sessions import Session ClaimKeysList = Dict[UserID, Dict[DeviceID, DeviceIdentity]] class OlmEncryptionMachine(BaseOlmMachine): _claim_keys_lock: asyncio.Lock _olm_lock: asyncio.Lock def __init__(self): self._claim_keys_lock = asyncio.Lock() self._olm_lock = asyncio.Lock() async def _encrypt_olm_event( self, session: Session, recipient: DeviceIdentity, event_type: EventType, content: Any ) -> EncryptedOlmEventContent: evt = DecryptedOlmEvent( sender=self.client.mxid, sender_device=self.client.device_id, keys=OlmEventKeys(ed25519=self.account.signing_key), recipient=recipient.user_id, recipient_keys=OlmEventKeys(ed25519=recipient.signing_key), type=event_type, content=content, ) ciphertext = session.encrypt(evt.json()) await self.crypto_store.update_session(recipient.identity_key, session) return EncryptedOlmEventContent( ciphertext={recipient.identity_key: ciphertext}, sender_key=self.account.identity_key ) async def _create_outbound_sessions( self, users: ClaimKeysList, _force_recreate_session: bool = False ) -> None: async with self._claim_keys_lock: return await self._create_outbound_sessions_locked(users, _force_recreate_session) async def _create_outbound_sessions_locked( self, users: ClaimKeysList, _force_recreate_session: bool = False ) -> None: request: Dict[UserID, Dict[DeviceID, EncryptionKeyAlgorithm]] = {} expected_devices = set() for user_id, devices in users.items(): request[user_id] = {} for device_id, identity in devices.items(): if _force_recreate_session or not await self.crypto_store.has_session( identity.identity_key ): request[user_id][device_id] = EncryptionKeyAlgorithm.SIGNED_CURVE25519 expected_devices.add((user_id, device_id)) if not request[user_id]: del request[user_id] if not request: return request_device_count = len(expected_devices) keys = await self.client.claim_keys(request) for server, info in (keys.failures or {}).items(): self.log.warning(f"Key claim failure for {server}: {info}") for user_id, devices in keys.one_time_keys.items(): for device_id, one_time_keys in devices.items(): expected_devices.discard((user_id, device_id)) key_id, one_time_key_data = one_time_keys.popitem() one_time_key = one_time_key_data["key"] identity = users[user_id][device_id] if not verify_signature_json( one_time_key_data, user_id, device_id, identity.signing_key ): self.log.warning(f"Invalid signature for {device_id} of {user_id}") else: session = self.account.new_outbound_session( identity.identity_key, one_time_key ) await self.crypto_store.add_session(identity.identity_key, session) self.log.debug( f"Created new Olm session with {user_id}/{device_id} " f"(OTK ID: {key_id})" ) if expected_devices: if request_device_count == 1: raise Exception( "Key claim response didn't contain key " f"for queried device {expected_devices.pop()}" ) else: self.log.warning( "Key claim response didn't contain keys for %d out of %d expected devices: %s", len(expected_devices), request_device_count, expected_devices, ) async def send_encrypted_to_device( self, device: DeviceIdentity, event_type: EventType, content: ToDeviceEventContent, _force_recreate_session: bool = False, ) -> None: await self._create_outbound_sessions( {device.user_id: {device.device_id: device}}, _force_recreate_session=_force_recreate_session, ) session = await self.crypto_store.get_latest_session(device.identity_key) async with self._olm_lock: encrypted_content = await self._encrypt_olm_event(session, device, event_type, content) await self.client.send_to_one_device( EventType.TO_DEVICE_ENCRYPTED, device.user_id, device.device_id, encrypted_content ) python-0.20.7/mautrix/crypto/key_request.py000066400000000000000000000135021473573527000210620ustar00rootroot00000000000000# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, List, Optional, Union import asyncio import uuid from mautrix.types import ( DecryptedOlmEvent, DeviceID, EncryptionAlgorithm, EventType, ForwardedRoomKeyEventContent, IdentityKey, KeyRequestAction, RequestedKeyInfo, RoomID, RoomKeyRequestEventContent, SessionID, UserID, ) from .base import BaseOlmMachine from .sessions import InboundGroupSession class KeyRequestingMachine(BaseOlmMachine): async def request_room_key( self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID, from_devices: Dict[UserID, List[DeviceID]], timeout: Optional[Union[int, float]] = None, ) -> bool: """ Request keys for a Megolm group session from other devices. Once the keys are received, or if this task is cancelled (via the ``timeout`` parameter), a cancel request event is sent to the remaining devices. If the ``timeout`` is set to zero or less, this will return immediately, and the extra key requests will not be cancelled. Args: room_id: The room where the session is used. sender_key: The key of the user who created the session. session_id: The ID of the session. from_devices: A dict from user ID to list of device IDs whom to ask for the keys. timeout: The maximum number of seconds to wait for the keys. If the timeout is ``None``, the wait time is not limited, but the task can still be cancelled. If it's zero or less, this returns immediately and will never cancel requests. Returns: ``True`` if the keys were received and are now in the crypto store, ``False`` otherwise (including if the method didn't wait at all). """ request_id = str(uuid.uuid1()) request = RoomKeyRequestEventContent( action=KeyRequestAction.REQUEST, body=RequestedKeyInfo( algorithm=EncryptionAlgorithm.MEGOLM_V1, room_id=room_id, sender_key=sender_key, session_id=session_id, ), request_id=request_id, requesting_device_id=self.client.device_id, ) wait = timeout is None or timeout > 0 fut: Optional[asyncio.Future] = None if wait: fut = asyncio.get_running_loop().create_future() self._key_request_waiters[session_id] = fut await self.client.send_to_device( EventType.ROOM_KEY_REQUEST, { user_id: {device_id: request for device_id in devices} for user_id, devices in from_devices.items() }, ) if not wait: # Timeout is set and <=0, don't wait for keys return False assert fut is not None got_keys = False try: user_id, device_id = await asyncio.wait_for(fut, timeout=timeout) got_keys = True try: del from_devices[user_id][device_id] if len(from_devices[user_id]) == 0: del from_devices[user_id] except KeyError: pass except (asyncio.CancelledError, asyncio.TimeoutError): pass del self._key_request_waiters[session_id] if len(from_devices) > 0: cancel = RoomKeyRequestEventContent( action=KeyRequestAction.CANCEL, request_id=str(request_id), requesting_device_id=self.client.device_id, ) await self.client.send_to_device( EventType.ROOM_KEY_REQUEST, { user_id: {device_id: cancel for device_id in devices} for user_id, devices in from_devices.items() }, ) return got_keys async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None: key: ForwardedRoomKeyEventContent = evt.content if await self.crypto_store.has_group_session(key.room_id, key.session_id): self.log.debug( f"Ignoring received session {key.session_id} from {evt.sender}/" f"{evt.sender_device}, as crypto store says we have it already" ) return if not key.beeper_max_messages or not key.beeper_max_age_ms: await self._fill_encryption_info(key) key.forwarding_key_chain.append(evt.sender_key) sess = InboundGroupSession.import_session( key.session_key, key.signing_key, key.sender_key, key.room_id, key.forwarding_key_chain, max_age=key.beeper_max_age_ms, max_messages=key.beeper_max_messages, is_scheduled=key.beeper_is_scheduled, ) if key.session_id != sess.id: self.log.warning( f"Mismatched session ID while importing forwarded key from " f"{evt.sender}/{evt.sender_device}: '{key.session_id}' != '{sess.id}'" ) return await self.crypto_store.put_group_session( key.room_id, key.sender_key, key.session_id, sess ) self._mark_session_received(key.session_id) self.log.debug( f"Imported {key.session_id} for {key.room_id} " f"from {evt.sender}/{evt.sender_device}" ) try: task = self._key_request_waiters[key.session_id] except KeyError: pass else: task.set_result((evt.sender, evt.sender_device)) python-0.20.7/mautrix/crypto/key_share.py000066400000000000000000000201521473573527000204730ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional from mautrix.errors import MatrixConnectionError, MatrixError, MatrixRequestError from mautrix.types import ( DeviceIdentity, EncryptionAlgorithm, EventType, ForwardedRoomKeyEventContent, KeyRequestAction, RequestedKeyInfo, RoomKeyRequestEventContent, RoomKeyWithheldCode, RoomKeyWithheldEventContent, SessionID, ToDeviceEvent, TrustState, ) from .device_lists import DeviceListMachine from .encrypt_olm import OlmEncryptionMachine class RejectKeyShare(MatrixError): def __init__( self, log_message: str = "", code: Optional[RoomKeyWithheldCode] = None, reason: Optional[str] = None, ) -> None: """ RejectKeyShare is an error used to signal that a key share request should be rejected. Args: log_message: The message to log when rejecting the request. code: The m.room_key.withheld code, or ``None`` to reject silently. reason: The human-readable reason for the rejection. """ super().__init__(log_message) self.code = code self.reason = reason class KeySharingMachine(OlmEncryptionMachine, DeviceListMachine): async def default_allow_key_share( self, device: DeviceIdentity, request: RequestedKeyInfo ) -> bool: """ Check whether or not the given key request should be fulfilled. You can set a custom function in :attr:`allow_key_share` to override this. Args: device: The identity of the device requesting keys. request: The requested key details. Returns: ``True`` if the key share should be accepted, ``False`` if it should be silently ignored. Raises: RejectKeyShare: if the key share should be rejected. """ if device.user_id != self.client.mxid: raise RejectKeyShare( f"Ignoring key request from a different user ({device.user_id})", code=None ) elif device.device_id == self.client.device_id: raise RejectKeyShare("Ignoring key request from ourselves", code=None) elif device.trust == TrustState.BLACKLISTED: raise RejectKeyShare( f"Rejecting key request from blacklisted device {device.device_id}", code=RoomKeyWithheldCode.BLACKLISTED, reason="You have been blacklisted by this device", ) elif await self.resolve_trust(device) >= self.share_keys_min_trust: self.log.debug(f"Accepting key request from trusted device {device.device_id}") return True else: raise RejectKeyShare( f"Rejecting key request from untrusted device {device.device_id}", code=RoomKeyWithheldCode.UNVERIFIED, reason="You have not been verified by this device", ) async def handle_room_key_request( self, evt: ToDeviceEvent, raise_exceptions: bool = False ) -> None: """ Handle a ``m.room_key_request`` where the action is ``request``. This is automatically registered as an event handler and therefore called if the client you passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you do syncing in some manual way. Args: evt: The to-device event. raise_exceptions: Whether or not errors while handling should be raised. """ request: RoomKeyRequestEventContent = evt.content if request.action != KeyRequestAction.REQUEST: return elif ( evt.sender == self.client.mxid and request.requesting_device_id == self.client.device_id ): self.log.debug(f"Ignoring key request {request.request_id} from ourselves") return try: device = await self.get_or_fetch_device(evt.sender, request.requesting_device_id) except Exception: self.log.warning( f"Failed to get device {evt.sender}/{request.requesting_device_id} to " f"handle key request {request.request_id}", exc_info=True, ) if raise_exceptions: raise return if not device: self.log.warning( f"Couldn't find device {evt.sender}/{request.requesting_device_id} to " f"handle key request {request.request_id}" ) return self.log.debug( f"Received key request {request.request_id} from {device.user_id}/" f"{device.device_id} for session {request.body.session_id}" ) try: await self._handle_room_key_request(device, request.body) except RejectKeyShare as e: self.log.debug(f"Rejecting key request {request.request_id}: {e}") await self._reject_key_request(e, device, request.body) except (MatrixRequestError, MatrixConnectionError): self.log.exception( f"API error while handling key request {request.request_id} " f"(not sending rejection)" ) if raise_exceptions: raise except Exception: self.log.exception( f"Error while handling key request {request.request_id}, sending rejection..." ) error = RejectKeyShare( code=RoomKeyWithheldCode.UNAVAILABLE, reason="An internal error occurred while trying to share the requested session", ) await self._reject_key_request(error, device, request.body) if raise_exceptions: raise async def _handle_room_key_request( self, device: DeviceIdentity, request: RequestedKeyInfo ) -> None: if not await self.allow_key_share(device, request): return sess = await self.crypto_store.get_group_session(request.room_id, request.session_id) if sess is None: raise RejectKeyShare( f"Didn't find group session {request.session_id} to forward to " f"{device.user_id}/{device.device_id}", code=RoomKeyWithheldCode.UNAVAILABLE, reason="Requested session ID not found on this device", ) exported_key = sess.export_session(sess.first_known_index) forward_content = ForwardedRoomKeyEventContent( algorithm=EncryptionAlgorithm.MEGOLM_V1, room_id=sess.room_id, session_id=SessionID(sess.id), session_key=exported_key, sender_key=sess.sender_key, forwarding_key_chain=sess.forwarding_chain, signing_key=sess.signing_key, ) await self.send_encrypted_to_device(device, EventType.FORWARDED_ROOM_KEY, forward_content) async def _reject_key_request( self, rejection: RejectKeyShare, device: DeviceIdentity, request: RequestedKeyInfo ) -> None: if not rejection.code: # Silent rejection return content = RoomKeyWithheldEventContent( room_id=request.room_id, algorithm=request.algorithm, session_id=request.session_id, sender_key=request.sender_key, code=rejection.code, reason=rejection.reason, ) try: await self.client.send_to_one_device( EventType.ROOM_KEY_WITHHELD, device.user_id, device.device_id, content ) await self.client.send_to_one_device( EventType.ORG_MATRIX_ROOM_KEY_WITHHELD, device.user_id, device.device_id, content ) except MatrixError: self.log.warning( f"Failed to send key share rejection {rejection.code} " f"to {device.user_id}/{device.device_id}", exc_info=True, ) python-0.20.7/mautrix/crypto/machine.py000066400000000000000000000327761473573527000201440ustar00rootroot00000000000000# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Optional import asyncio import logging import time from mautrix import client as cli from mautrix.errors import GroupSessionWithheldError from mautrix.types import ( ASToDeviceEvent, DecryptedOlmEvent, DeviceID, DeviceLists, DeviceOTKCount, EncryptionAlgorithm, EncryptionKeyAlgorithm, EventType, Member, Membership, StateEvent, ToDeviceEvent, TrustState, UserID, ) from mautrix.util import background_task from mautrix.util.logging import TraceLogger from .account import OlmAccount from .decrypt_megolm import MegolmDecryptionMachine from .encrypt_megolm import MegolmEncryptionMachine from .key_request import KeyRequestingMachine from .key_share import KeySharingMachine from .store import CryptoStore, StateStore from .unwedge import OlmUnwedgingMachine class OlmMachine( MegolmEncryptionMachine, MegolmDecryptionMachine, OlmUnwedgingMachine, KeySharingMachine, KeyRequestingMachine, ): """ OlmMachine is the main class for handling things related to Matrix end-to-end encryption with Olm and Megolm. Users primarily need :meth:`encrypt_megolm_event`, :meth:`share_group_session`, and :meth:`decrypt_megolm_event`. Tracking device lists, establishing Olm sessions and handling Megolm group sessions is handled internally. """ client: cli.Client log: TraceLogger crypto_store: CryptoStore state_store: StateStore account: Optional[OlmAccount] def __init__( self, client: cli.Client, crypto_store: CryptoStore, state_store: StateStore, log: Optional[TraceLogger] = None, ) -> None: super().__init__() self.client = client self.log = log or logging.getLogger("mau.crypto") self.crypto_store = crypto_store self.state_store = state_store self.account = None self.send_keys_min_trust = TrustState.UNVERIFIED self.share_keys_min_trust = TrustState.CROSS_SIGNED_TOFU self.allow_key_share = self.default_allow_key_share self.delete_outbound_keys_on_ack = False self.dont_store_outbound_keys = False self.delete_previous_keys_on_receive = False self.ratchet_keys_on_decrypt = False self.delete_fully_used_keys_on_decrypt = False self.delete_keys_on_device_delete = False self.disable_device_change_key_rotation = False self._fetch_keys_lock = asyncio.Lock() self._megolm_decrypt_lock = asyncio.Lock() self._share_keys_lock = asyncio.Lock() self._last_key_share = time.monotonic() - 60 self._key_request_waiters = {} self._inbound_session_waiters = {} self._prev_unwedge = {} self._cs_fetch_attempted = set() self.client.add_event_handler( cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True ) self.client.add_event_handler(cli.InternalEventType.DEVICE_LISTS, self.handle_device_lists) self.client.add_event_handler(EventType.TO_DEVICE_ENCRYPTED, self.handle_to_device_event) self.client.add_event_handler(EventType.ROOM_KEY_REQUEST, self.handle_room_key_request) self.client.add_event_handler(EventType.BEEPER_ROOM_KEY_ACK, self.handle_beep_room_key_ack) # self.client.add_event_handler(EventType.ROOM_KEY_WITHHELD, self.handle_room_key_withheld) # self.client.add_event_handler(EventType.ORG_MATRIX_ROOM_KEY_WITHHELD, # self.handle_room_key_withheld) self.client.add_event_handler(EventType.ROOM_MEMBER, self.handle_member_event) async def load(self) -> None: """Load the Olm account into memory, or create one if the store doesn't have one stored.""" self.account = await self.crypto_store.get_account() if self.account is None: self.account = OlmAccount() await self.crypto_store.put_account(self.account) async def handle_as_otk_counts( self, otk_counts: dict[UserID, dict[DeviceID, DeviceOTKCount]] ) -> None: for user_id, devices in otk_counts.items(): for device_id, count in devices.items(): if user_id == self.client.mxid and device_id == self.client.device_id: await self.handle_otk_count(count) else: self.log.warning(f"Got OTK count for unknown device {user_id}/{device_id}") async def handle_as_device_lists(self, device_lists: DeviceLists) -> None: background_task.create(self.handle_device_lists(device_lists)) async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None: if evt.to_user_id != self.client.mxid or evt.to_device_id != self.client.device_id: self.log.warning( f"Got to-device event for unknown device {evt.to_user_id}/{evt.to_device_id}" ) return if evt.type == EventType.TO_DEVICE_ENCRYPTED: await self.handle_to_device_event(evt) elif evt.type == EventType.ROOM_KEY_REQUEST: await self.handle_room_key_request(evt) elif evt.type == EventType.BEEPER_ROOM_KEY_ACK: await self.handle_beep_room_key_ack(evt) else: self.log.debug(f"Got unknown to-device event {evt.type} from {evt.sender}") async def handle_otk_count(self, otk_count: DeviceOTKCount) -> None: """ Handle the ``device_one_time_keys_count`` data in a sync response. This is automatically registered as an event handler and therefore called if the client you passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you do syncing in some manual way. """ if otk_count.signed_curve25519 < self.account.max_one_time_keys // 2: self.log.debug( f"Sync response said we have {otk_count.signed_curve25519} signed" " curve25519 keys left, sharing new ones..." ) await self.share_keys(otk_count.signed_curve25519) async def handle_device_lists(self, device_lists: DeviceLists) -> None: """ Handle the ``device_lists`` data in a sync response. This is automatically registered as an event handler and therefore called if the client you passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you do syncing in some manual way. """ if len(device_lists.changed) > 0: async with self._fetch_keys_lock: await self._fetch_keys(device_lists.changed, include_untracked=False) async def handle_member_event(self, evt: StateEvent) -> None: """ Handle a new member event. This is automatically registered as an event handler and therefore called if the client you passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you receive events in some manual way (e.g. through appservice transactions) """ if not await self.state_store.is_encrypted(evt.room_id): return prev = evt.prev_content.membership cur = evt.content.membership ignored_changes = { Membership.INVITE: Membership.JOIN, Membership.BAN: Membership.LEAVE, Membership.LEAVE: Membership.BAN, } if prev == cur or ignored_changes.get(prev) == cur: return src = getattr(evt, "source", None) prev_cache = evt.unsigned.get("mautrix_prev_membership") if isinstance(prev_cache, Member) and prev_cache.membership == cur: self.log.debug( f"Got duplicate membership state event in {evt.room_id} changing {evt.state_key} " f"from {prev} to {cur}, cached state was {prev_cache} (event ID: {evt.event_id}, " f"sync source: {src})" ) return self.log.debug( f"Got membership state event in {evt.room_id} changing {evt.state_key} from " f"{prev} to {cur} (event ID: {evt.event_id}, sync source: {src}, " f"cached: {prev_cache.membership if prev_cache else None}), invalidating group session" ) await self.crypto_store.remove_outbound_group_session(evt.room_id) async def handle_to_device_event(self, evt: ToDeviceEvent) -> None: """ Handle an encrypted to-device event. This is automatically registered as an event handler and therefore called if the client you passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you do syncing in some manual way. """ self.log.trace( f"Handling encrypted to-device event from {evt.content.sender_key} ({evt.sender})" ) decrypted_evt = await self._decrypt_olm_event(evt) if decrypted_evt.type == EventType.ROOM_KEY: await self._receive_room_key(decrypted_evt) elif decrypted_evt.type == EventType.FORWARDED_ROOM_KEY: await self._receive_forwarded_room_key(decrypted_evt) async def _receive_room_key(self, evt: DecryptedOlmEvent) -> None: # TODO nio had a comment saying "handle this better" # for the case where evt.Keys.Ed25519 is none? if evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1 or not evt.keys.ed25519: return if not evt.content.beeper_max_messages or not evt.content.beeper_max_age_ms: await self._fill_encryption_info(evt.content) if self.delete_previous_keys_on_receive and not evt.content.beeper_is_scheduled: removed_ids = await self.crypto_store.redact_group_sessions( evt.content.room_id, evt.sender_key, reason="received new key from device" ) self.log.info(f"Redacted previous megolm sessions: {removed_ids}") await self._create_group_session( evt.sender_key, evt.keys.ed25519, evt.content.room_id, evt.content.session_id, evt.content.session_key, max_age=evt.content.beeper_max_age_ms, max_messages=evt.content.beeper_max_messages, is_scheduled=evt.content.beeper_is_scheduled, ) async def handle_beep_room_key_ack(self, evt: ToDeviceEvent) -> None: try: sess = await self.crypto_store.get_group_session( evt.content.room_id, evt.content.session_id ) except GroupSessionWithheldError: self.log.debug( f"Ignoring room key ack for session {evt.content.session_id}" " that was already redacted" ) return if not sess: self.log.debug(f"Ignoring room key ack for unknown session {evt.content.session_id}") return if ( sess.sender_key == self.account.identity_key and self.delete_outbound_keys_on_ack and evt.content.first_message_index == 0 ): self.log.debug("Redacting inbound copy of outbound group session after ack") await self.crypto_store.redact_group_session( evt.content.room_id, evt.content.session_id, reason="outbound session acked" ) else: self.log.debug(f"Received room key ack for {sess.id}") async def share_keys(self, current_otk_count: int | None = None) -> None: """ Share any keys that need to be shared. This is automatically called from :meth:`handle_otk_count`, so you should not need to call this yourself. Args: current_otk_count: The current number of signed curve25519 keys present on the server. If omitted, the count will be fetched from the server. """ async with self._share_keys_lock: await self._share_keys(current_otk_count) async def _share_keys(self, current_otk_count: int | None) -> None: if current_otk_count is None or ( # If the last key share was recent and the new count is very low, re-check the count # from the server to avoid any race conditions. self._last_key_share + 60 > time.monotonic() and current_otk_count < 10 ): self.log.debug("Checking OTK count on server") current_otk_count = (await self.client.upload_keys()).get( EncryptionKeyAlgorithm.SIGNED_CURVE25519, 0 ) device_keys = ( self.account.get_device_keys(self.client.mxid, self.client.device_id) if not self.account.shared else None ) one_time_keys = self.account.get_one_time_keys( self.client.mxid, self.client.device_id, current_otk_count ) if not device_keys and not one_time_keys: self.log.warning("No one-time keys nor device keys got when trying to share keys") return if device_keys: self.log.debug("Going to upload initial account keys") self.log.debug(f"Uploading {len(one_time_keys)} one-time keys") resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys) self.account.shared = True self._last_key_share = time.monotonic() await self.crypto_store.put_account(self.account) self.log.debug(f"Shared keys and saved account, new keys: {resp}") python-0.20.7/mautrix/crypto/sessions.py000066400000000000000000000230051473573527000203670ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import List, Optional, Set, Tuple, Union, cast from datetime import datetime, timedelta from _libolm import ffi, lib from attr import dataclass import olm from mautrix.errors import EncryptionError from mautrix.types import ( DeviceID, EncryptionAlgorithm, IdentityKey, OlmCiphertext, OlmMsgType, RoomID, RoomKeyEventContent, SerializableAttrs, SigningKey, UserID, field, ) class Session(olm.Session): creation_time: datetime last_encrypted: datetime last_decrypted: datetime def __init__(self): super().__init__() self.creation_time = datetime.now() self.last_encrypted = datetime.now() self.last_decrypted = datetime.now() def __new__(cls, *args, **kwargs): return super().__new__(cls) @property def expired(self): return False @classmethod def from_pickle( cls, pickle: bytes, passphrase: str, creation_time: datetime, last_encrypted: Optional[datetime] = None, last_decrypted: Optional[datetime] = None, ) -> "Session": session = super().from_pickle(pickle, passphrase=passphrase) session.creation_time = creation_time session.last_encrypted = last_encrypted or creation_time session.last_decrypted = last_decrypted or creation_time return session def matches(self, ciphertext: str) -> bool: return super().matches(olm.OlmPreKeyMessage(ciphertext)) def decrypt(self, ciphertext: OlmCiphertext) -> str: plaintext = super().decrypt( olm.OlmPreKeyMessage(ciphertext.body) if ciphertext.type == OlmMsgType.PREKEY else olm.OlmMessage(ciphertext.body) ) self.last_decrypted = datetime.now() return plaintext def encrypt(self, plaintext: str) -> OlmCiphertext: self.last_encrypted = datetime.now() result = super().encrypt(plaintext) return OlmCiphertext( type=( OlmMsgType.PREKEY if isinstance(result, olm.OlmPreKeyMessage) else OlmMsgType.MESSAGE ), body=result.ciphertext, ) def describe(self) -> str: parent = super() if hasattr(parent, "describe"): return parent.describe() elif hasattr(lib, "olm_session_describe"): describe_length = 600 describe_buffer = ffi.new("char[]", describe_length) lib.olm_session_describe(self._session, describe_buffer, describe_length) return ffi.string(describe_buffer).decode("utf-8") else: return "describe not supported" @dataclass class RatchetSafety(SerializableAttrs): next_index: int = 0 missed_indices: List[int] = field(factory=lambda: []) lost_indices: List[int] = field(factory=lambda: []) class InboundGroupSession(olm.InboundGroupSession): room_id: RoomID signing_key: SigningKey sender_key: IdentityKey forwarding_chain: List[IdentityKey] ratchet_safety: RatchetSafety received_at: datetime max_age: timedelta max_messages: int is_scheduled: bool def __init__( self, session_key: str, signing_key: SigningKey, sender_key: IdentityKey, room_id: RoomID, forwarding_chain: Optional[List[IdentityKey]] = None, ratchet_safety: Optional[RatchetSafety] = None, received_at: Optional[datetime] = None, max_age: Union[timedelta, int, None] = None, max_messages: Optional[int] = None, is_scheduled: bool = False, ) -> None: self.signing_key = signing_key self.sender_key = sender_key self.room_id = room_id self.forwarding_chain = forwarding_chain or [] self.ratchet_safety = ratchet_safety or RatchetSafety() self.received_at = received_at or datetime.utcnow() if isinstance(max_age, int): max_age = timedelta(milliseconds=max_age) self.max_age = max_age self.max_messages = max_messages self.is_scheduled = is_scheduled super().__init__(session_key) def __new__(cls, *args, **kwargs): return super().__new__(cls) @classmethod def from_pickle( cls, pickle: bytes, passphrase: str, signing_key: SigningKey, sender_key: IdentityKey, room_id: RoomID, forwarding_chain: Optional[List[IdentityKey]] = None, ratchet_safety: Optional[RatchetSafety] = None, received_at: Optional[datetime] = None, max_age: Optional[timedelta] = None, max_messages: Optional[int] = None, is_scheduled: bool = False, ) -> "InboundGroupSession": session = super().from_pickle(pickle, passphrase) session.signing_key = signing_key session.sender_key = sender_key session.room_id = room_id session.forwarding_chain = forwarding_chain or [] session.ratchet_safety = ratchet_safety or RatchetSafety() session.received_at = received_at session.max_age = max_age session.max_messages = max_messages session.is_scheduled = is_scheduled return session @classmethod def import_session( cls, session_key: str, signing_key: SigningKey, sender_key: IdentityKey, room_id: RoomID, forwarding_chain: Optional[List[str]] = None, ratchet_safety: Optional[RatchetSafety] = None, received_at: Optional[datetime] = None, max_age: Union[timedelta, int, None] = None, max_messages: Optional[int] = None, is_scheduled: bool = False, ) -> "InboundGroupSession": session = super().import_session(session_key) session.signing_key = signing_key session.sender_key = sender_key session.room_id = room_id session.forwarding_chain = forwarding_chain or [] session.ratchet_safety = ratchet_safety or RatchetSafety() session.received_at = received_at or datetime.utcnow() if isinstance(max_age, int): max_age = timedelta(milliseconds=max_age) session.max_age = max_age session.max_messages = max_messages session.is_scheduled = is_scheduled return session def ratchet_to(self, index: int) -> "InboundGroupSession": exported = self.export_session(index) return self.import_session( exported, signing_key=self.signing_key, sender_key=self.sender_key, room_id=self.room_id, forwarding_chain=self.forwarding_chain, ratchet_safety=self.ratchet_safety, received_at=self.received_at, max_age=self.max_age, max_messages=self.max_messages, is_scheduled=self.is_scheduled, ) class OutboundGroupSession(olm.OutboundGroupSession): """Outbound group session aware of the users it is shared with. Also remembers the time it was created and the number of messages it has encrypted, in order to know if it needs to be rotated. """ max_age: timedelta max_messages: int creation_time: datetime use_time: datetime message_count: int room_id: RoomID users_shared_with: Set[Tuple[UserID, DeviceID]] users_ignored: Set[Tuple[UserID, DeviceID]] shared: bool def __init__(self, room_id: RoomID) -> None: self.max_age = timedelta(days=7) self.max_messages = 100 self.creation_time = datetime.now() self.use_time = datetime.now() self.message_count = 0 self.room_id = room_id self.users_shared_with = set() self.users_ignored = set() self.shared = False super().__init__() def __new__(cls, *args, **kwargs): return super().__new__(cls) @property def expired(self): return ( self.message_count >= self.max_messages or datetime.now() - self.creation_time >= self.max_age ) def encrypt(self, plaintext): if not self.shared: raise EncryptionError("Group session has not been shared") if self.expired: raise EncryptionError("Group session has expired") self.message_count += 1 self.use_time = datetime.now() return super().encrypt(plaintext) @classmethod def from_pickle( cls, pickle: bytes, passphrase: str, max_age: timedelta, max_messages: int, creation_time: datetime, use_time: datetime, message_count: int, room_id: RoomID, shared: bool, ) -> "OutboundGroupSession": session = cast(OutboundGroupSession, super().from_pickle(pickle, passphrase)) session.max_age = max_age session.max_messages = max_messages session.creation_time = creation_time session.use_time = use_time session.message_count = message_count session.room_id = room_id session.users_shared_with = set() session.users_ignored = set() session.shared = shared return session @property def share_content(self) -> RoomKeyEventContent: return RoomKeyEventContent( algorithm=EncryptionAlgorithm.MEGOLM_V1, room_id=self.room_id, session_id=self.id, session_key=self.session_key, ) python-0.20.7/mautrix/crypto/store/000077500000000000000000000000001473573527000173035ustar00rootroot00000000000000python-0.20.7/mautrix/crypto/store/__init__.py000066400000000000000000000006211473573527000214130ustar00rootroot00000000000000from mautrix import __optional_imports__ from .abstract import CryptoStore, StateStore from .memory import MemoryCryptoStore try: from .asyncpg import PgCryptoStateStore, PgCryptoStore except ImportError: if __optional_imports__: raise PgCryptoStore = PgCryptoStateStore = None __all__ = ["CryptoStore", "StateStore", "MemoryCryptoStore", "PgCryptoStateStore", "PgCryptoStore"] python-0.20.7/mautrix/crypto/store/abstract.py000066400000000000000000000367231473573527000214730ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import NamedTuple from abc import ABC, abstractmethod from mautrix.types import ( CrossSigner, CrossSigningUsage, DeviceID, DeviceIdentity, EventID, IdentityKey, RoomEncryptionStateEventContent, RoomID, SessionID, SigningKey, TOFUSigningKey, UserID, ) from ..account import OlmAccount from ..sessions import InboundGroupSession, OutboundGroupSession, Session class StateStore(ABC): @abstractmethod async def is_encrypted(self, room_id: RoomID) -> bool: pass @abstractmethod async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None: pass @abstractmethod async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]: pass class CryptoStore(ABC): """ CryptoStore is used by :class:`OlmMachine` to store Olm and Megolm sessions, user device lists and message indices. """ account_id: str """The unique identifier for the account that is stored in this CryptoStore.""" pickle_key: str """The pickle key to use when pickling Olm objects.""" @abstractmethod async def get_device_id(self) -> DeviceID | None: """ Get the device ID corresponding to this account_id Returns: The device ID in the store. """ @abstractmethod async def put_device_id(self, device_id: DeviceID) -> None: """ Store a device ID. Args: device_id: The device ID to store. """ async def open(self) -> None: """ Open the store. If the store doesn't require opening any resources beforehand or only opens when flushing, this can be a no-op """ async def close(self) -> None: """ Close the store when it will no longer be used. The default implementation will simply call .flush(). If the store doesn't keep any persistent resources, the default implementation is sufficient. """ await self.flush() async def flush(self) -> None: """Flush the store. If all the methods persist data immediately, this can be a no-op.""" @abstractmethod async def delete(self) -> None: """Delete the data in the store.""" @abstractmethod async def put_account(self, account: OlmAccount) -> None: """Insert or update the OlmAccount in the store.""" @abstractmethod async def get_account(self) -> OlmAccount | None: """Get the OlmAccount that was previously inserted with :meth:`put_account`. If no account has been inserted, this must return ``None``.""" @abstractmethod async def has_session(self, key: IdentityKey) -> bool: """ Check whether or not the store has a session for a specific device. Args: key: The curve25519 identity key of the device to check. Returns: ``True`` if the session has at least one Olm session for the given identity key, ``False`` otherwise. """ @abstractmethod async def get_sessions(self, key: IdentityKey) -> list[Session]: """ Get all Olm sessions in the store for the specific device. Args: key: The curve25519 identity key of the device whose sessions to get. Returns: A list of Olm sessions for the given identity key. If the store contains no sessions, an empty list. """ @abstractmethod async def get_latest_session(self, key: IdentityKey) -> Session | None: """ Get the Olm session with the highest session ID (lexiographically sorting) for a specific device. It's usually safe to return the most recently added session if sorting by session ID is too difficult. Args: key: The curve25519 identity key of the device whose session to get. Returns: The most recent session for the given device. If the store contains no sessions, ``None``. """ @abstractmethod async def add_session(self, key: IdentityKey, session: Session) -> None: """ Insert an Olm session into the store. Args: key: The curve25519 identity key of the device with whom this session was made. session: The session itself. """ @abstractmethod async def update_session(self, key: IdentityKey, session: Session) -> None: """ Update a session in the store. Implementations may assume that the given session was previously either inserted with :meth:`add_session` or fetched with either :meth:`get_sessions` or :meth:`get_latest_session`. Args: key: The curve25519 identity key of the device with whom this session was made. session: The session itself. """ @abstractmethod async def put_group_session( self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID, session: InboundGroupSession, ) -> None: """ Insert an inbound Megolm session into the store. Args: room_id: The room ID for which this session was made. sender_key: The curve25519 identity key of the user who made this session. session_id: The unique identifier for this session. session: The session itself. """ @abstractmethod async def get_group_session( self, room_id: RoomID, session_id: SessionID ) -> InboundGroupSession | None: """ Get an inbound Megolm group session that was previously inserted with :meth:`put_group_session`. Args: room_id: The room ID for which the session was made. session_id: The unique identifier of the session. Returns: The :class:`InboundGroupSession` object, or ``None`` if not found. """ @abstractmethod async def redact_group_session( self, room_id: RoomID, session_id: SessionID, reason: str ) -> None: """ Remove the keys for a specific Megolm group session. Args: room_id: The room where the session is. session_id: The session ID to remove. reason: The reason the session is being removed. """ @abstractmethod async def redact_group_sessions( self, room_id: RoomID | None, sender_key: IdentityKey | None, reason: str ) -> list[SessionID]: """ Remove the keys for multiple Megolm group sessions, based on the room ID and/or sender device. Args: room_id: The room ID to delete keys from. sender_key: The Olm identity key of the device to delete keys from. reason: The reason why the keys are being deleted. Returns: The list of session IDs that were deleted. """ @abstractmethod async def redact_expired_group_sessions(self) -> list[SessionID]: """ Remove all Megolm group sessions where at least twice the maximum age has passed since receiving the keys. Returns: The list of session IDs that were deleted. """ @abstractmethod async def redact_outdated_group_sessions(self) -> list[SessionID]: """ Remove all Megolm group sessions which lack the metadata to determine when they should expire. Returns: The list of session IDs that were deleted. """ @abstractmethod async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: """ Check whether or not a specific inbound Megolm session is in the store. This is used before importing forwarded keys. Args: room_id: The room ID for which the session was made. session_id: The unique identifier of the session. Returns: ``True`` if the store has a session with the given ID, ``False`` otherwise. """ @abstractmethod async def add_outbound_group_session(self, session: OutboundGroupSession) -> None: """ Insert an outbound Megolm session into the store. The store should index inserted sessions by the room_id field of the session to support getting and removing sessions. There will only be one outbound session per room ID at a time. Args: session: The session itself. """ @abstractmethod async def update_outbound_group_session(self, session: OutboundGroupSession) -> None: """ Update an outbound Megolm session in the store. Implementations may assume that the given session was previously either inserted with :meth:`add_outbound_group_session` or fetched with :meth:`get_outbound_group_session`. Args: session: The session itself. """ @abstractmethod async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None: """ Get the stored outbound Megolm session from the store. Args: room_id: The room whose session to get. Returns: The :class:`OutboundGroupSession` object, or ``None`` if not found. """ @abstractmethod async def remove_outbound_group_session(self, room_id: RoomID) -> None: """ Remove the stored outbound Megolm session for a specific room. This is used when a membership change is received in a specific room. Args: room_id: The room whose session to remove. """ @abstractmethod async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None: """ Remove the stored outbound Megolm session for multiple rooms. This is used when the device list of a user changes. Args: rooms: The list of rooms whose sessions to remove. """ @abstractmethod async def validate_message_index( self, sender_key: IdentityKey, session_id: SessionID, event_id: EventID, index: int, timestamp: int, ) -> bool: """ Validate that a specific message isn't a replay attack. Implementations should store a map from ``(sender_key, session_id, index)`` to ``(event_id, timestamp)``, then use that map to check whether or not the message index is valid: * If the map key doesn't exist, the given values should be stored and the message is valid. * If the map key exists and the stored values match the given values, the message is valid. * If the map key exists, but the stored values do not match the given values, the message is not valid. Args: sender_key: The curve25519 identity key of the user who sent the message. session_id: The Megolm session ID for the session with which the message was encrypted. event_id: The event ID of the message. index: The Megolm message index of the message. timestamp: The timestamp of the message. Returns: ``True`` if the message is valid, ``False`` if not. """ @abstractmethod async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None: """ Get all devices for a given user. Args: user_id: The ID of the user whose devices to get. Returns: If there has been a previous call to :meth:`put_devices` with the same user ID (even with an empty dict), a dict from device ID to :class:`DeviceIdentity` object. Otherwise, ``None``. """ @abstractmethod async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None: """ Get a specific device identity. Args: user_id: The ID of the user whose device to get. device_id: The ID of the device to get. Returns: The :class:`DeviceIdentity` object, or ``None`` if not found. """ @abstractmethod async def find_device_by_key( self, user_id: UserID, identity_key: IdentityKey ) -> DeviceIdentity | None: """ Find a specific device identity based on the identity key. Args: user_id: The ID of the user whose device to get. identity_key: The identity key of the device to get. Returns: The :class:`DeviceIdentity` object, or ``None`` if not found. """ @abstractmethod async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None: """ Replace the stored device list for a specific user. Args: user_id: The ID of the user whose device list to update. devices: A dict from device ID to :class:`DeviceIdentity` object. The dict may be empty. """ @abstractmethod async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]: """ Filter a list of user IDs to only include users whose device lists are being tracked. Args: users: The list of user IDs to filter. Returns: A filtered version of the input list that only includes users who have had a previous call to :meth:`put_devices` (even if the call was with an empty dict). """ @abstractmethod async def put_cross_signing_key( self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey ) -> None: """ Store a single cross-signing key. Args: user_id: The user whose cross-signing key is being stored. usage: The type of key being stored. key: The key itself. """ @abstractmethod async def get_cross_signing_keys( self, user_id: UserID ) -> dict[CrossSigningUsage, TOFUSigningKey]: """ Retrieve stored cross-signing keys for a specific user. Args: user_id: The user whose cross-signing keys to get. Returns: A map from the type of key to a tuple containing the current key and the key that was seen first. If the keys are different, it should be treated as a local TOFU violation. """ @abstractmethod async def put_signature( self, target: CrossSigner, signer: CrossSigner, signature: str ) -> None: """ Store a signature for a given key from a given key. Args: target: The user ID and key being signed. signer: The user ID and key who are doing the signing. signature: The signature. """ @abstractmethod async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool: """ Check if a given key is signed by the given signer. Args: target: The key to check. signer: The signer who is expected to have signed the key. Returns: ``True`` if the database contains a signature for the key, ``False`` otherwise. """ @abstractmethod async def drop_signatures_by_key(self, signer: CrossSigner) -> int: """ Delete signatures made by the given key. Args: signer: The key whose signatures to delete. Returns: The number of signatures deleted. """ python-0.20.7/mautrix/crypto/store/asyncpg/000077500000000000000000000000001473573527000207475ustar00rootroot00000000000000python-0.20.7/mautrix/crypto/store/asyncpg/__init__.py000066400000000000000000000001501473573527000230540ustar00rootroot00000000000000from .store import PgCryptoStateStore, PgCryptoStore __all__ = ["PgCryptoStore", "PgCryptoStateStore"] python-0.20.7/mautrix/crypto/store/asyncpg/store.py000066400000000000000000000666141473573527000224720ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from collections import defaultdict from datetime import timedelta from asyncpg import UniqueViolationError from mautrix.client.state_store import SyncStore from mautrix.client.state_store.asyncpg import PgStateStore from mautrix.errors import GroupSessionWithheldError from mautrix.types import ( CrossSigner, CrossSigningUsage, DeviceID, DeviceIdentity, EventID, IdentityKey, RoomID, RoomKeyWithheldCode, SessionID, SigningKey, SyncToken, TOFUSigningKey, TrustState, UserID, ) from mautrix.util.async_db import Database, Scheme from mautrix.util.logging import TraceLogger from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, RatchetSafety, Session from ..abstract import CryptoStore, StateStore from .upgrade import upgrade_table try: from sqlite3 import IntegrityError, sqlite_version_info as sqlite_version from aiosqlite import Cursor except ImportError: Cursor = None sqlite_version = (0, 0, 0) class IntegrityError(Exception): pass class PgCryptoStateStore(PgStateStore, StateStore): """ This class ensures that the PgStateStore in the client module implements the StateStore methods needed by the crypto module. """ class PgCryptoStore(CryptoStore, SyncStore): upgrade_table = upgrade_table db: Database account_id: str pickle_key: str log: TraceLogger _sync_token: SyncToken | None _device_id: DeviceID | None _account: OlmAccount | None _olm_cache: dict[IdentityKey, dict[SessionID, Session]] def __init__(self, account_id: str, pickle_key: str, db: Database) -> None: self.db = db self.account_id = account_id self.pickle_key = pickle_key self.log = db.log self._sync_token = None self._device_id = DeviceID("") self._account = None self._olm_cache = defaultdict(lambda: {}) async def delete(self) -> None: tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session") async with self.db.acquire() as conn, conn.transaction(): for table in tables: await conn.execute(f"DELETE FROM {table} WHERE account_id=$1", self.account_id) async def get_device_id(self) -> DeviceID | None: q = "SELECT device_id FROM crypto_account WHERE account_id=$1" device_id = await self.db.fetchval(q, self.account_id) self._device_id = device_id or self._device_id return self._device_id async def put_device_id(self, device_id: DeviceID) -> None: q = "UPDATE crypto_account SET device_id=$1 WHERE account_id=$2" await self.db.fetchval(q, device_id, self.account_id) self._device_id = device_id async def put_next_batch(self, next_batch: SyncToken) -> None: self._sync_token = next_batch q = "UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2" await self.db.execute(q, self._sync_token, self.account_id) async def get_next_batch(self) -> SyncToken: if self._sync_token is None: q = "SELECT sync_token FROM crypto_account WHERE account_id=$1" self._sync_token = await self.db.fetchval(q, self.account_id) return self._sync_token async def put_account(self, account: OlmAccount) -> None: self._account = account pickle = account.pickle(self.pickle_key) q = """ INSERT INTO crypto_account (account_id, device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account """ await self.db.execute( q, self.account_id, self._device_id or "", account.shared, self._sync_token or "", pickle, ) async def get_account(self) -> OlmAccount: if self._account is None: q = "SELECT shared, account, device_id FROM crypto_account WHERE account_id=$1" row = await self.db.fetchrow(q, self.account_id) if row is not None: self._account = OlmAccount.from_pickle( row["account"], passphrase=self.pickle_key, shared=row["shared"] ) return self._account async def has_session(self, key: IdentityKey) -> bool: if len(self._olm_cache[key]) > 0: return True q = "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2" val = await self.db.fetchval(q, key, self.account_id) return val is not None async def get_sessions(self, key: IdentityKey) -> list[Session]: q = """ SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC """ rows = await self.db.fetch(q, key, self.account_id) sessions = [] for row in rows: try: sess = self._olm_cache[key][row["session_id"]] except KeyError: sess = Session.from_pickle( row["session"], passphrase=self.pickle_key, creation_time=row["created_at"], last_encrypted=row["last_encrypted"], last_decrypted=row["last_decrypted"], ) self._olm_cache[key][SessionID(sess.id)] = sess sessions.append(sess) return sessions async def get_latest_session(self, key: IdentityKey) -> Session | None: q = """ SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1 """ row = await self.db.fetchrow(q, key, self.account_id) if row is None: return None try: return self._olm_cache[key][row["session_id"]] except KeyError: sess = Session.from_pickle( row["session"], passphrase=self.pickle_key, creation_time=row["created_at"], last_encrypted=row["last_encrypted"], last_decrypted=row["last_decrypted"], ) self._olm_cache[key][SessionID(sess.id)] = sess return sess async def add_session(self, key: IdentityKey, session: Session) -> None: if session.id in self._olm_cache[key]: self.log.warning(f"Cache already contains Olm session with ID {session.id}") self._olm_cache[key][SessionID(session.id)] = session pickle = session.pickle(self.pickle_key) q = """ INSERT INTO crypto_olm_session ( session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id ) VALUES ($1, $2, $3, $4, $5, $6, $7) """ await self.db.execute( q, session.id, key, pickle, session.creation_time, session.last_encrypted, session.last_decrypted, self.account_id, ) async def update_session(self, key: IdentityKey, session: Session) -> None: try: assert self._olm_cache[key][SessionID(session.id)] == session except (KeyError, AssertionError) as e: self.log.warning( f"Cached olm session with ID {session.id} " f"isn't equal to the one being saved to the database ({e})" ) pickle = session.pickle(self.pickle_key) q = """ UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5 """ await self.db.execute( q, pickle, session.last_encrypted, session.last_decrypted, session.id, self.account_id ) async def put_group_session( self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID, session: InboundGroupSession, ) -> None: pickle = session.pickle(self.pickle_key) forwarding_chains = ",".join(session.forwarding_chain) q = """ INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled """ try: await self.db.execute( q, session_id, sender_key, session.signing_key, room_id, pickle, forwarding_chains, session.ratchet_safety.json(), session.received_at, int(session.max_age.total_seconds() * 1000) if session.max_age else None, session.max_messages, session.is_scheduled, self.account_id, ) except (IntegrityError, UniqueViolationError): self.log.exception(f"Failed to insert megolm session {session_id}") async def get_group_session( self, room_id: RoomID, session_id: SessionID ) -> InboundGroupSession | None: q = """ SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, ratchet_safety, received_at, max_age, max_messages, is_scheduled FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3 """ row = await self.db.fetchrow(q, room_id, session_id, self.account_id) if row is None: return None if row["withheld_code"] is not None: raise GroupSessionWithheldError(session_id, row["withheld_code"]) forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else [] return InboundGroupSession.from_pickle( row["session"], passphrase=self.pickle_key, signing_key=row["signing_key"], sender_key=row["sender_key"], room_id=room_id, forwarding_chain=forwarding_chain, ratchet_safety=RatchetSafety.parse_json(row["ratchet_safety"] or "{}"), received_at=row["received_at"], max_age=timedelta(milliseconds=row["max_age"]) if row["max_age"] else None, max_messages=row["max_messages"], is_scheduled=row["is_scheduled"], ) async def redact_group_session( self, room_id: RoomID, session_id: SessionID, reason: str ) -> None: q = """ UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL """ await self.db.execute( q, RoomKeyWithheldCode.BEEPER_REDACTED.value, f"Session redacted: {reason}", session_id, self.account_id, ) async def redact_group_sessions( self, room_id: RoomID, sender_key: IdentityKey, reason: str ) -> list[SessionID]: if not room_id and not sender_key: raise ValueError("Either room_id or sender_key must be provided") q = """ UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5 AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL RETURNING session_id """ rows = await self.db.fetch( q, RoomKeyWithheldCode.BEEPER_REDACTED.value, f"Session redacted: {reason}", room_id, sender_key, self.account_id, ) return [row["session_id"] for row in rows] async def redact_expired_group_sessions(self) -> list[SessionID]: if self.db.scheme == Scheme.SQLITE: q = """ UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL and max_age IS NOT NULL AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now')) RETURNING session_id """ elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): q = """ UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL and max_age IS NOT NULL AND received_at + 2 * (max_age * interval '1 millisecond') < now() RETURNING session_id """ else: raise RuntimeError(f"Unsupported dialect {self.db.scheme}") rows = await self.db.fetch( q, RoomKeyWithheldCode.BEEPER_REDACTED.value, f"Session redacted: expired", self.account_id, ) return [row["session_id"] for row in rows] async def redact_outdated_group_sessions(self) -> list[SessionID]: q = """ UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL RETURNING session_id """ rows = await self.db.fetch( q, RoomKeyWithheldCode.BEEPER_REDACTED.value, f"Session redacted: outdated", self.account_id, ) return [row["session_id"] for row in rows] async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: q = """ SELECT COUNT(session) FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL """ count = await self.db.fetchval(q, room_id, session_id, self.account_id) return count > 0 async def add_outbound_group_session(self, session: OutboundGroupSession) -> None: pickle = session.pickle(self.pickle_key) max_age = int(session.max_age.total_seconds() * 1000) q = """ INSERT INTO crypto_megolm_outbound_session ( room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ON CONFLICT (account_id, room_id) DO UPDATE SET session_id=excluded.session_id, session=excluded.session, shared=excluded.shared, max_messages=excluded.max_messages, message_count=excluded.message_count, max_age=excluded.max_age, created_at=excluded.created_at, last_used=excluded.last_used """ await self.db.execute( q, session.room_id, session.id, pickle, session.shared, session.max_messages, session.message_count, max_age, session.creation_time, session.use_time, self.account_id, ) async def update_outbound_group_session(self, session: OutboundGroupSession) -> None: pickle = session.pickle(self.pickle_key) q = """ UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6 """ await self.db.execute( q, pickle, session.message_count, session.use_time, session.room_id, session.id, self.account_id, ) async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None: q = """ SELECT room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2 """ row = await self.db.fetchrow(q, room_id, self.account_id) if row is None: return None return OutboundGroupSession.from_pickle( row["session"], passphrase=self.pickle_key, room_id=row["room_id"], shared=row["shared"], max_messages=row["max_messages"], message_count=row["message_count"], max_age=timedelta(milliseconds=row["max_age"]), use_time=row["last_used"], creation_time=row["created_at"], ) async def remove_outbound_group_session(self, room_id: RoomID) -> None: q = "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2" await self.db.execute(q, room_id, self.account_id) async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None: if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): q = """ DELETE FROM crypto_megolm_outbound_session WHERE account_id=$1 AND room_id=ANY($2) """ await self.db.execute(q, self.account_id, rooms) else: params = ",".join(["?"] * len(rooms)) q = f""" DELETE FROM crypto_megolm_outbound_session WHERE account_id=? AND room_id IN ({params}) """ await self.db.execute(q, self.account_id, *rooms) _validate_message_index_query = """ INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5) -- have to update something so that RETURNING * always returns the row ON CONFLICT (sender_key, session_id, "index") DO UPDATE SET sender_key=excluded.sender_key RETURNING * """ async def validate_message_index( self, sender_key: IdentityKey, session_id: SessionID, event_id: EventID, index: int, timestamp: int, ) -> bool: if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH) or ( # RETURNING was added in SQLite 3.35.0 https://www.sqlite.org/lang_returning.html self.db.scheme == Scheme.SQLITE and sqlite_version >= (3, 35) ): row = await self.db.fetchrow( self._validate_message_index_query, sender_key, session_id, index, event_id, timestamp, ) return row["event_id"] == event_id and row["timestamp"] == timestamp else: row = await self.db.fetchrow( "SELECT event_id, timestamp FROM crypto_message_index " 'WHERE sender_key=$1 AND session_id=$2 AND "index"=$3', sender_key, session_id, index, ) if row is not None: return row["event_id"] == event_id and row["timestamp"] == timestamp await self.db.execute( "INSERT INTO crypto_message_index(sender_key, session_id, " ' "index", event_id, timestamp) ' "VALUES ($1, $2, $3, $4, $5)", sender_key, session_id, index, event_id, timestamp, ) return True async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None: q = "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1" tracked_user_id = await self.db.fetchval(q, user_id) if tracked_user_id is None: return None q = """ SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 """ rows = await self.db.fetch(q, user_id) result = {} for row in rows: result[row["device_id"]] = DeviceIdentity( user_id=user_id, device_id=row["device_id"], identity_key=row["identity_key"], signing_key=row["signing_key"], trust=TrustState(row["trust"]), deleted=row["deleted"], name=row["name"], ) return result async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None: q = """ SELECT identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND device_id=$2 """ row = await self.db.fetchrow(q, user_id, device_id) if row is None: return None return DeviceIdentity( user_id=user_id, device_id=device_id, name=row["name"], identity_key=row["identity_key"], signing_key=row["signing_key"], trust=TrustState(row["trust"]), deleted=row["deleted"], ) async def find_device_by_key( self, user_id: UserID, identity_key: IdentityKey ) -> DeviceIdentity | None: q = """ SELECT device_id, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND identity_key=$2 """ row = await self.db.fetchrow( q, user_id, identity_key, ) if row is None: return None return DeviceIdentity( user_id=user_id, device_id=row["device_id"], name=row["name"], identity_key=identity_key, signing_key=row["signing_key"], trust=TrustState(row["trust"]), deleted=row["deleted"], ) async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None: data = [ ( user_id, device_id, identity.identity_key, identity.signing_key, identity.trust, identity.deleted, identity.name, ) for device_id, identity in devices.items() ] columns = [ "user_id", "device_id", "identity_key", "signing_key", "trust", "deleted", "name", ] async with self.db.acquire() as conn, conn.transaction(): q = """ INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING """ await conn.execute(q, user_id) await conn.execute("DELETE FROM crypto_device WHERE user_id=$1", user_id) if self.db.scheme == Scheme.POSTGRES: await conn.copy_records_to_table("crypto_device", records=data, columns=columns) else: q = """ INSERT INTO crypto_device ( user_id, device_id, identity_key, signing_key, trust, deleted, name ) VALUES ($1, $2, $3, $4, $5, $6, $7) """ await conn.executemany(q, data) async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]: if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): q = "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)" rows = await self.db.fetch(q, users) else: params = ",".join(["?"] * len(users)) q = f"SELECT user_id FROM crypto_tracked_user WHERE user_id IN ({params})" rows = await self.db.fetch(q, *users) return [row["user_id"] for row in rows] async def put_cross_signing_key( self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey ) -> None: q = """ INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key) VALUES ($1, $2, $3, $4) ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key """ try: await self.db.execute(q, user_id, usage.value, key, key) except Exception: self.log.exception(f"Failed to store cross-signing key {user_id}/{key}/{usage}") async def get_cross_signing_keys( self, user_id: UserID ) -> dict[CrossSigningUsage, TOFUSigningKey]: q = "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1" return { CrossSigningUsage(row["usage"]): TOFUSigningKey( key=SigningKey(row["key"]), first=SigningKey(row["first_seen_key"]), ) for row in await self.db.fetch(q, user_id) } async def put_signature( self, target: CrossSigner, signer: CrossSigner, signature: str ) -> None: q = """ INSERT INTO crypto_cross_signing_signatures ( signed_user_id, signed_key, signer_user_id, signer_key, signature ) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) DO UPDATE SET signature=excluded.signature """ signed_user_id, signed_key = target signer_user_id, signer_key = signer try: await self.db.execute( q, signed_user_id, signed_key, signer_user_id, signer_key, signature ) except Exception: self.log.exception( f"Failed to store signature from {signer_user_id}/{signer_key} " f"for {signed_user_id}/{signed_key}" ) async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool: q = """ SELECT EXISTS( SELECT 1 FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4 ) """ signed_user_id, signed_key = target signer_user_id, signer_key = signer return await self.db.fetchval(q, signed_user_id, signed_key, signer_user_id, signer_key) async def drop_signatures_by_key(self, signer: CrossSigner) -> int: signer_user_id, signer_key = signer q = "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2" try: res = await self.db.execute(q, signer_user_id, signer_key) except Exception: self.log.exception( f"Failed to drop old signatures made by replaced key {signer_user_id}/{signer_key}" ) return -1 if Cursor is not None and isinstance(res, Cursor): return res.rowcount elif ( isinstance(res, str) and res.startswith("DELETE ") and (intPart := res[len("DELETE ") :]).isdecimal() ): return int(intPart) return -1 python-0.20.7/mautrix/crypto/store/asyncpg/upgrade.py000066400000000000000000000430031473573527000227500ustar00rootroot00000000000000# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import logging from mautrix.util.async_db import Connection, Scheme, UpgradeTable upgrade_table = UpgradeTable( version_table_name="crypto_version", database_name="crypto store", log=logging.getLogger("mau.crypto.db.upgrade"), ) @upgrade_table.register(description="Latest revision", upgrades_to=10) async def upgrade_blank_to_latest(conn: Connection) -> None: await conn.execute( """CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL )""" ) await conn.execute( """CREATE TABLE IF NOT EXISTS crypto_message_index ( sender_key CHAR(43), session_id CHAR(43), "index" INTEGER, event_id TEXT NOT NULL, timestamp BIGINT NOT NULL, PRIMARY KEY (sender_key, session_id, "index") )""" ) await conn.execute( """CREATE TABLE IF NOT EXISTS crypto_tracked_user ( user_id TEXT PRIMARY KEY )""" ) await conn.execute( """CREATE TABLE IF NOT EXISTS crypto_device ( user_id TEXT, device_id TEXT, identity_key CHAR(43) NOT NULL, signing_key CHAR(43) NOT NULL, trust SMALLINT NOT NULL, deleted BOOLEAN NOT NULL, name TEXT NOT NULL, PRIMARY KEY (user_id, device_id) )""" ) await conn.execute( """CREATE TABLE IF NOT EXISTS crypto_olm_session ( account_id TEXT, session_id CHAR(43), sender_key CHAR(43) NOT NULL, session bytea NOT NULL, created_at timestamp NOT NULL, last_decrypted timestamp NOT NULL, last_encrypted timestamp NOT NULL, PRIMARY KEY (account_id, session_id) )""" ) await conn.execute( """CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( account_id TEXT, session_id CHAR(43), sender_key CHAR(43) NOT NULL, signing_key CHAR(43), room_id TEXT NOT NULL, session bytea, forwarding_chains TEXT, withheld_code TEXT, withheld_reason TEXT, ratchet_safety jsonb, received_at timestamp, max_age BIGINT, max_messages INTEGER, is_scheduled BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY (account_id, session_id) )""" ) await conn.execute( """CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, room_id TEXT, session_id CHAR(43) NOT NULL UNIQUE, session bytea NOT NULL, shared BOOLEAN NOT NULL, max_messages INTEGER NOT NULL, message_count INTEGER NOT NULL, max_age BIGINT NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) )""" ) await conn.execute( """CREATE TABLE crypto_cross_signing_keys ( user_id TEXT, usage TEXT, key CHAR(43) NOT NULL, first_seen_key CHAR(43) NOT NULL, PRIMARY KEY (user_id, usage) )""" ) await conn.execute( """CREATE TABLE crypto_cross_signing_signatures ( signed_user_id TEXT, signed_key TEXT, signer_user_id TEXT, signer_key TEXT, signature CHAR(88) NOT NULL, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) )""" ) @upgrade_table.register(description="Add account_id primary key column") async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: if scheme == Scheme.SQLITE: await conn.execute("DROP TABLE crypto_account") await conn.execute("DROP TABLE crypto_olm_session") await conn.execute("DROP TABLE crypto_megolm_inbound_session") await conn.execute("DROP TABLE crypto_megolm_outbound_session") await conn.execute( """CREATE TABLE crypto_account ( account_id VARCHAR(255) PRIMARY KEY, device_id VARCHAR(255) NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL )""" ) await conn.execute( """CREATE TABLE crypto_olm_session ( account_id VARCHAR(255), session_id CHAR(43), sender_key CHAR(43) NOT NULL, session bytea NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, session_id) )""" ) await conn.execute( """CREATE TABLE crypto_megolm_inbound_session ( account_id VARCHAR(255), session_id CHAR(43), sender_key CHAR(43) NOT NULL, signing_key CHAR(43) NOT NULL, room_id VARCHAR(255) NOT NULL, session bytea NOT NULL, forwarding_chains TEXT NOT NULL, PRIMARY KEY (account_id, session_id) )""" ) await conn.execute( """CREATE TABLE crypto_megolm_outbound_session ( account_id VARCHAR(255), room_id VARCHAR(255), session_id CHAR(43) NOT NULL UNIQUE, session bytea NOT NULL, shared BOOLEAN NOT NULL, max_messages INTEGER NOT NULL, message_count INTEGER NOT NULL, max_age BIGINT NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) )""" ) else: async def add_account_id_column(table: str, pkey_columns: list[str]) -> None: await conn.execute(f"ALTER TABLE {table} ADD COLUMN account_id VARCHAR(255)") await conn.execute(f"UPDATE {table} SET account_id=''") await conn.execute(f"ALTER TABLE {table} ALTER COLUMN account_id SET NOT NULL") await conn.execute(f"ALTER TABLE {table} DROP CONSTRAINT {table}_pkey") pkey_columns.append("account_id") pkey_columns_str = ", ".join(f'"{col}"' for col in pkey_columns) await conn.execute( f"ALTER TABLE {table} ADD CONSTRAINT {table}_pkey " f"PRIMARY KEY ({pkey_columns_str})" ) await add_account_id_column("crypto_account", []) await add_account_id_column("crypto_olm_session", ["session_id"]) await add_account_id_column("crypto_megolm_inbound_session", ["session_id"]) await add_account_id_column("crypto_megolm_outbound_session", ["room_id"]) @upgrade_table.register(description="Stop using size-limited string fields") async def upgrade_v3(conn: Connection, scheme: Scheme) -> None: if scheme == Scheme.SQLITE: return await conn.execute("ALTER TABLE crypto_account ALTER COLUMN account_id TYPE TEXT") await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id TYPE TEXT") await conn.execute("ALTER TABLE crypto_message_index ALTER COLUMN event_id TYPE TEXT") await conn.execute("ALTER TABLE crypto_tracked_user ALTER COLUMN user_id TYPE TEXT") await conn.execute("ALTER TABLE crypto_device ALTER COLUMN user_id TYPE TEXT") await conn.execute("ALTER TABLE crypto_device ALTER COLUMN device_id TYPE TEXT") await conn.execute("ALTER TABLE crypto_device ALTER COLUMN name TYPE TEXT") await conn.execute("ALTER TABLE crypto_olm_session ALTER COLUMN account_id TYPE TEXT") await conn.execute( "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN account_id TYPE TEXT" ) await conn.execute("ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN room_id TYPE TEXT") await conn.execute( "ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN account_id TYPE TEXT" ) await conn.execute("ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN room_id TYPE TEXT") @upgrade_table.register(description="Split last_used into last_encrypted and last_decrypted") async def upgrade_v4(conn: Connection, scheme: Scheme) -> None: await conn.execute("ALTER TABLE crypto_olm_session RENAME COLUMN last_used TO last_decrypted") await conn.execute("ALTER TABLE crypto_olm_session ADD COLUMN last_encrypted timestamp") await conn.execute("UPDATE crypto_olm_session SET last_encrypted=last_decrypted") if scheme == Scheme.POSTGRES: # This is too hard to do on sqlite, so let's just do it on postgres await conn.execute( "ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL" ) @upgrade_table.register(description="Add cross-signing key and signature caches") async def upgrade_v5(conn: Connection) -> None: await conn.execute( """CREATE TABLE crypto_cross_signing_keys ( user_id TEXT, usage TEXT, key CHAR(43), first_seen_key CHAR(43), PRIMARY KEY (user_id, usage) )""" ) await conn.execute( """CREATE TABLE crypto_cross_signing_signatures ( signed_user_id TEXT, signed_key TEXT, signer_user_id TEXT, signer_key TEXT, signature TEXT, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) )""" ) @upgrade_table.register(description="Update trust state values") async def upgrade_v6(conn: Connection) -> None: await conn.execute("UPDATE crypto_device SET trust=300 WHERE trust=1") # verified await conn.execute("UPDATE crypto_device SET trust=-100 WHERE trust=2") # blacklisted await conn.execute("UPDATE crypto_device SET trust=0 WHERE trust=3") # ignored -> unset @upgrade_table.register( description="Synchronize schema with mautrix-go", upgrades_to=9, transaction=False ) async def upgrade_v9(conn: Connection, scheme: Scheme) -> None: if scheme == Scheme.POSTGRES: async with conn.transaction(): await upgrade_v9_postgres(conn) else: await upgrade_v9_sqlite(conn) # These two are never used because the previous one jumps from 6 to 9. @upgrade_table.register async def upgrade_noop_7_to_8(_: Connection) -> None: pass @upgrade_table.register async def upgrade_noop_8_to_9(_: Connection) -> None: pass async def upgrade_v9_postgres(conn: Connection) -> None: await conn.execute("UPDATE crypto_account SET device_id='' WHERE device_id IS NULL") await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id SET NOT NULL") await conn.execute( "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN signing_key DROP NOT NULL" ) await conn.execute( "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN session DROP NOT NULL" ) await conn.execute( "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN forwarding_chains DROP NOT NULL" ) await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_code TEXT") await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_reason TEXT") await conn.execute("DELETE FROM crypto_cross_signing_keys WHERE key IS NULL") await conn.execute( "UPDATE crypto_cross_signing_keys SET first_seen_key=key WHERE first_seen_key IS NULL" ) await conn.execute("ALTER TABLE crypto_cross_signing_keys ALTER COLUMN key SET NOT NULL") await conn.execute( "ALTER TABLE crypto_cross_signing_keys ALTER COLUMN first_seen_key SET NOT NULL" ) await conn.execute("DELETE FROM crypto_cross_signing_signatures WHERE signature IS NULL") await conn.execute( "ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature SET NOT NULL" ) await conn.execute( "ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN max_age TYPE BIGINT " "USING (EXTRACT(EPOCH from max_age)*1000)::bigint" ) async def upgrade_v9_sqlite(conn: Connection) -> None: await conn.execute("PRAGMA foreign_keys = OFF") async with conn.transaction(): await conn.execute( """CREATE TABLE new_crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL )""" ) await conn.execute( """ INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account) SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account FROM crypto_account """ ) await conn.execute("DROP TABLE crypto_account") await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account") await conn.execute( """CREATE TABLE new_crypto_megolm_inbound_session ( account_id TEXT, session_id CHAR(43), sender_key CHAR(43) NOT NULL, signing_key CHAR(43), room_id TEXT NOT NULL, session bytea, forwarding_chains TEXT, withheld_code TEXT, withheld_reason TEXT, PRIMARY KEY (account_id, session_id) )""" ) await conn.execute( """ INSERT INTO new_crypto_megolm_inbound_session ( account_id, session_id, sender_key, signing_key, room_id, session, forwarding_chains ) SELECT account_id, session_id, sender_key, signing_key, room_id, session, forwarding_chains FROM crypto_megolm_inbound_session """ ) await conn.execute("DROP TABLE crypto_megolm_inbound_session") await conn.execute( "ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session" ) await conn.execute("UPDATE crypto_megolm_outbound_session SET max_age=max_age*1000") await conn.execute( """CREATE TABLE new_crypto_cross_signing_keys ( user_id TEXT, usage TEXT, key CHAR(43) NOT NULL, first_seen_key CHAR(43) NOT NULL, PRIMARY KEY (user_id, usage) )""" ) await conn.execute( """ INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key) SELECT user_id, usage, key, COALESCE(first_seen_key, key) FROM crypto_cross_signing_keys WHERE key IS NOT NULL """ ) await conn.execute("DROP TABLE crypto_cross_signing_keys") await conn.execute( "ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys" ) await conn.execute( """CREATE TABLE new_crypto_cross_signing_signatures ( signed_user_id TEXT, signed_key TEXT, signer_user_id TEXT, signer_key TEXT, signature CHAR(88) NOT NULL, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) )""" ) await conn.execute( """ INSERT INTO new_crypto_cross_signing_signatures ( signed_user_id, signed_key, signer_user_id, signer_key, signature ) SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature FROM crypto_cross_signing_signatures WHERE signature IS NOT NULL """ ) await conn.execute("DROP TABLE crypto_cross_signing_signatures") await conn.execute( "ALTER TABLE new_crypto_cross_signing_signatures " "RENAME TO crypto_cross_signing_signatures" ) await conn.execute("PRAGMA foreign_key_check") await conn.execute("PRAGMA foreign_keys = ON") @upgrade_table.register( description="Add metadata for detecting when megolm sessions are safe to delete" ) async def upgrade_v10(conn: Connection) -> None: await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb") await conn.execute( "ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp" ) await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT") await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER") await conn.execute( "ALTER TABLE crypto_megolm_inbound_session " "ADD COLUMN is_scheduled BOOLEAN NOT NULL DEFAULT false" ) python-0.20.7/mautrix/crypto/store/memory.py000066400000000000000000000175451473573527000212010ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from mautrix.client.state_store import SyncStore from mautrix.types import ( CrossSigner, CrossSigningUsage, DeviceID, DeviceIdentity, EventID, IdentityKey, RoomID, SessionID, SigningKey, SyncToken, TOFUSigningKey, UserID, ) from ..account import OlmAccount from ..sessions import InboundGroupSession, OutboundGroupSession, Session from .abstract import CryptoStore class MemoryCryptoStore(CryptoStore, SyncStore): _device_id: DeviceID | None _sync_token: SyncToken | None _account: OlmAccount | None _message_indices: dict[tuple[IdentityKey, SessionID, int], tuple[EventID, int]] _devices: dict[UserID, dict[DeviceID, DeviceIdentity]] _olm_sessions: dict[IdentityKey, list[Session]] _inbound_sessions: dict[tuple[RoomID, SessionID], InboundGroupSession] _outbound_sessions: dict[RoomID, OutboundGroupSession] _signatures: dict[CrossSigner, dict[CrossSigner, str]] _cross_signing_keys: dict[UserID, dict[CrossSigningUsage, TOFUSigningKey]] def __init__(self, account_id: str, pickle_key: str) -> None: self.account_id = account_id self.pickle_key = pickle_key self._sync_token = None self._device_id = None self._account = None self._message_indices = {} self._devices = {} self._olm_sessions = {} self._inbound_sessions = {} self._outbound_sessions = {} self._signatures = {} self._cross_signing_keys = {} async def get_device_id(self) -> DeviceID | None: return self._device_id async def put_device_id(self, device_id: DeviceID) -> None: self._device_id = device_id async def put_next_batch(self, next_batch: SyncToken) -> None: self._sync_token = next_batch async def get_next_batch(self) -> SyncToken: return self._sync_token async def delete(self) -> None: self._account = None self._device_id = None self._olm_sessions = {} self._outbound_sessions = {} async def put_account(self, account: OlmAccount) -> None: self._account = account async def get_account(self) -> OlmAccount: return self._account async def has_session(self, key: IdentityKey) -> bool: return key in self._olm_sessions async def get_sessions(self, key: IdentityKey) -> list[Session]: return self._olm_sessions.get(key, []) async def get_latest_session(self, key: IdentityKey) -> Session | None: try: return self._olm_sessions[key][-1] except (KeyError, IndexError): return None async def add_session(self, key: IdentityKey, session: Session) -> None: self._olm_sessions.setdefault(key, []).append(session) async def update_session(self, key: IdentityKey, session: Session) -> None: # This is a no-op as the session object is the same one previously added. pass async def put_group_session( self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID, session: InboundGroupSession, ) -> None: self._inbound_sessions[(room_id, session_id)] = session async def get_group_session( self, room_id: RoomID, session_id: SessionID ) -> InboundGroupSession: return self._inbound_sessions.get((room_id, session_id)) async def redact_group_session( self, room_id: RoomID, session_id: SessionID, reason: str ) -> None: self._inbound_sessions.pop((room_id, session_id), None) async def redact_group_sessions( self, room_id: RoomID, sender_key: IdentityKey, reason: str ) -> list[SessionID]: if not room_id and not sender_key: raise ValueError("Either room_id or sender_key must be provided") deleted = [] keys = list(self._inbound_sessions.keys()) for key in keys: item = self._inbound_sessions[key] if (not room_id or item.room_id == room_id) and ( not sender_key or item.sender_key == sender_key ): deleted.append(SessionID(item.id)) del self._inbound_sessions[key] return deleted async def redact_expired_group_sessions(self) -> list[SessionID]: raise NotImplementedError() async def redact_outdated_group_sessions(self) -> list[SessionID]: raise NotImplementedError() async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: return (room_id, session_id) in self._inbound_sessions async def add_outbound_group_session(self, session: OutboundGroupSession) -> None: self._outbound_sessions[session.room_id] = session async def update_outbound_group_session(self, session: OutboundGroupSession) -> None: # This is a no-op as the session object is the same one previously added. pass async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None: return self._outbound_sessions.get(room_id) async def remove_outbound_group_session(self, room_id: RoomID) -> None: self._outbound_sessions.pop(room_id, None) async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None: for room_id in rooms: self._outbound_sessions.pop(room_id, None) async def validate_message_index( self, sender_key: IdentityKey, session_id: SessionID, event_id: EventID, index: int, timestamp: int, ) -> bool: try: return self._message_indices[(sender_key, session_id, index)] == (event_id, timestamp) except KeyError: self._message_indices[(sender_key, session_id, index)] = (event_id, timestamp) return True async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None: return self._devices.get(user_id) async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None: return self._devices.get(user_id, {}).get(device_id) async def find_device_by_key( self, user_id: UserID, identity_key: IdentityKey ) -> DeviceIdentity | None: for device in self._devices.get(user_id, {}).values(): if device.identity_key == identity_key: return device return None async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None: self._devices[user_id] = devices async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]: return [user_id for user_id in users if user_id in self._devices] async def put_cross_signing_key( self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey ) -> None: try: current = self._cross_signing_keys[user_id][usage] except KeyError: self._cross_signing_keys.setdefault(user_id, {})[usage] = TOFUSigningKey( key=key, first=key ) else: current.key = key async def get_cross_signing_keys( self, user_id: UserID ) -> dict[CrossSigningUsage, TOFUSigningKey]: return self._cross_signing_keys.get(user_id, {}) async def put_signature( self, target: CrossSigner, signer: CrossSigner, signature: str ) -> None: self._signatures.setdefault(signer, {})[target] = signature async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool: return target in self._signatures.get(signer, {}) async def drop_signatures_by_key(self, signer: CrossSigner) -> int: deleted = self._signatures.pop(signer, None) return len(deleted) python-0.20.7/mautrix/crypto/store/tests/000077500000000000000000000000001473573527000204455ustar00rootroot00000000000000python-0.20.7/mautrix/crypto/store/tests/__init__.py000066400000000000000000000000001473573527000225440ustar00rootroot00000000000000python-0.20.7/mautrix/crypto/store/tests/store_test.py000066400000000000000000000115131473573527000232130ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import AsyncContextManager, AsyncIterator, Callable from contextlib import asynccontextmanager import os import random import string import time import asyncpg import pytest from mautrix.client.state_store import SyncStore from mautrix.crypto import InboundGroupSession, OlmAccount, OutboundGroupSession from mautrix.types import DeviceID, EventID, RoomID, SessionID, SyncToken from mautrix.util.async_db import Database from .. import CryptoStore, MemoryCryptoStore, PgCryptoStore @asynccontextmanager async def async_postgres_store() -> AsyncIterator[PgCryptoStore]: try: pg_url = os.environ["MEOW_TEST_PG_URL"] except KeyError: pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)") return conn: asyncpg.Connection = await asyncpg.connect(pg_url) schema_name = "".join(random.choices(string.ascii_lowercase, k=8)) schema_name = f"test_schema_{schema_name}_{int(time.time())}" await conn.execute(f"CREATE SCHEMA {schema_name}") db = Database.create( pg_url, upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}}, ) store = PgCryptoStore("", "test", db) await db.start() yield store await db.stop() await conn.execute(f"DROP SCHEMA {schema_name} CASCADE") await conn.close() @asynccontextmanager async def async_sqlite_store() -> AsyncIterator[PgCryptoStore]: db = Database.create( "sqlite::memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1} ) store = PgCryptoStore("", "test", db) await db.start() yield store await db.stop() @asynccontextmanager async def memory_store() -> AsyncIterator[MemoryCryptoStore]: yield MemoryCryptoStore("", "test") @pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store]) async def crypto_store(request) -> AsyncIterator[CryptoStore]: param: Callable[[], AsyncContextManager[CryptoStore]] = request.param async with param() as state_store: yield state_store async def test_basic(crypto_store: CryptoStore) -> None: acc = OlmAccount() keys = acc.identity_keys await crypto_store.put_account(acc) await crypto_store.put_device_id(DeviceID("TEST")) if isinstance(crypto_store, SyncStore): await crypto_store.put_next_batch(SyncToken("TEST")) assert await crypto_store.get_device_id() == "TEST" assert (await crypto_store.get_account()).identity_keys == keys if isinstance(crypto_store, SyncStore): assert await crypto_store.get_next_batch() == "TEST" def _make_group_sess( acc: OlmAccount, room_id: RoomID ) -> tuple[InboundGroupSession, OutboundGroupSession]: outbound = OutboundGroupSession(room_id) inbound = InboundGroupSession( session_key=outbound.session_key, signing_key=acc.signing_key, sender_key=acc.identity_key, room_id=room_id, ) return inbound, outbound async def test_validate_message_index(crypto_store: CryptoStore) -> None: acc = OlmAccount() inbound, outbound = _make_group_sess(acc, RoomID("!foo:bar.com")) outbound.shared = True orig_plaintext = "hello world" ciphertext = outbound.encrypt(orig_plaintext) ts = int(time.time() * 1000) plaintext, index = inbound.decrypt(ciphertext) assert plaintext == orig_plaintext assert await crypto_store.validate_message_index( acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts ), "Initial validation returns True" assert await crypto_store.validate_message_index( acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts ), "Validating the same details again returns True" assert not await crypto_store.validate_message_index( acc.identity_key, SessionID(inbound.id), EventID("$bar"), index, ts ), "Different event ID causes validation to fail" assert not await crypto_store.validate_message_index( acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1 ), "Different timestamp causes validation to fail" assert not await crypto_store.validate_message_index( acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1 ), "Validating incorrect details twice fails" assert await crypto_store.validate_message_index( acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts ), "Validating the same details after fails still returns True" # TODO tests for device identity storage, group session storage # and cross-signing key/signature storage python-0.20.7/mautrix/crypto/unwedge.py000066400000000000000000000034671473573527000201710ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import time from mautrix.types import EventType, IdentityKey, Obj, UserID from .decrypt_olm import OlmDecryptionMachine from .device_lists import DeviceListMachine from .encrypt_olm import OlmEncryptionMachine MIN_UNWEDGE_INTERVAL = 1 * 60 * 60 class OlmUnwedgingMachine(OlmDecryptionMachine, OlmEncryptionMachine, DeviceListMachine): async def _unwedge_session(self, sender: UserID, sender_key: IdentityKey) -> None: try: prev_unwedge = self._prev_unwedge[sender_key] except KeyError: pass else: delta = time.monotonic() - prev_unwedge if delta < MIN_UNWEDGE_INTERVAL: self.log.debug( f"Not creating new Olm session with {sender}/{sender_key}, " f"previous recreation was {delta}s ago" ) return self._prev_unwedge[sender_key] = time.monotonic() try: device = await self.get_or_fetch_device_by_key(sender, sender_key) if device is None: self.log.warning( f"Didn't find identity of {sender}/{sender_key}, can't unwedge session" ) return self.log.debug( f"Creating new Olm session with {sender}/{device.user_id} (key: {sender_key})" ) await self.send_encrypted_to_device( device, EventType.TO_DEVICE_DUMMY, Obj(), _force_recreate_session=True ) except Exception: self.log.exception(f"Error unwedging session with {sender}/{sender_key}") python-0.20.7/mautrix/errors/000077500000000000000000000000001473573527000161435ustar00rootroot00000000000000python-0.20.7/mautrix/errors/__init__.py000066400000000000000000000055211473573527000202570ustar00rootroot00000000000000from .base import IntentError, MatrixConnectionError, MatrixError, MatrixResponseError from .crypto import ( CryptoError, DecryptedPayloadError, DecryptionError, DeviceValidationError, DuplicateMessageIndex, EncryptionError, GroupSessionWithheldError, MatchingSessionDecryptionError, MismatchingRoomError, SessionNotFound, SessionShareError, VerificationError, ) from .request import ( MAlreadyJoined, MatrixBadContent, MatrixBadRequest, MatrixInvalidToken, MatrixRequestError, MatrixStandardRequestError, MatrixUnknownRequestError, MBadJSON, MBadState, MCaptchaInvalid, MCaptchaNeeded, MExclusive, MForbidden, MGuestAccessForbidden, MIncompatibleRoomVersion, MInsufficientPower, MInvalidParam, MInvalidRoomState, MInvalidUsername, MLimitExceeded, MMissingParam, MMissingToken, MNotFound, MNotJoined, MNotJSON, MRoomInUse, MTooLarge, MUnauthorized, MUnknown, MUnknownEndpoint, MUnknownToken, MUnrecognized, MUnsupportedRoomVersion, MUserDeactivated, MUserInUse, make_request_error, standard_error, ) from .well_known import ( WellKnownError, WellKnownInvalidVersionsResponse, WellKnownMissingHomeserver, WellKnownNotJSON, WellKnownNotURL, WellKnownUnexpectedStatus, WellKnownUnsupportedScheme, ) __all__ = [ "IntentError", "MatrixConnectionError", "MatrixError", "MatrixResponseError", "CryptoError", "DecryptedPayloadError", "DecryptionError", "DeviceValidationError", "DuplicateMessageIndex", "EncryptionError", "GroupSessionWithheldError", "MatchingSessionDecryptionError", "MismatchingRoomError", "SessionNotFound", "SessionShareError", "VerificationError", "MAlreadyJoined", "MatrixBadContent", "MatrixBadRequest", "MatrixInvalidToken", "MatrixRequestError", "MatrixStandardRequestError", "MatrixUnknownRequestError", "MBadJSON", "MBadState", "MCaptchaInvalid", "MCaptchaNeeded", "MExclusive", "MForbidden", "MGuestAccessForbidden", "MIncompatibleRoomVersion", "MInsufficientPower", "MInvalidParam", "MInvalidRoomState", "MInvalidUsername", "MLimitExceeded", "MMissingParam", "MMissingToken", "MNotFound", "MNotJoined", "MNotJSON", "MRoomInUse", "MTooLarge", "MUnauthorized", "MUnknown", "MUnknownEndpoint", "MUnknownToken", "MUnrecognized", "MUnsupportedRoomVersion", "MUserDeactivated", "MUserInUse", "make_request_error", "standard_error", "WellKnownError", "WellKnownInvalidVersionsResponse", "WellKnownMissingHomeserver", "WellKnownNotJSON", "WellKnownNotURL", "WellKnownUnexpectedStatus", "WellKnownUnsupportedScheme", ] python-0.20.7/mautrix/errors/base.py000066400000000000000000000013251473573527000174300ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. class MatrixError(Exception): """A generic Matrix error. Specific errors will subclass this.""" pass class MatrixConnectionError(MatrixError): pass class MatrixResponseError(MatrixError): """The response from the homeserver did not fulfill expectations.""" def __init__(self, message: str) -> None: super().__init__(message) class IntentError(MatrixError): """An intent execution failure, most likely caused by a `MatrixRequestError`.""" pass python-0.20.7/mautrix/errors/crypto.py000066400000000000000000000046631473573527000200460ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import warnings from mautrix.types import IdentityKey, SessionID from .base import MatrixError class CryptoError(MatrixError): def __init__(self, message: str) -> None: super().__init__(message) self.message = message class EncryptionError(CryptoError): pass class SessionShareError(CryptoError): pass class DecryptionError(CryptoError): @property def human_message(self) -> str: return "the bridge failed to decrypt the message" class MatchingSessionDecryptionError(DecryptionError): pass class GroupSessionWithheldError(DecryptionError): def __init__(self, session_id: SessionID, withheld_code: str) -> None: super().__init__(f"Session ID {session_id} was withheld ({withheld_code})") self.withheld_code = withheld_code class SessionNotFound(DecryptionError): def __init__(self, session_id: SessionID, sender_key: IdentityKey | None = None) -> None: super().__init__( f"Failed to decrypt megolm event: no session with given ID {session_id} found" ) self.session_id = session_id self._sender_key = sender_key @property def human_message(self) -> str: return "the bridge hasn't received the decryption keys" @property def sender_key(self) -> IdentityKey | None: """ .. deprecated:: 0.17.0 Matrix v1.3 deprecated the device_id and sender_key fields in megolm events. """ warnings.warn( "The sender_key field in Megolm events was deprecated in Matrix 1.3", DeprecationWarning, ) return self._sender_key class DuplicateMessageIndex(DecryptionError): def __init__(self) -> None: super().__init__("Duplicate message index") class VerificationError(DecryptionError): def __init__(self) -> None: super().__init__("Device keys in session and cached device info do not match") class DecryptedPayloadError(DecryptionError): pass class MismatchingRoomError(DecryptionError): def __init__(self) -> None: super().__init__("Encrypted megolm event is not intended for this room") class DeviceValidationError(EncryptionError): pass python-0.20.7/mautrix/errors/request.py000066400000000000000000000134431473573527000202120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Callable, Type from .base import MatrixError class MatrixRequestError(MatrixError): """An error that was returned by the homeserver.""" http_status: int message: str | None errcode: str class MatrixUnknownRequestError(MatrixRequestError): """An unknown error type returned by the homeserver.""" http_status: int text: str errcode: str | None message: str | None def __init__( self, http_status: int = 0, text: str = "", errcode: str | None = None, message: str | None = None, ) -> None: super().__init__(f"{http_status}: {text}") self.http_status = http_status self.text = text self.errcode = errcode self.message = message class MatrixStandardRequestError(MatrixRequestError): """A standard error type returned by the homeserver.""" errcode: str = None def __init__(self, http_status: int, message: str = "") -> None: super().__init__(message) self.http_status: int = http_status self.message: str = message MxSRE = Type[MatrixStandardRequestError] ec_map: dict[str, MxSRE] = {} uec_map: dict[str, MxSRE] = {} def standard_error(code: str, unstable: str | None = None) -> Callable[[MxSRE], MxSRE]: def decorator(cls: MxSRE) -> MxSRE: cls.errcode = code ec_map[code] = cls if unstable: cls.unstable_errcode = unstable uec_map[unstable] = cls return cls return decorator def make_request_error( http_status: int, text: str, errcode: str | None, message: str | None, unstable_errcode: str | None = None, ) -> MatrixRequestError: """ Determine the correct exception class for the error code and create an instance of that class with the given values. Args: http_status: The HTTP status code. text: The raw response text. errcode: The errcode field in the response JSON. message: The error field in the response JSON. unstable_errcode: The MSC3848 error code field in the response JSON. """ if unstable_errcode: try: ec_class = uec_map[unstable_errcode] return ec_class(http_status, message) except KeyError: pass try: ec_class = ec_map[errcode] return ec_class(http_status, message) except KeyError: return MatrixUnknownRequestError(http_status, text, errcode, message) # Standard error codes from https://spec.matrix.org/v1.3/client-server-api/#api-standards # Additionally some combining superclasses for some of the error codes @standard_error("M_FORBIDDEN") class MForbidden(MatrixStandardRequestError): pass @standard_error("M_ALREADY_JOINED", unstable="ORG.MATRIX.MSC3848.ALREADY_JOINED") class MAlreadyJoined(MForbidden): pass @standard_error("M_NOT_JOINED", unstable="ORG.MATRIX.MSC3848.NOT_JOINED") class MNotJoined(MForbidden): pass @standard_error("M_INSUFFICIENT_POWER", unstable="ORG.MATRIX.MSC3848.INSUFFICIENT_POWER") class MInsufficientPower(MForbidden): pass @standard_error("M_UNKNOWN_ENDPOINT") class MUnknownEndpoint(MatrixStandardRequestError): pass @standard_error("M_USER_DEACTIVATED") class MUserDeactivated(MForbidden): pass class MatrixInvalidToken(MatrixStandardRequestError): pass @standard_error("M_UNKNOWN_TOKEN") class MUnknownToken(MatrixInvalidToken): pass @standard_error("M_MISSING_TOKEN") class MMissingToken(MatrixInvalidToken): pass class MatrixBadRequest(MatrixStandardRequestError): pass class MatrixBadContent(MatrixBadRequest): pass @standard_error("M_BAD_JSON") class MBadJSON(MatrixBadContent): pass @standard_error("M_NOT_JSON") class MNotJSON(MatrixBadContent): pass @standard_error("M_NOT_FOUND") class MNotFound(MatrixStandardRequestError): pass @standard_error("M_LIMIT_EXCEEDED") class MLimitExceeded(MatrixStandardRequestError): pass @standard_error("M_UNKNOWN") class MUnknown(MatrixStandardRequestError): pass @standard_error("M_UNRECOGNIZED") class MUnrecognized(MatrixStandardRequestError): pass @standard_error("M_UNAUTHORIZED") class MUnauthorized(MatrixStandardRequestError): pass @standard_error("M_USER_IN_USE") class MUserInUse(MatrixStandardRequestError): pass @standard_error("M_INVALID_USERNAME") class MInvalidUsername(MatrixStandardRequestError): pass @standard_error("M_ROOM_IN_USE") class MRoomInUse(MatrixStandardRequestError): pass @standard_error("M_INVALID_ROOM_STATE") class MInvalidRoomState(MatrixStandardRequestError): pass # TODO THREEPID_ errors @standard_error("M_UNSUPPORTED_ROOM_VERSION") class MUnsupportedRoomVersion(MatrixStandardRequestError): pass @standard_error("M_INCOMPATIBLE_ROOM_VERSION") class MIncompatibleRoomVersion(MatrixStandardRequestError): pass @standard_error("M_BAD_STATE") class MBadState(MatrixStandardRequestError): pass @standard_error("M_GUEST_ACCESS_FORBIDDEN") class MGuestAccessForbidden(MatrixStandardRequestError): pass @standard_error("M_CAPTCHA_NEEDED") class MCaptchaNeeded(MatrixStandardRequestError): pass @standard_error("M_CAPTCHA_INVALID") class MCaptchaInvalid(MatrixStandardRequestError): pass @standard_error("M_MISSING_PARAM") class MMissingParam(MatrixBadRequest): pass @standard_error("M_INVALID_PARAM") class MInvalidParam(MatrixBadRequest): pass @standard_error("M_TOO_LARGE") class MTooLarge(MatrixBadRequest): pass @standard_error("M_EXCLUSIVE") class MExclusive(MatrixStandardRequestError): pass python-0.20.7/mautrix/errors/well_known.py000066400000000000000000000030211473573527000206700ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .base import MatrixResponseError class WellKnownError(MatrixResponseError): """ An error that occurred during server discovery. https://matrix.org/docs/spec/client_server/latest#get-well-known-matrix-client """ pass class WellKnownUnexpectedStatus(WellKnownError): def __init__(self, status: int) -> None: super().__init__(f"Unexpected status code {status} when fetching .well-known file") self.status = status class WellKnownNotJSON(WellKnownError): def __init__(self) -> None: super().__init__(".well-known response was not JSON") class WellKnownMissingHomeserver(WellKnownError): def __init__(self) -> None: super().__init__("No homeserver found in .well-known response") class WellKnownNotURL(WellKnownError): def __init__(self) -> None: super().__init__("Homeserver base URL in .well-known response was not a valid URL") class WellKnownUnsupportedScheme(WellKnownError): def __init__(self, scheme: str) -> None: super().__init__(f"URL in .well-known response has unsupported scheme {scheme}") class WellKnownInvalidVersionsResponse(WellKnownError): def __init__(self) -> None: super().__init__( "URL in .well-known response didn't respond to versions endpoint properly" ) python-0.20.7/mautrix/fixmodule.py000066400000000000000000000030661473573527000172020ustar00rootroot00000000000000# This is a script that fixes the __module__ tags in mautrix-python and some libraries. # It's used to help Sphinx/autodoc figure out where the things are canonically imported from # (by default, it shows the exact module they're defined in rather than the top-level import path). from typing import NewType from types import FunctionType, ModuleType import aiohttp from . import appservice, bridge, client, crypto, errors, types from .crypto import attachments from .util import async_db, config, db, formatter, logging def _fix(obj: ModuleType) -> None: for item_name in getattr(obj, "__all__", None) or dir(obj): item = getattr(obj, item_name) if isinstance(item, (type, FunctionType, NewType)): # Ignore backwards-compatibility imports like the BridgeState import in mautrix.bridge if item.__module__.startswith("mautrix") and not item.__module__.startswith( obj.__name__ ): continue item.__module__ = obj.__name__ if isinstance(item, NewType): # By default autodoc makes a blank "Bases:" text, # so adjust it to show the type as the "base" item.__bases__ = (item.__supertype__,) # elif type(item).__module__ == "typing": # print(obj.__name__, item_name, type(item)) _things_to_fix = [ types, bridge, client, crypto, attachments, errors, appservice, async_db, config, db, formatter, logging, aiohttp, ] for mod in _things_to_fix: _fix(mod) python-0.20.7/mautrix/genall.py000066400000000000000000000027351473573527000164520ustar00rootroot00000000000000# This script generates the __all__ arrays for types/__init__.py and errors/__init__.py # to avoid having to manually add both the import and the __all__ entry. # See https://github.com/mautrix/python/issues/90 for why __all__ is needed at all. from pathlib import Path import ast import black root_module = Path(__file__).parent black_cfg = black.parse_pyproject_toml(str(root_module.parent / "pyproject.toml")) black_mode = black.Mode( target_versions={black.TargetVersion[ver.upper()] for ver in black_cfg["target_version"]}, line_length=black_cfg["line_length"], ) def add_imports_to_all(dir: str) -> None: init_file = root_module / dir / "__init__.py" with open(init_file) as f: init_ast = ast.parse(f.read(), filename=f"mautrix/{dir}/__init__.py") imports: list[str] = [] all_node: ast.List | None = None for node in ast.iter_child_nodes(init_ast): if isinstance(node, (ast.Import, ast.ImportFrom)): imports += (name.name for name in node.names) elif isinstance(node, ast.Assign) and isinstance(node.value, ast.List): target = node.targets[0] if len(node.targets) == 1 and isinstance(target, ast.Name) and target.id == "__all__": all_node = node.value all_node.elts = [ast.Constant(name) for name in imports] with open(init_file, "w") as f: f.write(black.format_str(ast.unparse(init_ast), mode=black_mode)) add_imports_to_all("types") add_imports_to_all("errors") python-0.20.7/mautrix/py.typed000066400000000000000000000000001473573527000163140ustar00rootroot00000000000000python-0.20.7/mautrix/types/000077500000000000000000000000001473573527000157735ustar00rootroot00000000000000python-0.20.7/mautrix/types/__init__.py000066400000000000000000000217331473573527000201120ustar00rootroot00000000000000from .auth import ( DiscoveryInformation, DiscoveryIntegrations, DiscoveryIntegrationServer, DiscoveryServer, LoginFlow, LoginFlowList, LoginResponse, LoginType, MatrixUserIdentifier, PhoneIdentifier, ThirdPartyIdentifier, UserIdentifier, UserIdentifierType, WhoamiResponse, ) from .crypto import ( ClaimKeysResponse, CrossSigner, CrossSigningKeys, CrossSigningUsage, DecryptedOlmEvent, DeviceIdentity, DeviceKeys, OlmEventKeys, QueryKeysResponse, TOFUSigningKey, TrustState, UnsignedDeviceInfo, ) from .event import ( AccountDataEvent, AccountDataEventContent, ASToDeviceEvent, AudioInfo, BaseEvent, BaseFileInfo, BaseMessageEventContent, BaseMessageEventContentFuncs, BaseRoomEvent, BaseUnsigned, BatchSendEvent, BatchSendStateEvent, BeeperMessageStatusEvent, BeeperMessageStatusEventContent, CallAnswerEventContent, CallCandidate, CallCandidatesEventContent, CallData, CallDataType, CallEvent, CallEventContent, CallHangupEventContent, CallHangupReason, CallInviteEventContent, CallNegotiateEventContent, CallRejectEventContent, CallSelectAnswerEventContent, CanonicalAliasStateEventContent, EncryptedEvent, EncryptedEventContent, EncryptedFile, EncryptedMegolmEventContent, EncryptedOlmEventContent, EncryptionAlgorithm, EncryptionKeyAlgorithm, EphemeralEvent, Event, EventContent, EventType, FileInfo, Format, ForwardedRoomKeyEventContent, GenericEvent, ImageInfo, InReplyTo, JoinRule, JoinRulesStateEventContent, JSONWebKey, KeyID, KeyRequestAction, LocationInfo, LocationMessageEventContent, MediaInfo, MediaMessageEventContent, Membership, MemberStateEventContent, MessageEvent, MessageEventContent, MessageStatus, MessageStatusReason, MessageType, MessageUnsigned, OlmCiphertext, OlmMsgType, PowerLevelStateEventContent, PresenceEvent, PresenceEventContent, PresenceState, ReactionEvent, ReactionEventContent, ReceiptEvent, ReceiptEventContent, ReceiptType, RedactionEvent, RedactionEventContent, RelatesTo, RelationType, RequestedKeyInfo, RoomAvatarStateEventContent, RoomCreateStateEventContent, RoomEncryptionStateEventContent, RoomKeyEventContent, RoomKeyRequestEventContent, RoomKeyWithheldCode, RoomKeyWithheldEventContent, RoomNameStateEventContent, RoomPinnedEventsStateEventContent, RoomPredecessor, RoomTagAccountDataEventContent, RoomTagInfo, RoomTombstoneStateEventContent, RoomTopicStateEventContent, RoomType, SingleReceiptEventContent, SpaceChildStateEventContent, SpaceParentStateEventContent, StateEvent, StateEventContent, StateUnsigned, StrippedStateEvent, TextMessageEventContent, ThumbnailInfo, ToDeviceEvent, ToDeviceEventContent, TypingEvent, TypingEventContent, VideoInfo, ) from .filter import EventFilter, Filter, RoomEventFilter, RoomFilter, StateFilter from .matrixuri import IdentifierType, MatrixURI, MatrixURIError, URIAction from .media import ( MediaCreateResponse, MediaRepoConfig, MXOpenGraph, OpenGraphAudio, OpenGraphImage, OpenGraphVideo, ) from .misc import ( BatchSendResponse, BeeperBatchSendResponse, DeviceLists, DeviceOTKCount, DirectoryPaginationToken, EventContext, PaginatedMessages, PaginationDirection, RoomAliasInfo, RoomCreatePreset, RoomDirectoryResponse, RoomDirectoryVisibility, ) from .primitive import ( JSON, BatchID, ContentURI, DeviceID, EventID, FilterID, IdentityKey, RoomAlias, RoomID, SessionID, Signature, SigningKey, SyncToken, UserID, ) from .push_rules import ( PushAction, PushActionDict, PushActionType, PushCondition, PushConditionKind, PushOperator, PushRule, PushRuleID, PushRuleKind, PushRuleScope, ) from .users import Member, User, UserSearchResults from .util import ( ExtensibleEnum, Lst, Obj, Serializable, SerializableAttrs, SerializableEnum, SerializerError, deserializer, field, serializer, ) from .versions import SpecVersions, Version, VersionFormat, VersionsResponse __all__ = [ "DiscoveryInformation", "DiscoveryIntegrations", "DiscoveryIntegrationServer", "DiscoveryServer", "LoginFlow", "LoginFlowList", "LoginResponse", "LoginType", "MatrixUserIdentifier", "PhoneIdentifier", "ThirdPartyIdentifier", "UserIdentifier", "UserIdentifierType", "WhoamiResponse", "ClaimKeysResponse", "CrossSigner", "CrossSigningKeys", "CrossSigningUsage", "DecryptedOlmEvent", "DeviceIdentity", "DeviceKeys", "OlmEventKeys", "QueryKeysResponse", "TOFUSigningKey", "TrustState", "UnsignedDeviceInfo", "AccountDataEvent", "AccountDataEventContent", "ASToDeviceEvent", "AudioInfo", "BaseEvent", "BaseFileInfo", "BaseMessageEventContent", "BaseMessageEventContentFuncs", "BaseRoomEvent", "BaseUnsigned", "BatchSendEvent", "BatchSendStateEvent", "BeeperMessageStatusEvent", "BeeperMessageStatusEventContent", "CallAnswerEventContent", "CallCandidate", "CallCandidatesEventContent", "CallData", "CallDataType", "CallEvent", "CallEventContent", "CallHangupEventContent", "CallHangupReason", "CallInviteEventContent", "CallNegotiateEventContent", "CallRejectEventContent", "CallSelectAnswerEventContent", "CanonicalAliasStateEventContent", "EncryptedEvent", "EncryptedEventContent", "EncryptedFile", "EncryptedMegolmEventContent", "EncryptedOlmEventContent", "EncryptionAlgorithm", "EncryptionKeyAlgorithm", "EphemeralEvent", "Event", "EventContent", "EventType", "FileInfo", "Format", "ForwardedRoomKeyEventContent", "GenericEvent", "ImageInfo", "InReplyTo", "JoinRule", "JoinRulesStateEventContent", "JSONWebKey", "KeyID", "KeyRequestAction", "LocationInfo", "LocationMessageEventContent", "MediaInfo", "MediaMessageEventContent", "Membership", "MemberStateEventContent", "MessageEvent", "MessageEventContent", "MessageStatus", "MessageStatusReason", "MessageType", "MessageUnsigned", "OlmCiphertext", "OlmMsgType", "PowerLevelStateEventContent", "PresenceEvent", "PresenceEventContent", "PresenceState", "ReactionEvent", "ReactionEventContent", "ReceiptEvent", "ReceiptEventContent", "ReceiptType", "RedactionEvent", "RedactionEventContent", "RelatesTo", "RelationType", "RequestedKeyInfo", "RoomAvatarStateEventContent", "RoomCreateStateEventContent", "RoomEncryptionStateEventContent", "RoomKeyEventContent", "RoomKeyRequestEventContent", "RoomKeyWithheldCode", "RoomKeyWithheldEventContent", "RoomNameStateEventContent", "RoomPinnedEventsStateEventContent", "RoomPredecessor", "RoomTagAccountDataEventContent", "RoomTagInfo", "RoomTombstoneStateEventContent", "RoomTopicStateEventContent", "RoomType", "SingleReceiptEventContent", "SpaceChildStateEventContent", "SpaceParentStateEventContent", "StateEvent", "StateEventContent", "StateUnsigned", "StrippedStateEvent", "TextMessageEventContent", "ThumbnailInfo", "ToDeviceEvent", "ToDeviceEventContent", "TypingEvent", "TypingEventContent", "VideoInfo", "EventFilter", "Filter", "RoomEventFilter", "RoomFilter", "StateFilter", "IdentifierType", "MatrixURI", "MatrixURIError", "URIAction", "MediaCreateResponse", "MediaRepoConfig", "MXOpenGraph", "OpenGraphAudio", "OpenGraphImage", "OpenGraphVideo", "BatchSendResponse", "DeviceLists", "DeviceOTKCount", "DirectoryPaginationToken", "EventContext", "PaginatedMessages", "PaginationDirection", "RoomAliasInfo", "RoomCreatePreset", "RoomDirectoryResponse", "RoomDirectoryVisibility", "JSON", "BatchID", "ContentURI", "DeviceID", "EventID", "FilterID", "IdentityKey", "RoomAlias", "RoomID", "SessionID", "Signature", "SigningKey", "SyncToken", "UserID", "PushAction", "PushActionDict", "PushActionType", "PushCondition", "PushConditionKind", "PushOperator", "PushRule", "PushRuleID", "PushRuleKind", "PushRuleScope", "Member", "User", "UserSearchResults", "ExtensibleEnum", "Lst", "Obj", "Serializable", "SerializableAttrs", "SerializableEnum", "SerializerError", "deserializer", "field", "serializer", "SpecVersions", "Version", "VersionFormat", "VersionsResponse", ] python-0.20.7/mautrix/types/auth.py000066400000000000000000000142341473573527000173120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import List, NewType, Optional, Union from attr import dataclass from .primitive import JSON, DeviceID, UserID from .util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field class LoginType(ExtensibleEnum): """ A login type, as specified in the `POST /login endpoint`_ .. _POST /login endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login """ PASSWORD: "LoginType" = "m.login.password" TOKEN: "LoginType" = "m.login.token" SSO: "LoginType" = "m.login.sso" APPSERVICE: "LoginType" = "m.login.application_service" UNSTABLE_JWT: "LoginType" = "org.matrix.login.jwt" DEVTURE_SHARED_SECRET: "LoginType" = "com.devture.shared_secret_auth" @dataclass class LoginFlow(SerializableAttrs): """ A login flow, as specified in the `GET /login endpoint`_ .. _GET /login endpoint: https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login """ type: LoginType @dataclass class LoginFlowList(SerializableAttrs): flows: List[LoginFlow] def get_first_of_type(self, *types: LoginType) -> Optional[LoginFlow]: for flow in self.flows: if flow.type in types: return flow return None def supports_type(self, *types: LoginType) -> bool: return self.get_first_of_type(*types) is not None class UserIdentifierType(ExtensibleEnum): """ A user identifier type, as specified in the `Identifier types`_ section of the login spec. .. _Identifier types: https://spec.matrix.org/v1.2/client-server-api/#identifier-types """ MATRIX_USER: "UserIdentifierType" = "m.id.user" THIRD_PARTY: "UserIdentifierType" = "m.id.thirdparty" PHONE: "UserIdentifierType" = "m.id.phone" @dataclass class MatrixUserIdentifier(SerializableAttrs): """ A client can identify a user using their Matrix ID. This can either be the fully qualified Matrix user ID, or just the localpart of the user ID. """ user: str """The Matrix user ID or localpart""" type: UserIdentifierType = UserIdentifierType.MATRIX_USER @dataclass class ThirdPartyIdentifier(SerializableAttrs): """ A client can identify a user using a 3PID associated with the user's account on the homeserver, where the 3PID was previously associated using the `/account/3pid`_ API. See the `3PID Types`_ Appendix for a list of Third-party ID media. .. _/account/3pid: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3account3pid .. _3PID Types: https://spec.matrix.org/v1.2/appendices/#3pid-types """ medium: str address: str type: UserIdentifierType = UserIdentifierType.THIRD_PARTY @dataclass class PhoneIdentifier(SerializableAttrs): """ A client can identify a user using a phone number associated with the user's account, where the phone number was previously associated using the `/account/3pid`_ API. The phone number can be passed in as entered by the user; the homeserver will be responsible for canonicalising it. If the client wishes to canonicalise the phone number, then it can use the ``m.id.thirdparty`` identifier type with a ``medium`` of ``msisdn`` instead. .. _/account/3pid: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3account3pid """ country: str phone: str type: UserIdentifierType = UserIdentifierType.PHONE UserIdentifier = NewType( "UserIdentifier", Union[MatrixUserIdentifier, ThirdPartyIdentifier, PhoneIdentifier] ) @deserializer(UserIdentifier) def deserialize_user_identifier(data: JSON) -> Union[UserIdentifier, Obj]: try: identifier_type = UserIdentifierType.deserialize(data["type"]) except KeyError: return Obj(**data) if identifier_type == UserIdentifierType.MATRIX_USER: return MatrixUserIdentifier.deserialize(data) elif identifier_type == UserIdentifierType.THIRD_PARTY: return ThirdPartyIdentifier.deserialize(data) elif identifier_type == UserIdentifierType.PHONE: return PhoneIdentifier.deserialize(data) else: return Obj(**data) setattr(UserIdentifier, "deserialize", deserialize_user_identifier) @dataclass class DiscoveryServer(SerializableAttrs): base_url: Optional[str] = None @dataclass class DiscoveryIntegrationServer(SerializableAttrs): ui_url: Optional[str] = None api_url: Optional[str] = None @dataclass class DiscoveryIntegrations(SerializableAttrs): managers: List[DiscoveryIntegrationServer] = field(factory=lambda: []) @dataclass class DiscoveryInformation(SerializableAttrs): """ .well-known discovery information, as specified in the `GET /.well-known/matrix/client endpoint`_ .. _GET /.well-known/matrix/client endpoint: https://spec.matrix.org/v1.2/client-server-api/#getwell-knownmatrixclient """ homeserver: Optional[DiscoveryServer] = field(json="m.homeserver", factory=DiscoveryServer) identity_server: Optional[DiscoveryServer] = field( json="m.identity_server", factory=DiscoveryServer ) integrations: Optional[DiscoveryServer] = field( json="m.integrations", factory=DiscoveryIntegrations ) @dataclass class LoginResponse(SerializableAttrs): """ The response for a login request, as specified in the `POST /login endpoint`_ .. _POST /login endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login """ user_id: UserID device_id: DeviceID access_token: str well_known: DiscoveryInformation = field(factory=DiscoveryInformation) @dataclass class WhoamiResponse(SerializableAttrs): """ The response for a whoami request, as specified in the `GET /account/whoami endpoint`_ .. _GET /account/whoami endpoint: https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami """ user_id: UserID device_id: Optional[DeviceID] = None is_guest: bool = False python-0.20.7/mautrix/types/crypto.py000066400000000000000000000115071473573527000176710ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Dict, List, NamedTuple, Optional from enum import IntEnum from attr import dataclass from .event import EncryptionAlgorithm, EncryptionKeyAlgorithm, KeyID, ToDeviceEvent from .primitive import DeviceID, IdentityKey, Signature, SigningKey, UserID from .util import ExtensibleEnum, SerializableAttrs, field @dataclass class UnsignedDeviceInfo(SerializableAttrs): device_display_name: Optional[str] = None @dataclass class DeviceKeys(SerializableAttrs): user_id: UserID device_id: DeviceID algorithms: List[EncryptionAlgorithm] keys: Dict[KeyID, str] signatures: Dict[UserID, Dict[KeyID, Signature]] unsigned: UnsignedDeviceInfo = None def __attrs_post_init__(self) -> None: if self.unsigned is None: self.unsigned = UnsignedDeviceInfo() @property def ed25519(self) -> Optional[SigningKey]: try: return SigningKey(self.keys[KeyID(EncryptionKeyAlgorithm.ED25519, self.device_id)]) except KeyError: return None @property def curve25519(self) -> Optional[IdentityKey]: try: return IdentityKey(self.keys[KeyID(EncryptionKeyAlgorithm.CURVE25519, self.device_id)]) except KeyError: return None class CrossSigningUsage(ExtensibleEnum): MASTER = "master" SELF = "self_signing" USER = "user_signing" @dataclass class CrossSigningKeys(SerializableAttrs): user_id: UserID usage: List[CrossSigningUsage] keys: Dict[KeyID, SigningKey] signatures: Dict[UserID, Dict[KeyID, Signature]] = field(factory=lambda: {}) @property def first_key(self) -> Optional[SigningKey]: try: return next(iter(self.keys.values())) except StopIteration: return None @property def first_ed25519_key(self) -> Optional[SigningKey]: return self.first_key_with_algorithm(EncryptionKeyAlgorithm.ED25519) def first_key_with_algorithm(self, alg: EncryptionKeyAlgorithm) -> Optional[SigningKey]: if not self.keys: return None try: return next(key for key_id, key in self.keys.items() if key_id.algorithm == alg) except StopIteration: return None @dataclass class QueryKeysResponse(SerializableAttrs): device_keys: Dict[UserID, Dict[DeviceID, DeviceKeys]] = field(factory=lambda: {}) master_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {}) self_signing_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {}) user_signing_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {}) failures: Dict[str, Any] = field(factory=lambda: {}) @dataclass class ClaimKeysResponse(SerializableAttrs): one_time_keys: Dict[UserID, Dict[DeviceID, Dict[KeyID, Any]]] failures: Dict[str, Any] = field(factory=lambda: {}) class TrustState(IntEnum): BLACKLISTED = -100 UNVERIFIED = 0 UNKNOWN_DEVICE = 10 FORWARDED = 20 CROSS_SIGNED_UNTRUSTED = 50 CROSS_SIGNED_TOFU = 100 CROSS_SIGNED_TRUSTED = 200 VERIFIED = 300 def __str__(self) -> str: return _trust_state_to_name[self] @classmethod def parse(cls, val: str) -> "TrustState": try: return _name_to_trust_state[val] except KeyError as e: raise ValueError(f"Invalid trust state {val!r}") from e _trust_state_to_name: Dict[TrustState, str] = { val: val.name.lower().replace("_", "-") for val in TrustState } _name_to_trust_state: Dict[str, TrustState] = { value: key for key, value in _trust_state_to_name.items() } @dataclass class DeviceIdentity: user_id: UserID device_id: DeviceID identity_key: IdentityKey signing_key: SigningKey trust: TrustState deleted: bool name: str @dataclass class OlmEventKeys(SerializableAttrs): ed25519: SigningKey @dataclass class DecryptedOlmEvent(ToDeviceEvent, SerializableAttrs): keys: OlmEventKeys recipient: UserID recipient_keys: OlmEventKeys sender_device: Optional[DeviceID] = None sender_key: IdentityKey = field(hidden=True, default=None) class TOFUSigningKey(NamedTuple): """ A tuple representing a single cross-signing key. The first value is the current key, and the second value is the first seen key. If the values don't match, it means the key is not valid for trust-on-first-use. """ key: SigningKey first: SigningKey class CrossSigner(NamedTuple): """ A tuple containing a user ID and a signing key they own. The key can either be a device-owned signing key, or one of the user's cross-signing keys. """ user_id: UserID key: SigningKey python-0.20.7/mautrix/types/event/000077500000000000000000000000001473573527000171145ustar00rootroot00000000000000python-0.20.7/mautrix/types/event/__init__.py000066400000000000000000000057121473573527000212320ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .account_data import ( AccountDataEvent, AccountDataEventContent, RoomTagAccountDataEventContent, RoomTagInfo, ) from .base import BaseEvent, BaseRoomEvent, BaseUnsigned, GenericEvent from .batch import BatchSendEvent, BatchSendStateEvent from .beeper import ( BeeperMessageStatusEvent, BeeperMessageStatusEventContent, MessageStatus, MessageStatusReason, ) from .encrypted import ( EncryptedEvent, EncryptedEventContent, EncryptedMegolmEventContent, EncryptedOlmEventContent, EncryptionAlgorithm, EncryptionKeyAlgorithm, KeyID, OlmCiphertext, OlmMsgType, ) from .ephemeral import ( EphemeralEvent, PresenceEvent, PresenceEventContent, PresenceState, ReceiptEvent, ReceiptEventContent, ReceiptType, SingleReceiptEventContent, TypingEvent, TypingEventContent, ) from .generic import Event, EventContent from .message import ( AudioInfo, BaseFileInfo, BaseMessageEventContent, BaseMessageEventContentFuncs, EncryptedFile, FileInfo, Format, ImageInfo, InReplyTo, JSONWebKey, LocationInfo, LocationMessageEventContent, MediaInfo, MediaMessageEventContent, MessageEvent, MessageEventContent, MessageType, MessageUnsigned, RelatesTo, RelationType, TextMessageEventContent, ThumbnailInfo, VideoInfo, ) from .reaction import ReactionEvent, ReactionEventContent from .redaction import RedactionEvent, RedactionEventContent from .state import ( CanonicalAliasStateEventContent, JoinRule, JoinRulesStateEventContent, Membership, MemberStateEventContent, PowerLevelStateEventContent, RoomAvatarStateEventContent, RoomCreateStateEventContent, RoomEncryptionStateEventContent, RoomNameStateEventContent, RoomPinnedEventsStateEventContent, RoomPredecessor, RoomTombstoneStateEventContent, RoomTopicStateEventContent, RoomType, SpaceChildStateEventContent, SpaceParentStateEventContent, StateEvent, StateEventContent, StateUnsigned, StrippedStateEvent, ) from .to_device import ( ASToDeviceEvent, ForwardedRoomKeyEventContent, KeyRequestAction, RequestedKeyInfo, RoomKeyEventContent, RoomKeyRequestEventContent, RoomKeyWithheldCode, RoomKeyWithheldEventContent, ToDeviceEvent, ToDeviceEventContent, ) from .type import EventType from .voip import ( CallAnswerEventContent, CallCandidate, CallCandidatesEventContent, CallData, CallDataType, CallEvent, CallEventContent, CallHangupEventContent, CallHangupReason, CallInviteEventContent, CallNegotiateEventContent, CallRejectEventContent, CallSelectAnswerEventContent, ) python-0.20.7/mautrix/types/event/account_data.py000066400000000000000000000036141473573527000221170ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, List, Union from attr import dataclass import attr from ..primitive import JSON, RoomID, UserID from ..util import Obj, SerializableAttrs, deserializer from .base import BaseEvent, EventType @dataclass class RoomTagInfo(SerializableAttrs): order: Union[int, float, str] = None @dataclass class RoomTagAccountDataEventContent(SerializableAttrs): tags: Dict[str, RoomTagInfo] = attr.ib(default=None, metadata={"json": "tags"}) DirectAccountDataEventContent = Dict[UserID, List[RoomID]] AccountDataEventContent = Union[RoomTagAccountDataEventContent, DirectAccountDataEventContent, Obj] account_data_event_content_map = { EventType.TAG: RoomTagAccountDataEventContent, # m.direct doesn't really need deserializing # EventType.DIRECT: DirectAccountDataEventContent, } # TODO remaining account data event types @dataclass class AccountDataEvent(BaseEvent, SerializableAttrs): content: AccountDataEventContent @classmethod def deserialize(cls, data: JSON) -> "AccountDataEvent": try: evt_type = EventType.find(data.get("type")) data.get("content", {})["__mautrix_event_type"] = evt_type except ValueError: return Obj(**data) evt = super().deserialize(data) evt.type = evt_type return evt @staticmethod @deserializer(AccountDataEventContent) def deserialize_content(data: JSON) -> AccountDataEventContent: evt_type = data.pop("__mautrix_event_type", None) content_type = account_data_event_content_map.get(evt_type, None) if not content_type: return Obj(**data) return content_type.deserialize(data) python-0.20.7/mautrix/types/event/base.py000066400000000000000000000027401473573527000204030ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional from attr import dataclass import attr from ..primitive import EventID, RoomID, UserID from ..util import Obj, SerializableAttrs from .type import EventType @dataclass class BaseUnsigned: """Base unsigned information.""" age: int = None @dataclass class BaseEvent: """Base event class. The only things an event **must** have are content and event type.""" content: Obj type: EventType @dataclass class BaseRoomEvent(BaseEvent): """Base room event class. Room events must have a room ID, event ID, sender and timestamp in addition to the content and type in the base event.""" room_id: RoomID event_id: EventID sender: UserID timestamp: int = attr.ib(metadata={"json": "origin_server_ts"}) @dataclass class GenericEvent(BaseEvent, SerializableAttrs): """ An event class that contains all possible top-level event keys and uses generic Obj's for object keys (content and unsigned) """ content: Obj type: EventType room_id: Optional[RoomID] = None event_id: Optional[EventID] = None sender: Optional[UserID] = None timestamp: Optional[int] = None state_key: Optional[str] = None unsigned: Obj = None redacts: Optional[EventID] = None python-0.20.7/mautrix/types/event/batch.py000066400000000000000000000020111473573527000205410ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan, Sumner Evans # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Optional from attr import dataclass import attr from ..primitive import EventID, UserID from ..util import SerializableAttrs from .base import BaseEvent @dataclass(kw_only=True) class BatchSendEvent(BaseEvent, SerializableAttrs): """Base event class for events sent via a batch send request.""" sender: UserID timestamp: int = attr.ib(metadata={"json": "origin_server_ts"}) content: Any # N.B. Overriding event IDs is not allowed in standard room versions event_id: Optional[EventID] = None @dataclass(kw_only=True) class BatchSendStateEvent(BatchSendEvent, SerializableAttrs): """ State events to be used as initial state events on batch send events. These never need to be deserialized. """ state_key: str python-0.20.7/mautrix/types/event/beeper.py000066400000000000000000000036511473573527000207350ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional from attr import dataclass from ..primitive import EventID, RoomID, SessionID from ..util import SerializableAttrs, SerializableEnum, field from .base import BaseRoomEvent from .message import RelatesTo class MessageStatusReason(SerializableEnum): GENERIC_ERROR = "m.event_not_handled" UNSUPPORTED = "com.beeper.unsupported_event" UNDECRYPTABLE = "com.beeper.undecryptable_event" TOO_OLD = "m.event_too_old" NETWORK_ERROR = "m.foreign_network_error" NO_PERMISSION = "m.no_permission" @property def checkpoint_status(self): from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus if self == MessageStatusReason.UNSUPPORTED: return MessageSendCheckpointStatus.UNSUPPORTED elif self == MessageStatusReason.TOO_OLD: return MessageSendCheckpointStatus.TIMEOUT return MessageSendCheckpointStatus.PERM_FAILURE class MessageStatus(SerializableEnum): SUCCESS = "SUCCESS" PENDING = "PENDING" RETRIABLE = "FAIL_RETRIABLE" FAIL = "FAIL_PERMANENT" @dataclass(kw_only=True) class BeeperMessageStatusEventContent(SerializableAttrs): relates_to: RelatesTo = field(json="m.relates_to") network: str = "" status: Optional[MessageStatus] = None reason: Optional[MessageStatusReason] = None error: Optional[str] = None message: Optional[str] = None last_retry: Optional[EventID] = None @dataclass class BeeperMessageStatusEvent(BaseRoomEvent, SerializableAttrs): content: BeeperMessageStatusEventContent @dataclass class BeeperRoomKeyAckEventContent(SerializableAttrs): room_id: RoomID session_id: SessionID first_message_index: int python-0.20.7/mautrix/types/event/encrypted.py000066400000000000000000000107441473573527000214710ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, NewType, Optional, Union from enum import IntEnum import warnings from attr import dataclass from ..primitive import JSON, DeviceID, IdentityKey, SessionID from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field from .base import BaseRoomEvent, BaseUnsigned from .message import RelatesTo class EncryptionAlgorithm(ExtensibleEnum): OLM_V1: "EncryptionAlgorithm" = "m.olm.v1.curve25519-aes-sha2" MEGOLM_V1: "EncryptionAlgorithm" = "m.megolm.v1.aes-sha2" class EncryptionKeyAlgorithm(ExtensibleEnum): CURVE25519: "EncryptionKeyAlgorithm" = "curve25519" ED25519: "EncryptionKeyAlgorithm" = "ed25519" SIGNED_CURVE25519: "EncryptionKeyAlgorithm" = "signed_curve25519" @dataclass(frozen=True) class KeyID(Serializable): algorithm: EncryptionKeyAlgorithm key_id: str def serialize(self) -> JSON: return str(self) @classmethod def deserialize(cls, raw: JSON) -> "KeyID": assert isinstance(raw, str), "key IDs must be strings" alg, key_id = raw.split(":", 1) return cls(EncryptionKeyAlgorithm(alg), key_id) def __str__(self) -> str: return f"{self.algorithm.value}:{self.key_id}" class OlmMsgType(Serializable, IntEnum): PREKEY = 0 MESSAGE = 1 def serialize(self) -> JSON: return self.value @classmethod def deserialize(cls, raw: JSON) -> "OlmMsgType": return cls(raw) @dataclass class OlmCiphertext(SerializableAttrs): body: str type: OlmMsgType @dataclass class EncryptedOlmEventContent(SerializableAttrs): ciphertext: Dict[str, OlmCiphertext] sender_key: IdentityKey algorithm: EncryptionAlgorithm = EncryptionAlgorithm.OLM_V1 @dataclass class EncryptedMegolmEventContent(SerializableAttrs): """The content of an m.room.encrypted event""" ciphertext: str session_id: SessionID algorithm: EncryptionAlgorithm = EncryptionAlgorithm.MEGOLM_V1 _sender_key: Optional[IdentityKey] = field(default=None, json="sender_key") _device_id: Optional[DeviceID] = field(default=None, json="device_id") _relates_to: Optional[RelatesTo] = field(default=None, json="m.relates_to") @property def sender_key(self) -> Optional[IdentityKey]: """ .. deprecated:: 0.17.0 Matrix v1.3 deprecated the device_id and sender_key fields in megolm events. """ warnings.warn( "The sender_key field in Megolm events was deprecated in Matrix 1.3", DeprecationWarning, ) return self._sender_key @property def device_id(self) -> Optional[DeviceID]: """ .. deprecated:: 0.17.0 Matrix v1.3 deprecated the device_id and sender_key fields in megolm events. """ warnings.warn( "The sender_key field in Megolm events was deprecated in Matrix 1.3", DeprecationWarning, ) return self._device_id @property def relates_to(self) -> RelatesTo: if self._relates_to is None: self._relates_to = RelatesTo() return self._relates_to @relates_to.setter def relates_to(self, relates_to: RelatesTo) -> None: self._relates_to = relates_to EncryptedEventContent = NewType( "EncryptedEventContent", Union[EncryptedOlmEventContent, EncryptedMegolmEventContent] ) @deserializer(EncryptedEventContent) def deserialize_encrypted(data: JSON) -> Union[EncryptedEventContent, Obj]: alg = data.get("algorithm", None) if alg == EncryptionAlgorithm.MEGOLM_V1.value: return EncryptedMegolmEventContent.deserialize(data) elif alg == EncryptionAlgorithm.OLM_V1.value: return EncryptedOlmEventContent.deserialize(data) return Obj(**data) setattr(EncryptedEventContent, "deserialize", deserialize_encrypted) @dataclass class EncryptedEvent(BaseRoomEvent, SerializableAttrs): """A m.room.encrypted event""" content: EncryptedEventContent _unsigned: Optional[BaseUnsigned] = field(default=None, json="unsigned") @property def unsigned(self) -> BaseUnsigned: if not self._unsigned: self._unsigned = BaseUnsigned() return self._unsigned @unsigned.setter def unsigned(self, value: BaseUnsigned) -> None: self._unsigned = value python-0.20.7/mautrix/types/event/ephemeral.py000066400000000000000000000042151473573527000214320ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, List, NewType, Union from attr import dataclass from ..primitive import JSON, EventID, RoomID, UserID from ..util import ExtensibleEnum, SerializableAttrs, SerializableEnum, deserializer from .base import BaseEvent, GenericEvent from .type import EventType @dataclass class TypingEventContent(SerializableAttrs): user_ids: List[UserID] @dataclass class TypingEvent(BaseEvent, SerializableAttrs): room_id: RoomID content: TypingEventContent class PresenceState(SerializableEnum): ONLINE = "online" OFFLINE = "offline" UNAVAILABLE = "unavailable" @dataclass class PresenceEventContent(SerializableAttrs): presence: PresenceState last_active_ago: int = None status_msg: str = None currently_active: bool = None @dataclass class PresenceEvent(BaseEvent, SerializableAttrs): sender: UserID content: PresenceEventContent @dataclass class SingleReceiptEventContent(SerializableAttrs): ts: int class ReceiptType(ExtensibleEnum): READ = "m.read" READ_PRIVATE = "m.read.private" ReceiptEventContent = Dict[EventID, Dict[ReceiptType, Dict[UserID, SingleReceiptEventContent]]] @dataclass class ReceiptEvent(BaseEvent, SerializableAttrs): room_id: RoomID content: ReceiptEventContent EphemeralEvent = NewType("EphemeralEvent", Union[PresenceEvent, TypingEvent, ReceiptEvent]) @deserializer(EphemeralEvent) def deserialize_ephemeral_event(data: JSON) -> EphemeralEvent: event_type = EventType.find(data.get("type", None)) if event_type == EventType.RECEIPT: evt = ReceiptEvent.deserialize(data) elif event_type == EventType.TYPING: evt = TypingEvent.deserialize(data) elif event_type == EventType.PRESENCE: evt = PresenceEvent.deserialize(data) else: evt = GenericEvent.deserialize(data) evt.type = event_type return evt setattr(EphemeralEvent, "deserialize", deserialize_ephemeral_event) python-0.20.7/mautrix/types/event/generic.py000066400000000000000000000060131473573527000211020ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import NewType, Union from ..primitive import JSON from ..util import Obj, deserializer from .account_data import AccountDataEvent, AccountDataEventContent from .base import EventType, GenericEvent from .beeper import BeeperMessageStatusEvent, BeeperMessageStatusEventContent from .encrypted import EncryptedEvent, EncryptedEventContent from .ephemeral import ( EphemeralEvent, PresenceEvent, ReceiptEvent, ReceiptEventContent, TypingEvent, TypingEventContent, ) from .message import MessageEvent, MessageEventContent from .reaction import ReactionEvent, ReactionEventContent from .redaction import RedactionEvent, RedactionEventContent from .state import StateEvent, StateEventContent from .to_device import ASToDeviceEvent, ToDeviceEvent, ToDeviceEventContent from .voip import CallEvent, CallEventContent, type_to_class as voip_types Event = NewType( "Event", Union[ MessageEvent, ReactionEvent, RedactionEvent, StateEvent, TypingEvent, ReceiptEvent, PresenceEvent, EncryptedEvent, ToDeviceEvent, ASToDeviceEvent, CallEvent, BeeperMessageStatusEvent, GenericEvent, ], ) EventContent = Union[ MessageEventContent, RedactionEventContent, ReactionEventContent, StateEventContent, AccountDataEventContent, ReceiptEventContent, TypingEventContent, EncryptedEventContent, ToDeviceEventContent, CallEventContent, BeeperMessageStatusEventContent, Obj, ] @deserializer(Event) def deserialize_event(data: JSON) -> Event: event_type = EventType.find(data.get("type", None)) if event_type == EventType.ROOM_MESSAGE: return MessageEvent.deserialize(data) elif event_type == EventType.STICKER: data.get("content", {})["msgtype"] = "m.sticker" return MessageEvent.deserialize(data) elif event_type == EventType.REACTION: return ReactionEvent.deserialize(data) elif event_type == EventType.ROOM_REDACTION: return RedactionEvent.deserialize(data) elif event_type == EventType.ROOM_ENCRYPTED: return EncryptedEvent.deserialize(data) elif event_type in voip_types.keys(): return CallEvent.deserialize(data, event_type=event_type) elif event_type.is_to_device: return ToDeviceEvent.deserialize(data) elif event_type.is_state: return StateEvent.deserialize(data) elif event_type.is_account_data: return AccountDataEvent.deserialize(data) elif event_type.is_ephemeral: return EphemeralEvent.deserialize(data) elif event_type == EventType.BEEPER_MESSAGE_STATUS: return BeeperMessageStatusEvent.deserialize(data) else: return GenericEvent.deserialize(data) setattr(Event, "deserialize", deserialize_event) python-0.20.7/mautrix/types/event/message.py000066400000000000000000000310601473573527000211120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, List, Optional, Pattern, Union from html import escape import re from attr import dataclass import attr from ..primitive import JSON, ContentURI, EventID from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field from .base import BaseRoomEvent, BaseUnsigned # region Message types class Format(ExtensibleEnum): """A message format. Currently only ``org.matrix.custom.html`` is available. This will probably be deprecated when extensible events are implemented.""" HTML: "Format" = "org.matrix.custom.html" TEXT_MESSAGE_TYPES = ("m.text", "m.emote", "m.notice") MEDIA_MESSAGE_TYPES = ("m.image", "m.sticker", "m.video", "m.audio", "m.file") class MessageType(ExtensibleEnum): """A message type.""" TEXT: "MessageType" = "m.text" EMOTE: "MessageType" = "m.emote" NOTICE: "MessageType" = "m.notice" IMAGE: "MessageType" = "m.image" STICKER: "MessageType" = "m.sticker" VIDEO: "MessageType" = "m.video" AUDIO: "MessageType" = "m.audio" FILE: "MessageType" = "m.file" LOCATION: "MessageType" = "m.location" @property def is_text(self) -> bool: return self.value in TEXT_MESSAGE_TYPES @property def is_media(self) -> bool: return self.value in MEDIA_MESSAGE_TYPES # endregion # region Relations @dataclass class InReplyTo(SerializableAttrs): event_id: EventID class RelationType(ExtensibleEnum): ANNOTATION: "RelationType" = "m.annotation" REFERENCE: "RelationType" = "m.reference" REPLACE: "RelationType" = "m.replace" THREAD: "RelationType" = "m.thread" @dataclass class RelatesTo(SerializableAttrs): """Message relations. Used for reactions, edits and replies.""" rel_type: RelationType = None event_id: Optional[EventID] = None key: Optional[str] = None is_falling_back: Optional[bool] = None in_reply_to: Optional[InReplyTo] = field(default=None, json="m.in_reply_to") def __bool__(self) -> bool: return (bool(self.rel_type) and bool(self.event_id)) or bool(self.in_reply_to) def serialize(self) -> JSON: if not self: return attr.NOTHING return super().serialize() # endregion # region Base event content class BaseMessageEventContentFuncs: """Base class for the contents of all message-type events (currently m.room.message and m.sticker). Contains relation helpers.""" body: str _relates_to: Optional[RelatesTo] def set_reply(self, reply_to: Union[EventID, "MessageEvent"], **kwargs) -> None: self.relates_to.in_reply_to = InReplyTo( event_id=reply_to if isinstance(reply_to, str) else reply_to.event_id ) def set_thread_parent( self, thread_parent: Union[EventID, "MessageEvent"], last_event_in_thread: Union[EventID, "MessageEvent", None] = None, disable_reply_fallback: bool = False, **kwargs, ) -> None: self.relates_to.rel_type = RelationType.THREAD self.relates_to.event_id = ( thread_parent if isinstance(thread_parent, str) else thread_parent.event_id ) if isinstance(thread_parent, MessageEvent) and isinstance( thread_parent.content, BaseMessageEventContentFuncs ): self.relates_to.event_id = ( thread_parent.content.get_thread_parent() or self.relates_to.event_id ) if not disable_reply_fallback: self.set_reply(last_event_in_thread or thread_parent, **kwargs) self.relates_to.is_falling_back = True def set_edit(self, edits: Union[EventID, "MessageEvent"]) -> None: self.relates_to.rel_type = RelationType.REPLACE self.relates_to.event_id = edits if isinstance(edits, str) else edits.event_id # Library consumers may create message content by setting a reply first, # then later marking it as an edit. As edits can't change the reply, just remove # the reply metadata when marking as a reply. if self.relates_to.in_reply_to: self.relates_to.in_reply_to = None self.relates_to.is_falling_back = None def serialize(self) -> JSON: data = SerializableAttrs.serialize(self) evt = self.get_edit() if evt: new_content = {**data} del new_content["m.relates_to"] data["m.new_content"] = new_content if "body" in data: data["body"] = f"* {data['body']}" if "formatted_body" in data: data["formatted_body"] = f"* {data['formatted_body']}" return data @property def relates_to(self) -> RelatesTo: if self._relates_to is None: self._relates_to = RelatesTo() return self._relates_to @relates_to.setter def relates_to(self, relates_to: RelatesTo) -> None: self._relates_to = relates_to def get_reply_to(self) -> Optional[EventID]: if self._relates_to and self._relates_to.in_reply_to: return self._relates_to.in_reply_to.event_id return None def get_edit(self) -> Optional[EventID]: if self._relates_to and self._relates_to.rel_type == RelationType.REPLACE: return self._relates_to.event_id return None def get_thread_parent(self) -> Optional[EventID]: if self._relates_to and self._relates_to.rel_type == RelationType.THREAD: return self._relates_to.event_id return None def trim_reply_fallback(self) -> None: pass @dataclass class BaseMessageEventContent(BaseMessageEventContentFuncs): """Base event content for all m.room.message-type events.""" msgtype: MessageType = None body: str = "" external_url: str = None _relates_to: Optional[RelatesTo] = attr.ib(default=None, metadata={"json": "m.relates_to"}) # endregion # region Media info @dataclass class JSONWebKey(SerializableAttrs): key: str = attr.ib(metadata={"json": "k"}) algorithm: str = attr.ib(default="A256CTR", metadata={"json": "alg"}) extractable: bool = attr.ib(default=True, metadata={"json": "ext"}) key_type: str = attr.ib(default="oct", metadata={"json": "kty"}) key_ops: List[str] = attr.ib(factory=lambda: ["encrypt", "decrypt"]) @dataclass class EncryptedFile(SerializableAttrs): key: JSONWebKey iv: str hashes: Dict[str, str] url: Optional[ContentURI] = None version: str = attr.ib(default="v2", metadata={"json": "v"}) @dataclass class BaseFileInfo(SerializableAttrs): mimetype: str = None size: int = None @dataclass class ThumbnailInfo(BaseFileInfo, SerializableAttrs): """Information about the thumbnail for a document, video, image or location.""" height: int = attr.ib(default=None, metadata={"json": "h"}) width: int = attr.ib(default=None, metadata={"json": "w"}) orientation: int = None @dataclass class FileInfo(BaseFileInfo, SerializableAttrs): """Information about a document message.""" thumbnail_info: Optional[ThumbnailInfo] = None thumbnail_file: Optional[EncryptedFile] = None thumbnail_url: Optional[ContentURI] = None @dataclass class ImageInfo(FileInfo, SerializableAttrs): """Information about an image message.""" height: int = attr.ib(default=None, metadata={"json": "h"}) width: int = attr.ib(default=None, metadata={"json": "w"}) orientation: int = None @dataclass class VideoInfo(ImageInfo, SerializableAttrs): """Information about a video message.""" duration: int = None orientation: int = None @dataclass class AudioInfo(BaseFileInfo, SerializableAttrs): """Information about an audio message.""" duration: int = None MediaInfo = Union[ImageInfo, VideoInfo, AudioInfo, FileInfo, Obj] @dataclass class LocationInfo(SerializableAttrs): """Information about a location message.""" thumbnail_url: Optional[ContentURI] = None thumbnail_info: Optional[ThumbnailInfo] = None thumbnail_file: Optional[EncryptedFile] = None # endregion # region Event content @dataclass class LocationMessageEventContent(BaseMessageEventContent, SerializableAttrs): geo_uri: str = None info: LocationInfo = None html_reply_fallback_regex: Pattern = re.compile(r"^[\s\S]+") @dataclass class TextMessageEventContent(BaseMessageEventContent, SerializableAttrs): """The content of a text message event (m.text, m.notice, m.emote)""" format: Format = None formatted_body: str = None def ensure_has_html(self) -> None: if not self.formatted_body or self.format != Format.HTML: self.format = Format.HTML self.formatted_body = escape(self.body).replace("\n", "
") def formatted(self, format: Format) -> Optional[str]: if self.format == format: return self.formatted_body return None def trim_reply_fallback(self) -> None: if self.get_reply_to() and not getattr(self, "__reply_fallback_trimmed", False): self._trim_reply_fallback_text() self._trim_reply_fallback_html() setattr(self, "__reply_fallback_trimmed", True) def _trim_reply_fallback_text(self) -> None: if ( not self.body.startswith("> <") and not self.body.startswith("> * <") ) or "\n" not in self.body: return lines = self.body.split("\n") while len(lines) > 0 and lines[0].startswith("> "): lines.pop(0) self.body = "\n".join(lines).strip() def _trim_reply_fallback_html(self) -> None: if self.formatted_body and self.format == Format.HTML: self.formatted_body = html_reply_fallback_regex.sub("", self.formatted_body) @dataclass class MediaMessageEventContent(TextMessageEventContent, SerializableAttrs): """The content of a media message event (m.image, m.audio, m.video, m.file)""" url: Optional[ContentURI] = None info: Optional[MediaInfo] = None file: Optional[EncryptedFile] = None filename: Optional[str] = None @staticmethod @deserializer(MediaInfo) @deserializer(Optional[MediaInfo]) def deserialize_info(data: JSON) -> MediaInfo: if not isinstance(data, dict): return Obj() msgtype = data.pop("__mautrix_msgtype", None) if msgtype == "m.image" or msgtype == "m.sticker": return ImageInfo.deserialize(data) elif msgtype == "m.video": return VideoInfo.deserialize(data) elif msgtype == "m.audio": return AudioInfo.deserialize(data) elif msgtype == "m.file": return FileInfo.deserialize(data) else: return Obj(**data) MessageEventContent = Union[ TextMessageEventContent, MediaMessageEventContent, LocationMessageEventContent, Obj ] # endregion @dataclass class MessageUnsigned(BaseUnsigned, SerializableAttrs): """Unsigned information sent with message events.""" transaction_id: str = None html_reply_fallback_format = ( "
" "In reply to " "{displayname}
" "{content}" "
" ) media_reply_fallback_body_map = { MessageType.IMAGE: "an image", MessageType.STICKER: "a sticker", MessageType.AUDIO: "audio", MessageType.VIDEO: "a video", MessageType.FILE: "a file", MessageType.LOCATION: "a location", } @dataclass class MessageEvent(BaseRoomEvent, SerializableAttrs): """An m.room.message event""" content: MessageEventContent unsigned: Optional[MessageUnsigned] = field(factory=lambda: MessageUnsigned()) @staticmethod @deserializer(MessageEventContent) def deserialize_content(data: JSON) -> MessageEventContent: if not isinstance(data, dict): return Obj() rel = data.get("m.relates_to", None) or {} if rel.get("rel_type", None) == RelationType.REPLACE.value: data = data.get("m.new_content", data) data["m.relates_to"] = rel msgtype = data.get("msgtype", None) if msgtype in TEXT_MESSAGE_TYPES: return TextMessageEventContent.deserialize(data) elif msgtype in MEDIA_MESSAGE_TYPES: data.get("info", {})["__mautrix_msgtype"] = msgtype return MediaMessageEventContent.deserialize(data) elif msgtype == "m.location": return LocationMessageEventContent.deserialize(data) else: return Obj(**data) python-0.20.7/mautrix/types/event/reaction.py000066400000000000000000000025641473573527000213010ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional from attr import dataclass import attr from ..util import SerializableAttrs from .base import BaseRoomEvent, BaseUnsigned from .message import RelatesTo @dataclass class ReactionEventContent(SerializableAttrs): """The content of an m.reaction event""" _relates_to: Optional[RelatesTo] = attr.ib(default=None, metadata={"json": "m.relates_to"}) @property def relates_to(self) -> RelatesTo: if self._relates_to is None: self._relates_to = RelatesTo() return self._relates_to @relates_to.setter def relates_to(self, relates_to: RelatesTo) -> None: self._relates_to = relates_to @dataclass class ReactionEvent(BaseRoomEvent, SerializableAttrs): """A m.reaction event""" content: ReactionEventContent _unsigned: Optional[BaseUnsigned] = attr.ib(default=None, metadata={"json": "unsigned"}) @property def unsigned(self) -> BaseUnsigned: if not self._unsigned: self._unsigned = BaseUnsigned() return self._unsigned @unsigned.setter def unsigned(self, value: BaseUnsigned) -> None: self._unsigned = value python-0.20.7/mautrix/types/event/redaction.py000066400000000000000000000020631473573527000214370ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional from attr import dataclass import attr from ..primitive import EventID from ..util import SerializableAttrs from .base import BaseRoomEvent, BaseUnsigned @dataclass class RedactionEventContent(SerializableAttrs): """The content of an m.room.redaction event""" reason: str = None @dataclass class RedactionEvent(BaseRoomEvent, SerializableAttrs): """A m.room.redaction event""" content: RedactionEventContent redacts: EventID _unsigned: Optional[BaseUnsigned] = attr.ib(default=None, metadata={"json": "unsigned"}) @property def unsigned(self) -> BaseUnsigned: if not self._unsigned: self._unsigned = BaseUnsigned() return self._unsigned @unsigned.setter def unsigned(self, value: BaseUnsigned) -> None: self._unsigned = value python-0.20.7/mautrix/types/event/state.py000066400000000000000000000237701473573527000206170ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, List, Optional, Union from attr import dataclass import attr from ..primitive import JSON, ContentURI, EventID, RoomAlias, RoomID, UserID from ..util import Obj, SerializableAttrs, SerializableEnum, deserializer, field from .base import BaseRoomEvent, BaseUnsigned from .encrypted import EncryptionAlgorithm from .type import EventType, RoomType @dataclass class NotificationPowerLevels(SerializableAttrs): room: int = 50 @dataclass class PowerLevelStateEventContent(SerializableAttrs): """The content of a power level event.""" users: Dict[UserID, int] = attr.ib(default=attr.Factory(dict), metadata={"omitempty": False}) users_default: int = 0 events: Dict[EventType, int] = attr.ib( default=attr.Factory(dict), metadata={"omitempty": False} ) events_default: int = 0 notifications: NotificationPowerLevels = attr.ib(factory=lambda: NotificationPowerLevels()) state_default: int = 50 invite: int = 50 kick: int = 50 ban: int = 50 redact: int = 50 def get_user_level(self, user_id: UserID) -> int: return int(self.users.get(user_id, self.users_default)) def set_user_level(self, user_id: UserID, level: int) -> None: if level == self.users_default: del self.users[user_id] else: self.users[user_id] = level def ensure_user_level(self, user_id: UserID, level: int) -> bool: if self.get_user_level(user_id) != level: self.set_user_level(user_id, level) return True return False def get_event_level(self, event_type: EventType) -> int: return int( self.events.get( event_type, (self.state_default if event_type.is_state else self.events_default) ) ) def set_event_level(self, event_type: EventType, level: int) -> None: if level == self.state_default if event_type.is_state else self.events_default: del self.events[event_type] else: self.events[event_type] = level def ensure_event_level(self, event_type: EventType, level: int) -> bool: if self.get_event_level(event_type) != level: self.set_event_level(event_type, level) return True return False class Membership(SerializableEnum): """ The membership state of a user in a room as specified in section `8.4 Room membership`_ of the spec. .. _8.4 Room membership: https://spec.matrix.org/v1.2/client-server-api/#room-membership """ JOIN = "join" LEAVE = "leave" INVITE = "invite" BAN = "ban" KNOCK = "knock" @dataclass class MemberStateEventContent(SerializableAttrs): """The content of a membership event. `Spec link`_ .. _Spec link: https://spec.matrix.org/v1.2/client-server-api/#mroommember""" membership: Membership = Membership.LEAVE avatar_url: ContentURI = None displayname: str = None is_direct: bool = False reason: str = None third_party_invite: JSON = None @dataclass class CanonicalAliasStateEventContent(SerializableAttrs): """ The content of a ``m.room.canonical_alias`` event (:attr:`EventType.ROOM_CANONICAL_ALIAS`). This event is used to inform the room about which alias should be considered the canonical one, and which other aliases point to the room. This could be for display purposes or as suggestion to users which alias to use to advertise and access the room. See also: `m.room.canonical_alias in the spec`_ .. _m.room.canonical_alias in the spec: https://spec.matrix.org/v1.2/client-server-api/#mroomcanonical_alias """ canonical_alias: RoomAlias = attr.ib(default=None, metadata={"json": "alias"}) alt_aliases: List[RoomAlias] = attr.ib(factory=lambda: []) @dataclass class RoomNameStateEventContent(SerializableAttrs): name: str = None @dataclass class RoomTopicStateEventContent(SerializableAttrs): topic: str = None @dataclass class RoomAvatarStateEventContent(SerializableAttrs): url: Optional[ContentURI] = None class JoinRule(SerializableEnum): PUBLIC = "public" KNOCK = "knock" RESTRICTED = "restricted" INVITE = "invite" PRIVATE = "private" KNOCK_RESTRICTED = "knock_restricted" class JoinRestrictionType(SerializableEnum): ROOM_MEMBERSHIP = "m.room_membership" @dataclass class JoinRestriction(SerializableAttrs): type: JoinRestrictionType room_id: Optional[RoomID] = None @dataclass class JoinRulesStateEventContent(SerializableAttrs): join_rule: JoinRule allow: Optional[List[JoinRestriction]] = None @dataclass class RoomPinnedEventsStateEventContent(SerializableAttrs): pinned: List[EventID] = None @dataclass class RoomTombstoneStateEventContent(SerializableAttrs): body: str = None replacement_room: RoomID = None @dataclass class RoomEncryptionStateEventContent(SerializableAttrs): algorithm: EncryptionAlgorithm = None rotation_period_ms: int = 604800000 rotation_period_msgs: int = 100 @dataclass class RoomPredecessor(SerializableAttrs): room_id: RoomID = None event_id: EventID = None @dataclass class RoomCreateStateEventContent(SerializableAttrs): room_version: str = "1" federate: bool = field(json="m.federate", omit_default=True, default=True) predecessor: Optional[RoomPredecessor] = None type: Optional[RoomType] = None @dataclass class SpaceChildStateEventContent(SerializableAttrs): via: List[str] = None order: str = "" suggested: bool = False @dataclass class SpaceParentStateEventContent(SerializableAttrs): via: List[str] = None canonical: bool = False StateEventContent = Union[ PowerLevelStateEventContent, MemberStateEventContent, CanonicalAliasStateEventContent, RoomNameStateEventContent, RoomAvatarStateEventContent, RoomTopicStateEventContent, RoomPinnedEventsStateEventContent, RoomTombstoneStateEventContent, RoomEncryptionStateEventContent, RoomCreateStateEventContent, SpaceChildStateEventContent, SpaceParentStateEventContent, JoinRulesStateEventContent, Obj, ] @dataclass class StrippedStateUnsigned(BaseUnsigned, SerializableAttrs): """Unsigned information sent with state events.""" prev_content: StateEventContent = None prev_sender: UserID = None replaces_state: EventID = None @dataclass class StrippedStateEvent(SerializableAttrs): """Stripped state events included with some invite events.""" content: StateEventContent = None room_id: RoomID = None sender: UserID = None type: EventType = None state_key: str = None unsigned: Optional[StrippedStateUnsigned] = None @property def prev_content(self) -> StateEventContent: if self.unsigned and self.unsigned.prev_content: return self.unsigned.prev_content return state_event_content_map.get(self.type, Obj)() @classmethod def deserialize(cls, data: JSON) -> "StrippedStateEvent": try: event_type = EventType.find(data.get("type", None)) data.get("content", {})["__mautrix_event_type"] = event_type (data.get("unsigned") or {}).get("prev_content", {})[ "__mautrix_event_type" ] = event_type except ValueError: pass return super().deserialize(data) @dataclass class StateUnsigned(StrippedStateUnsigned, SerializableAttrs): invite_room_state: Optional[List[StrippedStateEvent]] = None state_event_content_map = { EventType.ROOM_CREATE: RoomCreateStateEventContent, EventType.ROOM_POWER_LEVELS: PowerLevelStateEventContent, EventType.ROOM_MEMBER: MemberStateEventContent, EventType.ROOM_PINNED_EVENTS: RoomPinnedEventsStateEventContent, EventType.ROOM_CANONICAL_ALIAS: CanonicalAliasStateEventContent, EventType.ROOM_NAME: RoomNameStateEventContent, EventType.ROOM_AVATAR: RoomAvatarStateEventContent, EventType.ROOM_TOPIC: RoomTopicStateEventContent, EventType.ROOM_JOIN_RULES: JoinRulesStateEventContent, EventType.ROOM_TOMBSTONE: RoomTombstoneStateEventContent, EventType.ROOM_ENCRYPTION: RoomEncryptionStateEventContent, EventType.SPACE_CHILD: SpaceChildStateEventContent, EventType.SPACE_PARENT: SpaceParentStateEventContent, } @dataclass class StateEvent(BaseRoomEvent, SerializableAttrs): """A room state event.""" state_key: str content: StateEventContent unsigned: Optional[StateUnsigned] = field(factory=lambda: StateUnsigned()) @property def prev_content(self) -> StateEventContent: if self.unsigned and self.unsigned.prev_content: return self.unsigned.prev_content return state_event_content_map.get(self.type, Obj)() @classmethod def deserialize(cls, data: JSON) -> "StateEvent": try: event_type = EventType.find(data.get("type"), t_class=EventType.Class.STATE) data.get("content", {})["__mautrix_event_type"] = event_type if "prev_content" in data and "prev_content" not in (data.get("unsigned") or {}): # This if is a workaround for Conduit being extremely dumb if data.get("unsigned", {}) is None: data["unsigned"] = {} data.setdefault("unsigned", {})["prev_content"] = data["prev_content"] data.get("unsigned", {}).get("prev_content", {})["__mautrix_event_type"] = event_type except ValueError: return Obj(**data) evt = super().deserialize(data) evt.type = event_type return evt @staticmethod @deserializer(StateEventContent) def deserialize_content(data: JSON) -> StateEventContent: evt_type = data.pop("__mautrix_event_type", None) content_type = state_event_content_map.get(evt_type, None) if not content_type: return Obj(**data) return content_type.deserialize(data) python-0.20.7/mautrix/types/event/to_device.py000066400000000000000000000101371473573527000214310ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import List, Optional, Union from attr import dataclass import attr from ..primitive import JSON, DeviceID, IdentityKey, RoomID, SessionID, SigningKey, UserID from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field from .base import BaseEvent, EventType from .beeper import BeeperRoomKeyAckEventContent from .encrypted import EncryptedOlmEventContent, EncryptionAlgorithm class RoomKeyWithheldCode(ExtensibleEnum): BLACKLISTED: "RoomKeyWithheldCode" = "m.blacklisted" UNVERIFIED: "RoomKeyWithheldCode" = "m.unverified" UNAUTHORIZED: "RoomKeyWithheldCode" = "m.unauthorised" UNAVAILABLE: "RoomKeyWithheldCode" = "m.unavailable" NO_OLM_SESSION: "RoomKeyWithheldCode" = "m.no_olm" BEEPER_REDACTED: "RoomKeyWithheldCode" = "com.beeper.redacted" @dataclass class RoomKeyWithheldEventContent(SerializableAttrs): algorithm: EncryptionAlgorithm sender_key: IdentityKey code: RoomKeyWithheldCode reason: Optional[str] = None room_id: Optional[RoomID] = None session_id: Optional[SessionID] = None @dataclass class RoomKeyEventContent(SerializableAttrs): algorithm: EncryptionAlgorithm room_id: RoomID session_id: SessionID session_key: str beeper_max_age_ms: Optional[int] = field(json="com.beeper.max_age_ms", default=None) beeper_max_messages: Optional[int] = field(json="com.beeper.max_messages", default=None) beeper_is_scheduled: Optional[bool] = field(json="com.beeper.is_scheduled", default=False) class KeyRequestAction(ExtensibleEnum): REQUEST: "KeyRequestAction" = "request" CANCEL: "KeyRequestAction" = "request_cancellation" @dataclass class RequestedKeyInfo(SerializableAttrs): algorithm: EncryptionAlgorithm room_id: RoomID sender_key: IdentityKey session_id: SessionID @dataclass class RoomKeyRequestEventContent(SerializableAttrs): action: KeyRequestAction requesting_device_id: DeviceID request_id: str body: Optional[RequestedKeyInfo] = None @dataclass(kw_only=True) class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs): sender_key: IdentityKey signing_key: SigningKey = attr.ib(metadata={"json": "sender_claimed_ed25519_key"}) forwarding_key_chain: List[str] = attr.ib(metadata={"json": "forwarding_curve25519_key_chain"}) ToDeviceEventContent = Union[ Obj, EncryptedOlmEventContent, RoomKeyWithheldEventContent, RoomKeyEventContent, RoomKeyRequestEventContent, ForwardedRoomKeyEventContent, BeeperRoomKeyAckEventContent, ] to_device_event_content_map = { EventType.TO_DEVICE_ENCRYPTED: EncryptedOlmEventContent, EventType.ROOM_KEY_WITHHELD: RoomKeyWithheldEventContent, EventType.ROOM_KEY_REQUEST: RoomKeyRequestEventContent, EventType.ROOM_KEY: RoomKeyEventContent, EventType.FORWARDED_ROOM_KEY: ForwardedRoomKeyEventContent, EventType.BEEPER_ROOM_KEY_ACK: BeeperRoomKeyAckEventContent, } @dataclass class ToDeviceEvent(BaseEvent, SerializableAttrs): sender: UserID content: ToDeviceEventContent @classmethod def deserialize(cls, data: JSON) -> "ToDeviceEvent": try: evt_type = EventType.find(data.get("type"), t_class=EventType.Class.TO_DEVICE) data.setdefault("content", {})["__mautrix_event_type"] = evt_type except ValueError: return Obj(**data) evt = super().deserialize(data) evt.type = evt_type return evt @staticmethod @deserializer(ToDeviceEventContent) def deserialize_content(data: JSON) -> ToDeviceEventContent: evt_type = data.pop("__mautrix_event_type", None) content_type = to_device_event_content_map.get(evt_type, None) if not content_type: return Obj(**data) return content_type.deserialize(data) @dataclass class ASToDeviceEvent(ToDeviceEvent, SerializableAttrs): to_user_id: UserID to_device_id: DeviceID python-0.20.7/mautrix/types/event/type.py000066400000000000000000000203561473573527000204550ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Optional import json from ..primitive import JSON from ..util import ExtensibleEnum, Serializable, SerializableEnum class RoomType(ExtensibleEnum): SPACE = "m.space" class EventType(Serializable): """ An immutable enum-like class that represents a specific Matrix event type. In addition to the plain event type string, this also includes the context that the event is used in (see: :class:`Class`). Comparing ``EventType`` instances for equality will check both the type string and the class. The idea behind the wrapper is that incoming event parsers will always create an ``EventType`` instance with the correct class, regardless of what the usual context for the event is. Then when the event is being handled, the type will not be equal to ``EventType`` instances with a different class. For example, if someone sends a non-state ``m.room.name`` event, checking ``if event.type == EventType.ROOM_NAME`` would return ``False``, because the class would be different. Bugs caused by not checking the context of an event (especially state event vs message event) were very common in the past, and using a wrapper like this helps prevent them. """ _by_event_type = {} class Class(SerializableEnum): """The context that an event type is used in.""" UNKNOWN = "unknown" STATE = "state" """Room state events""" MESSAGE = "message" """Room message events, i.e. room events that are not state events""" ACCOUNT_DATA = "account_data" """Account data events, user-specific storage used for synchronizing info between clients. Can be global or room-specific.""" EPHEMERAL = "ephemeral" """Ephemeral events. Currently only typing notifications, read receipts and presence are in this class, as custom ephemeral events are not yet possible.""" TO_DEVICE = "to_device" """Device-to-device events, primarily used for exchanging encryption keys""" __slots__ = ("t", "t_class") t: str """The type string of the event.""" t_class: Class """The context where the event appeared.""" def __init__(self, t: str, t_class: Class) -> None: object.__setattr__(self, "t", t) object.__setattr__(self, "t_class", t_class) if t not in self._by_event_type: self._by_event_type[t] = self def serialize(self) -> JSON: return self.t @classmethod def deserialize(cls, raw: JSON) -> Any: return cls.find(raw) @classmethod def find(cls, t: str, t_class: Optional[Class] = None) -> "EventType": """ Create a new ``EventType`` instance with the given type and class. If an ``EventType`` instance with the same type string and class has been created before, or if no class is specified here, this will return the same instance instead of making a new one. Examples: >>> from mautrix.client import Client >>> from mautrix.types import EventType >>> MY_CUSTOM_TYPE = EventType.find("com.example.custom_event", EventType.Class.STATE) >>> client = Client(...) >>> @client.on(MY_CUSTOM_TYPE) ... async def handle_event(evt): ... Args: t: The type string. t_class: The class of the event type. Returns: An ``EventType`` instance with the given parameters. """ try: return cls._by_event_type[t].with_class(t_class) except KeyError: return EventType(t, t_class=t_class or cls.Class.UNKNOWN) def json(self) -> str: return json.dumps(self.serialize()) @classmethod def parse_json(cls, data: str) -> "EventType": return cls.deserialize(json.loads(data)) def __setattr__(self, *args, **kwargs) -> None: raise TypeError("EventTypes are frozen") def __delattr__(self, *args, **kwargs) -> None: raise TypeError("EventTypes are frozen") def __str__(self): return self.t def __repr__(self): return f'EventType("{self.t}", EventType.Class.{self.t_class.name})' def __hash__(self): return hash(self.t) ^ hash(self.t_class) def __eq__(self, other: Any) -> bool: if not isinstance(other, EventType): return False return self.t == other.t and self.t_class == other.t_class def with_class(self, t_class: Optional[Class]) -> "EventType": """Return a copy of this ``EventType`` with the given class. If the given class is the same as what this instance has, or if the given class is ``None``, this returns ``self`` instead of making a copy.""" if t_class is None or self.t_class == t_class: return self return EventType(t=self.t, t_class=t_class) @property def is_message(self) -> bool: """A shortcut for ``type.t_class == EventType.Class.MESSAGE``""" return self.t_class == EventType.Class.MESSAGE @property def is_state(self) -> bool: """A shortcut for ``type.t_class == EventType.Class.STATE``""" return self.t_class == EventType.Class.STATE @property def is_ephemeral(self) -> bool: """A shortcut for ``type.t_class == EventType.Class.EPHEMERAL``""" return self.t_class == EventType.Class.EPHEMERAL @property def is_account_data(self) -> bool: """A shortcut for ``type.t_class == EventType.Class.ACCOUNT_DATA``""" return self.t_class == EventType.Class.ACCOUNT_DATA @property def is_to_device(self) -> bool: """A shortcut for ``type.t_class == EventType.Class.TO_DEVICE``""" return self.t_class == EventType.Class.TO_DEVICE _standard_types = { EventType.Class.STATE: { "m.room.canonical_alias": "ROOM_CANONICAL_ALIAS", "m.room.create": "ROOM_CREATE", "m.room.join_rules": "ROOM_JOIN_RULES", "m.room.member": "ROOM_MEMBER", "m.room.power_levels": "ROOM_POWER_LEVELS", "m.room.history_visibility": "ROOM_HISTORY_VISIBILITY", "m.room.name": "ROOM_NAME", "m.room.topic": "ROOM_TOPIC", "m.room.avatar": "ROOM_AVATAR", "m.room.pinned_events": "ROOM_PINNED_EVENTS", "m.room.tombstone": "ROOM_TOMBSTONE", "m.room.encryption": "ROOM_ENCRYPTION", "m.space.child": "SPACE_CHILD", "m.space.parent": "SPACE_PARENT", }, EventType.Class.MESSAGE: { "m.room.redaction": "ROOM_REDACTION", "m.room.message": "ROOM_MESSAGE", "m.room.encrypted": "ROOM_ENCRYPTED", "m.sticker": "STICKER", "m.reaction": "REACTION", "m.call.invite": "CALL_INVITE", "m.call.candidates": "CALL_CANDIDATES", "m.call.select_answer": "CALL_SELECT_ANSWER", "m.call.answer": "CALL_ANSWER", "m.call.hangup": "CALL_HANGUP", "m.call.reject": "CALL_REJECT", "m.call.negotiate": "CALL_NEGOTIATE", "com.beeper.message_send_status": "BEEPER_MESSAGE_STATUS", }, EventType.Class.EPHEMERAL: { "m.receipt": "RECEIPT", "m.typing": "TYPING", "m.presence": "PRESENCE", }, EventType.Class.ACCOUNT_DATA: { "m.direct": "DIRECT", "m.push_rules": "PUSH_RULES", "m.tag": "TAG", "m.ignored_user_list": "IGNORED_USER_LIST", }, EventType.Class.TO_DEVICE: { "m.room.encrypted": "TO_DEVICE_ENCRYPTED", "m.room_key": "ROOM_KEY", "m.room_key.withheld": "ROOM_KEY_WITHHELD", "org.matrix.room_key.withheld": "ORG_MATRIX_ROOM_KEY_WITHHELD", "m.room_key_request": "ROOM_KEY_REQUEST", "m.forwarded_room_key": "FORWARDED_ROOM_KEY", "m.dummy": "TO_DEVICE_DUMMY", "com.beeper.room_key.ack": "BEEPER_ROOM_KEY_ACK", }, EventType.Class.UNKNOWN: { "__ALL__": "ALL", # This is not a real event type }, } for _t_class, _types in _standard_types.items(): for _t, _name in _types.items(): _event_type = EventType(t=_t, t_class=_t_class) setattr(EventType, _name, _event_type) python-0.20.7/mautrix/types/event/type.pyi000066400000000000000000000047341473573527000206300ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, ClassVar, Optional from mautrix.types import JSON, ExtensibleEnum, Serializable, SerializableEnum class RoomType(ExtensibleEnum): SPACE: "RoomType" class EventType(Serializable): class Class(SerializableEnum): UNKNOWN = "unknown" STATE = "state" MESSAGE = "message" ACCOUNT_DATA = "account_data" EPHEMERAL = "ephemeral" TO_DEVICE = "to_device" _by_event_type: ClassVar[dict[str, EventType]] ROOM_CANONICAL_ALIAS: "EventType" ROOM_CREATE: "EventType" ROOM_JOIN_RULES: "EventType" ROOM_MEMBER: "EventType" ROOM_POWER_LEVELS: "EventType" ROOM_HISTORY_VISIBILITY: "EventType" ROOM_NAME: "EventType" ROOM_TOPIC: "EventType" ROOM_AVATAR: "EventType" ROOM_PINNED_EVENTS: "EventType" ROOM_TOMBSTONE: "EventType" ROOM_ENCRYPTION: "EventType" SPACE_CHILD: "EventType" SPACE_PARENT: "EventType" ROOM_REDACTION: "EventType" ROOM_MESSAGE: "EventType" ROOM_ENCRYPTED: "EventType" STICKER: "EventType" REACTION: "EventType" CALL_INVITE: "EventType" CALL_CANDIDATES: "EventType" CALL_SELECT_ANSWER: "EventType" CALL_ANSWER: "EventType" CALL_HANGUP: "EventType" CALL_REJECT: "EventType" CALL_NEGOTIATE: "EventType" BEEPER_MESSAGE_STATUS: "EventType" RECEIPT: "EventType" TYPING: "EventType" PRESENCE: "EventType" DIRECT: "EventType" PUSH_RULES: "EventType" TAG: "EventType" IGNORED_USER_LIST: "EventType" TO_DEVICE_ENCRYPTED: "EventType" TO_DEVICE_DUMMY: "EventType" ROOM_KEY: "EventType" ROOM_KEY_WITHHELD: "EventType" ORG_MATRIX_ROOM_KEY_WITHHELD: "EventType" ROOM_KEY_REQUEST: "EventType" FORWARDED_ROOM_KEY: "EventType" BEEPER_ROOM_KEY_ACK: "EventType" ALL: "EventType" is_message: bool is_state: bool is_ephemeral: bool is_account_data: bool is_to_device: bool t: str t_class: Class def __init__(self, t: str, t_class: Class) -> None: ... @classmethod def find(cls, t: str, t_class: Optional[Class] = None) -> "EventType": ... def serialize(self) -> JSON: ... @classmethod def deserialize(cls, raw: JSON) -> Any: ... def with_class(self, t_class: Class) -> "EventType": ... python-0.20.7/mautrix/types/event/voip.py000066400000000000000000000063671473573527000204570ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Generic, List, Optional, TypeVar, Union from attr import dataclass import attr from ..primitive import JSON, UserID from ..util import ExtensibleEnum, SerializableAttrs from .base import BaseRoomEvent from .type import EventType class CallDataType(ExtensibleEnum): OFFER = "offer" ANSWER = "answer" class CallHangupReason(ExtensibleEnum): ICE_FAILED = "ice_failed" INVITE_TIMEOUT = "invite_timeout" USER_HANGUP = "user_hangup" USER_MEDIA_FAILED = "user_media_failed" UNKNOWN_ERROR = "unknown_error" @dataclass class CallData(SerializableAttrs): sdp: str type: CallDataType @dataclass class CallCandidate(SerializableAttrs): candidate: str sdp_m_line_index: int = attr.ib(metadata={"json": "sdpMLineIndex"}, default=None) sdp_mid: str = attr.ib(metadata={"json": "sdpMid"}, default=None) @dataclass class CallInviteEventContent(SerializableAttrs): call_id: str lifetime: int version: int offer: CallData party_id: Optional[str] = None invitee: Optional[UserID] = None @dataclass class CallCandidatesEventContent(SerializableAttrs): call_id: str version: int candidates: List[CallCandidate] party_id: Optional[str] = None @dataclass class CallSelectAnswerEventContent(SerializableAttrs): call_id: str version: int party_id: str selected_party_id: str @dataclass class CallAnswerEventContent(SerializableAttrs): call_id: str version: int answer: CallData party_id: Optional[str] = None @dataclass class CallHangupEventContent(SerializableAttrs): call_id: str version: int reason: CallHangupReason = CallHangupReason.USER_HANGUP party_id: Optional[str] = None @dataclass class CallRejectEventContent(SerializableAttrs): call_id: str version: int party_id: str @dataclass class CallNegotiateEventContent(SerializableAttrs): call_id: str version: int lifetime: int party_id: str description: CallData type_to_class = { EventType.CALL_INVITE: CallInviteEventContent, EventType.CALL_CANDIDATES: CallCandidatesEventContent, EventType.CALL_SELECT_ANSWER: CallSelectAnswerEventContent, EventType.CALL_ANSWER: CallAnswerEventContent, EventType.CALL_HANGUP: CallHangupEventContent, EventType.CALL_NEGOTIATE: CallNegotiateEventContent, EventType.CALL_REJECT: CallRejectEventContent, } CallEventContent = Union[ CallInviteEventContent, CallCandidatesEventContent, CallAnswerEventContent, CallSelectAnswerEventContent, CallHangupEventContent, CallNegotiateEventContent, CallRejectEventContent, ] T = TypeVar("T", bound=CallEventContent) @dataclass class CallEvent(BaseRoomEvent, SerializableAttrs, Generic[T]): content: T @classmethod def deserialize(cls, data: JSON, event_type: Optional[EventType] = None) -> "CallEvent": event_type = event_type or EventType.find(data.get("type")) data["content"] = type_to_class[event_type].deserialize(data["content"]) return super().deserialize(data) python-0.20.7/mautrix/types/filter.py000066400000000000000000000142421473573527000176350ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import List from attr import dataclass from .event import EventType from .primitive import RoomID, UserID from .util import SerializableAttrs, SerializableEnum class EventFormat(SerializableEnum): """ Federation event format enum, as specified in the `create filter endpoint`_. .. _create filter endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ CLIENT = "client" FEDERATION = "federation" @dataclass class EventFilter(SerializableAttrs): """ Event filter object, as specified in the `create filter endpoint`_. .. _create filter endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ limit: int = None """The maximum number of events to return.""" not_senders: List[UserID] = None """A list of sender IDs to exclude. If this list is absent then no senders are excluded A matching sender will be excluded even if it is listed in the :attr:`senders` filter.""" not_types: List[EventType] = None """A list of event types to exclude. If this list is absent then no event types are excluded. A matching type will be excluded even if it is listed in the :attr:`types` filter. A ``'*'`` can be used as a wildcard to match any sequence of characters.""" senders: List[UserID] = None """A list of senders IDs to include. If this list is absent then all senders are included.""" types: List[EventType] = None """A list of event types to include. If this list is absent then all event types are included. A ``'*'`` can be used as a wildcard to match any sequence of characters.""" @dataclass class RoomEventFilter(EventFilter, SerializableAttrs): """ Room event filter object, as specified in the `create filter endpoint`_. .. _create filter endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ lazy_load_members: bool = False """ If ``True``, enables lazy-loading of membership events. See `Lazy-loading room members`_ for more information. .. _Lazy-loading room members: https://matrix.org/docs/spec/client_server/r0.5.0#lazy-loading-room-members """ include_redundant_members: bool = False """ If ``True``, sends all membership events for all events, even if they have already been sent to the client. Does not apply unless :attr:`lazy_load_members` is true. See `Lazy-loading room members`_ for more information.""" not_rooms: List[RoomID] = None """A list of room IDs to exclude. If this list is absent then no rooms are excluded. A matching room will be excluded even if it is listed in the :attr:`rooms` filter.""" rooms: List[RoomID] = None """A list of room IDs to include. If this list is absent then all rooms are included.""" contains_url: bool = None """If ``True``, includes only events with a url key in their content. If ``False``, excludes those events. If omitted, ``url`` key is not considered for filtering.""" @dataclass class StateFilter(RoomEventFilter, SerializableAttrs): """ State event filter object, as specified in the `create filter endpoint`_. Currently this is the same as :class:`RoomEventFilter`. .. _create filter endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ pass @dataclass class RoomFilter(SerializableAttrs): """ Room filter object, as specified in the `create filter endpoint`_. .. _create filter endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ not_rooms: List[RoomID] = None """A list of room IDs to exclude. If this list is absent then no rooms are excluded. A matching room will be excluded even if it is listed in the ``'rooms'`` filter. This filter is applied before the filters in :attr:`ephemeral`, :attr:`state`, :attr:`timeline` or :attr:`account_data`.""" rooms: List[RoomID] = None """A list of room IDs to include. If this list is absent then all rooms are included. This filter is applied before the filters in :attr:`ephemeral`, :attr:`state`, :attr:`timeline` or :attr:`account_data`.""" ephemeral: RoomEventFilter = None """The events that aren't recorded in the room history, e.g. typing and receipts, to include for rooms.""" include_leave: bool = False """Include rooms that the user has left in the sync.""" state: StateFilter = None """The state events to include for rooms.""" timeline: RoomEventFilter = None """The message and state update events to include for rooms.""" account_data: RoomEventFilter = None """The per user account data to include for rooms.""" @dataclass class Filter(SerializableAttrs): """ Base filter object, as specified in the `create filter endpoint`_. .. _create filter endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ event_fields: List[str] = None """List of event fields to include. If this list is absent then all fields are included. The entries may include ``.`` charaters to indicate sub-fields. So ``['content.body']`` will include the ``body`` field of the ``content`` object. A literal ``.`` character in a field name may be escaped using a ``\\``. A server may include more fields than were requested.""" event_format: EventFormat = None """The format to use for events. ``'client'`` will return the events in a format suitable for clients. ``'federation'`` will return the raw event as receieved over federation. The default is :attr:`~EventFormat.CLIENT`.""" presence: EventFilter = None """The presence updates to include.""" account_data: EventFilter = None """The user account data that isn't associated with rooms to include.""" room: RoomFilter = None """Filters to be applied to room data.""" python-0.20.7/mautrix/types/matrixuri.py000066400000000000000000000343751473573527000204050ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import ClassVar, NamedTuple from enum import Enum import urllib.parse from yarl import URL from .primitive import EventID, RoomAlias, RoomID, UserID from .util import ExtensibleEnum class IdentifierType(Enum): """The type qualifier for entities in a Matrix URI.""" EVENT = "$" USER = "@" ROOM_ALIAS = "#" ROOM_ID = "!" @property def sigil(self) -> str: """The sigil of the identifier, used in Matrix events, matrix.to URLs and other places""" return _type_to_sigil[self] @property def uri_type_qualifier(self) -> str: """The type qualifier of the identifier, only used in ``matrix:`` URIs.""" return _type_to_path[self] @classmethod def from_sigil(cls, sigil: str) -> IdentifierType: """Get the IdentifierType corresponding to the given sigil.""" return _sigil_to_type[sigil] @classmethod def from_uri_type_qualifier(cls, uri_type_qualifier: str) -> IdentifierType: """Get the IdentifierType corresponding to the given ``matrix:`` URI type qualifier.""" return _path_to_type[uri_type_qualifier] def __repr__(self) -> str: return f"{self.__class__.__name__}.{self.name}" class URIAction(ExtensibleEnum): """Represents an intent for what the client should do with a Matrix URI.""" JOIN = "join" CHAT = "chat" _type_to_path: dict[IdentifierType, str] = { IdentifierType.EVENT: "e", IdentifierType.USER: "u", IdentifierType.ROOM_ALIAS: "r", IdentifierType.ROOM_ID: "roomid", } _path_to_type: dict[str, IdentifierType] = {v: k for k, v in _type_to_path.items()} _type_to_sigil: dict[IdentifierType, str] = {it: it.value for it in IdentifierType} _sigil_to_type: dict[str, IdentifierType] = {v: k for k, v in _type_to_sigil.items()} class _PathPart(NamedTuple): type: IdentifierType identifier: str @classmethod def from_mxid(cls, mxid: UserID | RoomID | EventID | RoomAlias | str) -> _PathPart: return _PathPart(type=IdentifierType.from_sigil(mxid[0]), identifier=mxid[1:]) @property def mxid(self) -> str: return f"{self.type.sigil}{self.identifier}" def __str__(self) -> str: return self.mxid def __repr__(self) -> str: return f"_PathPart({self.type!r}, {self.identifier!r})" def __eq__(self, other: _PathPart) -> bool: if not isinstance(other, _PathPart): return False return other.type == self.type and other.identifier == self.identifier _uri_base = URL.build(scheme="matrix") class MatrixURIError(ValueError): """Raised by :meth:`MatrixURI.parse` when parsing a URI fails.""" class MatrixURI: """ A container for Matrix URI data. Supports parsing and generating both ``matrix:`` URIs and ``https://matrix.to`` URLs with the same interface. """ URI_BY_DEFAULT: ClassVar[bool] = False """Whether :meth:`__str__` should return the matrix: URI instead of matrix.to URL.""" _part1: _PathPart _part2: _PathPart | None via: list[str] | None """Servers that know about the resource. Important for room ID links.""" action: URIAction | None """The intent for what clients should do with the URI.""" def __init__(self) -> None: """Internal initializer for MatrixURI, external users should use either :meth:`build` or :meth:`parse`.""" self._part2 = None self.via = None self.action = None @classmethod def build( cls, part1: RoomID | UserID | RoomAlias, part2: EventID | None = None, via: list[str] | None = None, action: URIAction | None = None, ) -> MatrixURI: """ Construct a MatrixURI instance using an identifier. Args: part1: The first part of the URI, a user ID, room ID, or room alias. part2: The second part of the URI. Only event IDs are allowed, and only allowed when the first part is a room ID or alias. via: Servers that know about the resource. Important for room ID links. action: The intent for what clients should do with the URI. Returns: The constructed MatrixURI. Raises: ValueError: if one of the identifiers doesn't have a valid sigil. Examples: >>> from mautrix.types import MatrixURI, UserID, RoomAlias, EventID >>> MatrixURI.build(UserID("@user:example.com")).matrix_to_url 'https://matrix.to/#/%40user%3Aexample.com' >>> MatrixURI.build(UserID("@user:example.com")).matrix_uri 'matrix:u/user:example.com' >>> # Picks the format based on the URI_BY_DEFAULT field. >>> # The default value will be changed to True in a later release. >>> str(MatrixURI.build(UserID("@user:example.com"))) 'https://matrix.to/#/%40user%3Aexample.com' >>> MatrixURI.build(RoomAlias("#room:example.com"), EventID("$abc123")).matrix_uri 'matrix:r/room:example.com/e/abc123' """ uri = cls() try: uri._part1 = _PathPart.from_mxid(part1) except KeyError as e: raise ValueError(f"Invalid sigil in part 1 '{part1[0]}'") from e if uri._part1.type == IdentifierType.EVENT: raise ValueError(f"Event ID links must have a room ID or alias too") if part2: try: uri._part2 = _PathPart.from_mxid(part2) except KeyError as e: raise ValueError(f"Invalid sigil in part 2 '{part2[0]}'") from e if uri._part2.type != IdentifierType.EVENT: raise ValueError("The second part of Matrix URIs can only be an event ID") if uri._part1.type not in (IdentifierType.ROOM_ID, IdentifierType.ROOM_ALIAS): raise ValueError("Can't create an event ID link without a room link as the base") uri.via = via uri.action = action return uri @classmethod def try_parse(cls, url: str | URL) -> MatrixURI | None: """ Try to parse a ``matrix:`` URI or ``https://matrix.to`` URL into parts. If parsing fails, return ``None`` instead of throwing an error. Args: url: The URI to parse, either as a string or a :class:`yarl.URL` instance. Returns: The parsed data, or ``None`` if parsing failed. """ try: return cls.parse(url) except ValueError: return None @classmethod def parse(cls, url: str | URL) -> MatrixURI: """ Parse a ``matrix:`` URI or ``https://matrix.to`` URL into parts. Args: url: The URI to parse, either as a string or a :class:`yarl.URL` instance. Returns: The parsed data. Raises: ValueError: if yarl fails to parse the given URL string. MatrixURIError: if the URL isn't valid in the Matrix spec. Examples: >>> from mautrix.types import MatrixURI >>> MatrixURI.parse("https://matrix.to/#/@user:example.com").user_id '@user:example.com' >>> MatrixURI.parse("https://matrix.to/#/#room:example.com/$abc123").event_id '$abc123' >>> MatrixURI.parse("matrix:r/room:example.com/e/abc123").event_id '$abc123' """ url = URL(url) if url.scheme == "matrix": return cls._parse_matrix_uri(url) elif url.scheme == "https" and url.host == "matrix.to": return cls._parse_matrix_to_url(url) else: raise MatrixURIError("Invalid URI (not matrix: nor https://matrix.to)") @classmethod def _parse_matrix_to_url(cls, url: URL) -> MatrixURI: path, *rest = url.raw_fragment.split("?", 1) path_parts = path.split("/") if len(path_parts) < 2: raise MatrixURIError("matrix.to URL doesn't have enough parts") # The first component is the blank part between the # and / if path_parts[0] != "": raise MatrixURIError("first component of matrix.to URL is not empty as expected") query = urllib.parse.parse_qs(rest[0] if len(rest) > 0 else "") uri = cls() part1 = urllib.parse.unquote(path_parts[1]) if len(part1) < 2: raise MatrixURIError(f"Invalid first entity '{part1}' in matrix.to URL") try: uri._part1 = _PathPart.from_mxid(part1) except KeyError as e: raise MatrixURIError( f"Invalid sigil '{part1[0]}' in first entity of matrix.to URL" ) from e if len(path_parts) > 2 and len(path_parts[2]) > 0: part2 = urllib.parse.unquote(path_parts[2]) if len(part2) < 2: raise MatrixURIError(f"Invalid second entity '{part2}' in matrix.to URL") try: uri._part2 = _PathPart.from_mxid(part2) except KeyError as e: raise MatrixURIError( f"Invalid sigil '{part2[0]}' in second entity of matrix.to URL" ) from e uri.via = query.get("via", None) try: uri.action = URIAction(query["action"]) except KeyError: pass return uri @classmethod def _parse_matrix_uri(cls, url: URL) -> MatrixURI: components = url.raw_path.split("/") if len(components) < 2: raise MatrixURIError("URI doesn't contain enough parts") try: type1 = IdentifierType.from_uri_type_qualifier(components[0]) except KeyError as e: raise MatrixURIError( f"Invalid type qualifier '{components[0]}' in first entity of matrix: URI" ) from e if not components[1]: raise MatrixURIError("Identifier in first entity of matrix: URI is empty") uri = cls() uri._part1 = _PathPart(type1, components[1]) if len(components) >= 3 and components[2]: try: type2 = IdentifierType.from_uri_type_qualifier(components[2]) except KeyError as e: raise MatrixURIError( f"Invalid type qualifier '{components[2]}' in second entity of matrix: URI" ) from e if len(components) < 4 or not components[3]: raise MatrixURIError("Identifier in second entity of matrix: URI is empty") uri._part2 = _PathPart(type2, components[3]) uri.via = url.query.getall("via", None) try: uri.action = URIAction(url.query["action"]) except KeyError: pass return uri @property def user_id(self) -> UserID | None: """ Get the user ID from this parsed URI. Returns: The user ID in this URI, or ``None`` if this is not a link to a user. """ if self._part1.type == IdentifierType.USER: return UserID(self._part1.mxid) return None @property def room_id(self) -> RoomID | None: """ Get the room ID from this parsed URI. Returns: The room ID in this URI, or ``None`` if this is not a link to a room (or event). """ if self._part1.type == IdentifierType.ROOM_ID: return RoomID(self._part1.mxid) return None @property def room_alias(self) -> RoomAlias | None: """ Get the room alias from this parsed URI. Returns: The room alias in this URI, or ``None`` if this is not a link to a room (or event). """ if self._part1.type == IdentifierType.ROOM_ALIAS: return RoomAlias(self._part1.mxid) return None @property def event_id(self) -> EventID | None: """ Get the event ID from this parsed URI. Returns: The event ID in this URI, or ``None`` if this is not a link to an event in a room. """ if ( self._part2 and (self.room_id or self.room_alias) and self._part2.type == IdentifierType.EVENT ): return EventID(self._part2.mxid) return None @property def matrix_to_url(self) -> str: """ Convert this parsed URI into a ``https://matrix.to`` URL. Returns: The link as a matrix.to URL. """ url = f"https://matrix.to/#/{urllib.parse.quote(self._part1.mxid)}" if self._part2: url += f"/{urllib.parse.quote(self._part2.mxid)}" qp = [] if self.via: qp += (("via", server) for server in self.via) if self.action: qp.append(("action", self.action)) if qp: url += f"?{urllib.parse.urlencode(qp)}" return url @property def matrix_uri(self) -> str: """ Convert this parsed URI into a ``matrix:`` URI. Returns: The link as a ``matrix:`` URI. """ u = _uri_base / self._part1.type.uri_type_qualifier / self._part1.identifier if self._part2: u = u / self._part2.type.uri_type_qualifier / self._part2.identifier if self.via: u = u.update_query({"via": self.via}) if self.action: u = u.update_query({"action": self.action.value}) return str(u) def __str__(self) -> str: if self.URI_BY_DEFAULT: return self.matrix_uri else: return self.matrix_to_url def __repr__(self) -> str: parts = ", ".join(f"{part!r}" for part in (self._part1, self._part2) if part) return f"MatrixURI({parts}, via={self.via!r}, action={self.action!r})" def __eq__(self, other: MatrixURI) -> bool: """ Checks equality between two parsed Matrix URIs. The order of the via parameters is ignored, but otherwise everything has to match exactly. """ if not isinstance(other, MatrixURI): return False return ( self._part1 == other._part1 and self._part2 == other._part2 and set(self.via or []) == set(other.via or []) and self.action == other.action ) python-0.20.7/mautrix/types/matrixuri_test.py000066400000000000000000000130671473573527000214370ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import NamedTuple import pytest from .matrixuri import IdentifierType, MatrixURI, MatrixURIError, URIAction, _PathPart from .primitive import EventID, RoomAlias, RoomID, UserID def test_basic_parse_uri() -> None: for test in basic_tests: assert MatrixURI.parse(test.uri) == test.parsed def test_basic_stringify_uri() -> None: for test in basic_tests: assert test.uri == test.parsed.matrix_uri def test_basic_parse_url() -> None: for test in basic_tests: assert MatrixURI.parse(test.url) == test.parsed def test_basic_stringify_url() -> None: for test in basic_tests: assert test.url == test.parsed.matrix_to_url def test_basic_build() -> None: for test in basic_tests: assert MatrixURI.build(*test.params) == test.parsed def test_parse_unescaped() -> None: assert MatrixURI.parse("https://matrix.to/#/#hello:world").room_alias == "#hello:world" def test_parse_trailing_slash() -> None: assert MatrixURI.parse("https://matrix.to/#/#hello:world/").room_alias == "#hello:world" assert MatrixURI.parse("matrix:r/hello:world/").room_alias == "#hello:world" def test_parse_errors() -> None: tests = [ "https://example.com", "matrix:invalid/foo", "matrix:hello world", "matrix:/roomid", "matrix:roomid/", "matrix:roomid/foo/e/", "matrix:roomid/foo/e", "https://matrix.to", "https://matrix.to/#/", "https://matrix.to/#foo/#hello:world", "https://matrix.to/#/#hello:world/hmm", ] for test in tests: with pytest.raises(MatrixURIError): print(MatrixURI.parse(test)) def test_build_errors() -> None: with pytest.raises(ValueError): MatrixURI.build("hello world") with pytest.raises(ValueError): MatrixURI.build(EventID("$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")) with pytest.raises(ValueError): MatrixURI.build( UserID("@user:example.org"), EventID("$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"), ) with pytest.raises(ValueError): MatrixURI.build( RoomID("!room:example.org"), RoomID("!anotherroom:example.com"), ) with pytest.raises(ValueError): MatrixURI.build( RoomID("!room:example.org"), "hmm", ) def _make_parsed( part1: _PathPart, part2: _PathPart | None = None, via: list[str] | None = None, action: URIAction | None = None, ) -> MatrixURI: uri = MatrixURI() uri._part1 = part1 uri._part2 = part2 uri.via = via uri.action = action return uri class BasicTestItems(NamedTuple): url: str uri: str parsed: MatrixURI params: tuple[RoomID | UserID | RoomAlias, EventID | None, list[str] | None, URIAction | None] basic_tests = [ BasicTestItems( "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org", "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org", _make_parsed(_PathPart(IdentifierType.ROOM_ID, "7NdBVvkd4aLSbgKt9RXl:example.org")), (RoomID("!7NdBVvkd4aLSbgKt9RXl:example.org"), None, None, None), ), BasicTestItems( "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org?via=maunium.net&via=matrix.org", "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", _make_parsed( _PathPart(IdentifierType.ROOM_ID, "7NdBVvkd4aLSbgKt9RXl:example.org"), via=["maunium.net", "matrix.org"], ), (RoomID("!7NdBVvkd4aLSbgKt9RXl:example.org"), None, ["maunium.net", "matrix.org"], None), ), BasicTestItems( "https://matrix.to/#/%23someroom%3Aexample.org", "matrix:r/someroom:example.org", _make_parsed(_PathPart(IdentifierType.ROOM_ALIAS, "someroom:example.org")), (RoomAlias("#someroom:example.org"), None, None, None), ), BasicTestItems( "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", _make_parsed( _PathPart(IdentifierType.ROOM_ID, "7NdBVvkd4aLSbgKt9RXl:example.org"), _PathPart(IdentifierType.EVENT, "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"), ), ( RoomID("!7NdBVvkd4aLSbgKt9RXl:example.org"), EventID("$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"), None, None, ), ), BasicTestItems( "https://matrix.to/#/%23someroom%3Aexample.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", "matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", _make_parsed( _PathPart(IdentifierType.ROOM_ALIAS, "someroom:example.org"), _PathPart(IdentifierType.EVENT, "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"), ), ( RoomAlias("#someroom:example.org"), EventID("$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"), None, None, ), ), BasicTestItems( "https://matrix.to/#/%40user%3Aexample.org", "matrix:u/user:example.org", _make_parsed(_PathPart(IdentifierType.USER, "user:example.org")), (UserID("@user:example.org"), None, None, None), ), ] python-0.20.7/mautrix/types/media.py000066400000000000000000000046511473573527000174320ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional from attr import dataclass from .primitive import ContentURI from .util import SerializableAttrs, field @dataclass class MediaRepoConfig(SerializableAttrs): """ Matrix media repo config. See `GET /_matrix/media/v3/config`_. .. _GET /_matrix/media/v3/config: https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3config """ upload_size: int = field(default=50 * 1024 * 1024, json="m.upload.size") @dataclass class OpenGraphImage(SerializableAttrs): url: ContentURI = field(default=None, json="og:image") mimetype: str = field(default=None, json="og:image:type") height: int = field(default=None, json="og:image:width") width: int = field(default=None, json="og:image:height") size: int = field(default=None, json="matrix:image:size") @dataclass class OpenGraphVideo(SerializableAttrs): url: ContentURI = field(default=None, json="og:video") mimetype: str = field(default=None, json="og:video:type") height: int = field(default=None, json="og:video:width") width: int = field(default=None, json="og:video:height") size: int = field(default=None, json="matrix:video:size") @dataclass class OpenGraphAudio(SerializableAttrs): url: ContentURI = field(default=None, json="og:audio") mimetype: str = field(default=None, json="og:audio:type") @dataclass class MXOpenGraph(SerializableAttrs): """ Matrix URL preview response. See `GET /_matrix/media/v3/preview_url`_. .. _GET /_matrix/media/v3/preview_url: https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url """ title: str = field(default=None, json="og:title") description: str = field(default=None, json="og:description") image: OpenGraphImage = field(default=None, flatten=True) video: OpenGraphVideo = field(default=None, flatten=True) audio: OpenGraphAudio = field(default=None, flatten=True) @dataclass class MediaCreateResponse(SerializableAttrs): """ Matrix media create response including MSC3870 """ content_uri: ContentURI unused_expired_at: Optional[int] = None unstable_upload_url: Optional[str] = field(default=None, json="com.beeper.msc3870.upload_url") python-0.20.7/mautrix/types/misc.py000066400000000000000000000064371473573527000173120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import List, NamedTuple, NewType, Optional from enum import Enum from attr import dataclass import attr from .event import Event, StateEvent from .primitive import BatchID, ContentURI, EventID, RoomAlias, RoomID, SyncToken, UserID from .util import SerializableAttrs @dataclass class DeviceLists(SerializableAttrs): changed: List[UserID] = attr.ib(factory=lambda: []) left: List[UserID] = attr.ib(factory=lambda: []) def __bool__(self) -> bool: return bool(self.changed or self.left) @dataclass class DeviceOTKCount(SerializableAttrs): signed_curve25519: int = 0 curve25519: int = 0 class RoomCreatePreset(Enum): """ Room creation preset, as specified in the `createRoom endpoint`_ .. _createRoom endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom """ PRIVATE = "private_chat" TRUSTED_PRIVATE = "trusted_private_chat" PUBLIC = "public_chat" class RoomDirectoryVisibility(Enum): """ Room directory visibility, as specified in the `createRoom endpoint`_ .. _createRoom endpoint: https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom """ PRIVATE = "private" PUBLIC = "public" class PaginationDirection(Enum): """Pagination direction used in various endpoints that support pagination.""" FORWARD = "f" BACKWARD = "b" @dataclass class RoomAliasInfo(SerializableAttrs): """ Room alias query result, as specified in the `alias resolve endpoint`_ .. _alias resolve endpoint: https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3directoryroomroomalias """ room_id: RoomID = None """The room ID for this room alias.""" servers: List[str] = None """A list of servers that are aware of this room alias.""" DirectoryPaginationToken = NewType("DirectoryPaginationToken", str) @dataclass class PublicRoomInfo(SerializableAttrs): room_id: RoomID num_joined_members: int world_readable: bool guest_can_join: bool name: str = None topic: str = None avatar_url: ContentURI = None aliases: List[RoomAlias] = None canonical_alias: RoomAlias = None @dataclass class RoomDirectoryResponse(SerializableAttrs): chunk: List[PublicRoomInfo] next_batch: DirectoryPaginationToken = None prev_batch: DirectoryPaginationToken = None total_room_count_estimate: int = None PaginatedMessages = NamedTuple( "PaginatedMessages", start=SyncToken, end=SyncToken, events=List[Event] ) @dataclass class EventContext(SerializableAttrs): end: SyncToken start: SyncToken event: Event events_after: List[Event] events_before: List[Event] state: List[StateEvent] @dataclass class BatchSendResponse(SerializableAttrs): state_event_ids: List[EventID] event_ids: List[EventID] insertion_event_id: EventID batch_event_id: EventID next_batch_id: BatchID base_insertion_event_id: Optional[EventID] = None @dataclass class BeeperBatchSendResponse(SerializableAttrs): event_ids: List[EventID] python-0.20.7/mautrix/types/primitive.py000066400000000000000000000042021473573527000203530ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, List, NewType, Union JSON = NewType("JSON", Union[str, int, float, bool, None, Dict[str, "JSON"], List["JSON"]]) JSON.__doc__ = "A union type that covers all JSON-serializable data." UserID = NewType("UserID", str) UserID.__doc__ = "A Matrix user ID (``@user:example.com``)" EventID = NewType("EventID", str) EventID.__doc__ = "A Matrix event ID (``$base64`` or ``$legacyid:example.com``)" RoomID = NewType("RoomID", str) RoomID.__doc__ = "An internal Matrix room ID (``!randomstring:example.com``)" RoomAlias = NewType("RoomAlias", str) RoomAlias.__doc__ = "A Matrix room address (``#alias:example.com``)" FilterID = NewType("FilterID", str) FilterID.__doc__ = """ A filter ID returned by ``POST /filter`` (:meth:`mautrix.client.ClientAPI.create_filter`) """ BatchID = NewType("BatchID", str) BatchID.__doc__ = """ A message batch ID returned by ``POST /batch_send`` (:meth:`mautrix.appservice.IntentAPI.batch_send`) """ ContentURI = NewType("ContentURI", str) ContentURI.__doc__ = """ A Matrix `content URI`_, used by the content repository. .. _content URI: https://spec.matrix.org/v1.2/client-server-api/#matrix-content-mxc-uris """ SyncToken = NewType("SyncToken", str) SyncToken.__doc__ = """ A ``next_batch`` token from a ``/sync`` response (:meth:`mautrix.client.ClientAPI.sync`) """ DeviceID = NewType("DeviceID", str) DeviceID.__doc__ = "A Matrix device ID. Arbitrary, potentially client-specified string." SessionID = NewType("SessionID", str) SessionID.__doc__ = """ A `Megolm`_ session ID. .. _Megolm: https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/megolm.md """ SigningKey = NewType("SigningKey", str) SigningKey.__doc__ = "A ed25519 public key as unpadded base64" IdentityKey = NewType("IdentityKey", str) IdentityKey.__doc__ = "A curve25519 public key as unpadded base64" Signature = NewType("Signature", str) Signature.__doc__ = "An ed25519 signature as unpadded base64" python-0.20.7/mautrix/types/push_rules.py000066400000000000000000000040631473573527000205410ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import List, Optional, Union from attr import dataclass import attr from .primitive import JSON, RoomID, UserID from .util import ExtensibleEnum, SerializableAttrs, deserializer PushRuleID = Union[RoomID, UserID, str] class PushActionType(ExtensibleEnum): NOTIFY = "notify" DONT_NOTIFY = "dont_notify" COALESCE = "coalesce" @dataclass class PushActionDict(SerializableAttrs): set_tweak: Optional[str] = None value: Optional[str] = None PushAction = Union[PushActionDict, PushActionType] @deserializer(PushAction) def deserialize_push_action(data: JSON) -> PushAction: if isinstance(data, str): return PushActionType(data) else: return PushActionDict.deserialize(data) class PushOperator(ExtensibleEnum): EQUALS = "==" LESS_THAN = "<" GREATER_THAN = ">" LESS_THAN_OR_EQUAL = "<=" GREATER_THAN_OR_EQUAL = ">=" class PushRuleScope(ExtensibleEnum): GLOBAL = "global" class PushConditionKind(ExtensibleEnum): EVENT_MATCH = "event_match" CONTAINS_DISPLAY_NAME = "contains_display_name" ROOM_MEMBER_COUNT = "room_member_count" SENDER_NOTIFICATION_PERMISSION = "sender_notification_permission" class PushRuleKind(ExtensibleEnum): OVERRIDE = "override" SENDER = "sender" ROOM = "room" CONTENT = "content" UNDERRIDE = "underride" @dataclass class PushCondition(SerializableAttrs): kind: PushConditionKind key: Optional[str] = None pattern: Optional[str] = None operator: PushOperator = attr.ib( default=PushOperator.EQUALS, metadata={"json": "is", "omitdefault": True} ) @dataclass class PushRule(SerializableAttrs): rule_id: PushRuleID default: bool enabled: bool actions: List[PushAction] conditions: List[PushCondition] = attr.ib(factory=lambda: []) pattern: Optional[str] = None python-0.20.7/mautrix/types/users.py000066400000000000000000000013621473573527000175100ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import List, NamedTuple from attr import dataclass from .event import Membership from .primitive import ContentURI, UserID from .util import SerializableAttrs @dataclass class Member(SerializableAttrs): membership: Membership = None avatar_url: ContentURI = None displayname: str = None @dataclass class User(SerializableAttrs): user_id: UserID avatar_url: ContentURI = None displayname: str = None class UserSearchResults(NamedTuple): results: List[User] limit: int python-0.20.7/mautrix/types/util/000077500000000000000000000000001473573527000167505ustar00rootroot00000000000000python-0.20.7/mautrix/types/util/__init__.py000066400000000000000000000003301473573527000210550ustar00rootroot00000000000000from .enum import ExtensibleEnum from .obj import Lst, Obj from .serializable import Serializable, SerializableEnum, SerializerError from .serializable_attrs import SerializableAttrs, deserializer, field, serializer python-0.20.7/mautrix/types/util/enum.py000066400000000000000000000100361473573527000202660ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Iterable, Type, cast from ..primitive import JSON from .serializable import Serializable def _is_descriptor(obj): return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") class ExtensibleEnumMeta(type): _by_value: dict[Any, ExtensibleEnum] _by_key: dict[str, ExtensibleEnum] def __new__( mcs: Type[ExtensibleEnumMeta], name: str, bases: tuple[Type, ...], classdict: dict[str, Any], ) -> Type[ExtensibleEnum]: create = [ (key, val) for key, val in classdict.items() if not key.startswith("_") and not _is_descriptor(val) ] classdict = { key: val for key, val in classdict.items() if key.startswith("_") or _is_descriptor(val) } classdict["_by_value"] = {} classdict["_by_key"] = {} enum_class = cast(Type["ExtensibleEnum"], super().__new__(mcs, name, bases, classdict)) for key, val in create: ExtensibleEnum.__new__(enum_class, val).key = key return enum_class def __bool__(cls: Type["ExtensibleEnum"]) -> bool: return True def __contains__(cls: Type["ExtensibleEnum"], value: Any) -> bool: if isinstance(value, cls): return value in cls._by_value.values() else: return value in cls._by_value.keys() def __getattr__(cls: Type["ExtensibleEnum"], name: Any) -> "ExtensibleEnum": try: return cls._by_key[name] except KeyError: raise AttributeError(name) from None def __setattr__(cls: Type["ExtensibleEnum"], key: str, value: Any) -> None: if key.startswith("_"): return super().__setattr__(key, value) if not isinstance(value, cls): value = cls(value) value.key = key def __getitem__(cls: Type["ExtensibleEnum"], name: str) -> "ExtensibleEnum": try: return cls._by_key[name] except KeyError: raise KeyError(name) from None def __setitem__(cls: Type["ExtensibleEnum"], key: str, value: Any) -> None: return cls.__setattr__(cls, key, value) def __iter__(cls: Type["ExtensibleEnum"]) -> Iterable["ExtensibleEnum"]: return cls._by_key.values().__iter__() def __len__(cls: Type["ExtensibleEnum"]) -> int: return len(cls._by_key) def __repr__(cls: Type["ExtensibleEnum"]) -> str: return f"" class ExtensibleEnum(Serializable, metaclass=ExtensibleEnumMeta): _by_value: dict[Any, ExtensibleEnum] _by_key: dict[str, ExtensibleEnum] _inited: bool _key: str | None value: Any def __init__(self, value: Any) -> None: if getattr(self, "_inited", False): return self.value = value self._key = None self._inited = True def __new__(cls: Type[ExtensibleEnum], value: Any) -> ExtensibleEnum: try: return cls._by_value[value] except KeyError as e: self = super().__new__(cls) self.__objclass__ = cls self.__init__(value) cls._by_value[value] = self return self def __str__(self) -> str: return str(self.value) def __repr__(self) -> str: if self._key: return f"<{self.__class__.__name__}.{self._key}: {self.value!r}>" else: return f"{self.__class__.__name__}({self.value!r})" @property def key(self) -> str: return self._key @key.setter def key(self, key: str) -> None: self._key = key self._by_key[key] = self def serialize(self) -> JSON: return self.value @classmethod def deserialize(cls, raw: JSON) -> Any: return cls(raw) python-0.20.7/mautrix/types/util/enum_test.py000066400000000000000000000036101473573527000213250ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from attr import dataclass from .enum import ExtensibleEnum from .serializable_attrs import SerializableAttrs def test_extensible_enum_int(): class Hello(ExtensibleEnum): HI = 1 HMM = 2 assert Hello.HI.value == 1 assert Hello.HI.key == "HI" assert 1 in Hello assert Hello(1) == Hello.HI assert Hello["HMM"] == Hello.HMM assert len(Hello) == 2 hello3 = Hello(3) assert hello3.value == 3 assert not hello3.key Hello.YAY = hello3 assert len(Hello) == 3 assert hello3.key == "YAY" @dataclass class Wrapper(SerializableAttrs): hello: Hello assert Wrapper.deserialize({"hello": 1}).hello == Hello.HI assert Wrapper.deserialize({"hello": 2}).hello == Hello.HMM assert Wrapper.deserialize({"hello": 3}).hello == hello3 assert Wrapper.deserialize({"hello": 4}).hello.value == 4 def test_extensible_enum_str(): class Hello(ExtensibleEnum): HI = "hi" HMM = "🤔" assert Hello.HI.value == "hi" assert Hello.HI.key == "HI" assert "🤔" in Hello assert Hello("🤔") == Hello.HMM assert Hello["HI"] == Hello.HI assert len(Hello) == 2 hello3 = Hello("yay") assert hello3.value == "yay" assert not hello3.key Hello.YAY = hello3 assert len(Hello) == 3 assert hello3.key == "YAY" @dataclass class Wrapper(SerializableAttrs): hello: Hello assert Wrapper.deserialize({"hello": "hi"}).hello == Hello.HI assert Wrapper.deserialize({"hello": "🤔"}).hello == Hello.HMM assert Wrapper.deserialize({"hello": "yay"}).hello == hello3 assert Wrapper.deserialize({"hello": "thonk"}).hello.value == "thonk" python-0.20.7/mautrix/types/util/obj.py000066400000000000000000000043571473573527000201050ustar00rootroot00000000000000# From https://github.com/Lonami/dumbot/blob/master/dumbot.py # Modified to add Serializable base from __future__ import annotations from ..primitive import JSON from .serializable import AbstractSerializable, Serializable class Obj(AbstractSerializable): """""" def __init__(self, **kwargs): self.__dict__ = { k: Obj(**v) if isinstance(v, dict) else (Lst(v) if isinstance(v, list) else v) for k, v in kwargs.items() } def __getattr__(self, name): name = name.rstrip("_") obj = self.__dict__.get(name) if obj is None: obj = Obj() self.__dict__[name] = obj return obj def __getitem__(self, name): return self.__dict__.get(name) def __setitem__(self, key, value): self.__dict__[key] = value def __str__(self): return str(self.serialize()) def __repr__(self): return repr(self.serialize()) def __getstate__(self): return self.__dict__ def __setstate__(self, state): self.__dict__.update(state) def __bool__(self): return bool(self.__dict__) def __contains__(self, item): return item in self.__dict__ def popitem(self): return self.__dict__.popitem() def get(self, key, default=None): obj = self.__dict__.get(key) if obj is None: return default else: return obj def serialize(self) -> dict[str, JSON]: return { k: v.serialize() if isinstance(v, Serializable) else v for k, v in self.__dict__.items() } @classmethod def deserialize(cls, data: dict[str, JSON]) -> Obj: return cls(**data) class Lst(list, AbstractSerializable): def __init__(self, iterable=()): list.__init__( self, ( Obj(**x) if isinstance(x, dict) else (Lst(x) if isinstance(x, list) else x) for x in iterable ), ) def __repr__(self): return super().__repr__() def serialize(self) -> list[JSON]: return [v.serialize() if isinstance(v, Serializable) else v for v in self] @classmethod def deserialize(cls, data: list[JSON]) -> Lst: return cls(data) python-0.20.7/mautrix/types/util/serializable.py000066400000000000000000000062601473573527000217740ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Type, TypeVar, Union from abc import ABC, abstractmethod from enum import Enum import json from ..primitive import JSON SerializableSubtype = TypeVar("SerializableSubtype", bound="SerializableAttrs") class Serializable: """Serializable is the base class for types with custom JSON serializers.""" def serialize(self) -> JSON: """Convert this object into objects directly serializable with `json`.""" raise NotImplementedError() @classmethod def deserialize(cls: Type[SerializableSubtype], raw: JSON) -> SerializableSubtype: """Convert the given data parsed from JSON into an object of this type.""" raise NotImplementedError() def json(self) -> str: """Serialize this object and dump the output as JSON.""" return json.dumps(self.serialize()) @classmethod def parse_json(cls: Type[SerializableSubtype], data: Union[str, bytes]) -> SerializableSubtype: """Parse the given string as JSON and deserialize the result into this type.""" return cls.deserialize(json.loads(data)) class SerializerError(Exception): """ SerializerErrors are raised if something goes wrong during serialization or deserialization. """ pass class UnknownSerializationError(SerializerError): def __init__(self) -> None: super().__init__("Unknown serialization error") class AbstractSerializable(ABC, Serializable): """ An abstract Serializable that adds ``@abstractmethod`` decorators. """ @abstractmethod def serialize(self) -> JSON: pass @classmethod @abstractmethod def deserialize(cls: Type[SerializableSubtype], raw: JSON) -> SerializableSubtype: pass class SerializableEnum(Serializable, Enum): """ A simple Serializable implementation for Enums. Examples: >>> class MyEnum(SerializableEnum): ... FOO = "foo value" ... BAR = "hmm" >>> MyEnum.FOO.serialize() "foo value" >>> MyEnum.BAR.json() '"hmm"' """ def __init__(self, _) -> None: """ A fake ``__init__`` to stop the type checker from complaining. Enum's ``__new__`` overrides this. """ super().__init__() def serialize(self) -> str: """ Convert this object into objects directly serializable with `json`, i.e. return the value set to this enum value. """ return self.value @classmethod def deserialize(cls: Type[SerializableSubtype], raw: str) -> SerializableSubtype: """ Convert the given data parsed from JSON into an object of this type, i.e. find the enum value for the given string using ``cls(raw)``. """ try: return cls(raw) except ValueError as e: raise SerializerError() from e def __str__(self): return str(self.value) def __repr__(self): return f"{self.__class__.__name__}.{self.name}" python-0.20.7/mautrix/types/util/serializable_attrs.py000066400000000000000000000320371473573527000232120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Callable, Dict, Iterator, NewType, Optional, Tuple, Type, TypeVar, Union from uuid import UUID import copy import logging import attr from ..primitive import JSON from .obj import Lst, Obj from .serializable import ( AbstractSerializable, Serializable, SerializableSubtype, SerializerError, UnknownSerializationError, ) T = TypeVar("T") T2 = TypeVar("T2") Serializer = NewType("Serializer", Callable[[T], JSON]) Deserializer = NewType("Deserializer", Callable[[JSON], T]) serializer_map: Dict[Type[T], Serializer] = { UUID: str, } deserializer_map: Dict[Type[T], Deserializer] = { UUID: UUID, } META_JSON = "json" META_FLATTEN = "flatten" META_HIDDEN = "hidden" META_IGNORE_ERRORS = "ignore_errors" META_OMIT_EMPTY = "omitempty" META_OMIT_DEFAULT = "omitdefault" log = logging.getLogger("mau.attrs") def field( default: Any = attr.NOTHING, factory: Optional[Callable[[], Any]] = None, json: Optional[str] = None, flatten: bool = False, hidden: bool = False, ignore_errors: bool = False, omit_empty: bool = True, omit_default: bool = False, metadata: Optional[Dict[str, Any]] = None, **kwargs, ): """ A wrapper around :meth:`attr.ib` to conveniently add SerializableAttrs metadata fields. Args: default: Same as attr.ib, the default value for the field. factory: Same as attr.ib, a factory function that creates the default value. json: The JSON key used for de/serializing the object. flatten: Set to flatten subfields inside this field to be a part of the parent object in serialized objects. When deserializing, the input data will be deserialized into both the parent and child fields, so the classes should ignore unknown keys. hidden: Set to always omit the key from serialized objects. ignore_errors: Set to ignore type errors while deserializing. omit_empty: Set to omit the key from serialized objects if the value is ``None``. omit_default: Set to omit the key from serialized objects if the value is equal to the default. metadata: Additional metadata for attr.ib. **kwargs: Additional keyword arguments for attr.ib. Returns: The decorator function returned by attr.ib. Examples: >>> from attr import dataclass >>> from mautrix.types import SerializableAttrs, field >>> @dataclass ... class SomeData(SerializableAttrs): ... my_field: str = field(json="com.example.namespaced_field", default="hi") ... >>> SomeData().serialize() {'com.example.namespaced_field': 'hi'} >>> SomeData.deserialize({"com.example.namespaced_field": "hmm"}) SomeData(my_field='hmm') """ custom_meta = { META_JSON: json, META_FLATTEN: flatten, META_HIDDEN: hidden, META_IGNORE_ERRORS: ignore_errors, META_OMIT_EMPTY: omit_empty, META_OMIT_DEFAULT: omit_default, } metadata = metadata or {} metadata.update({k: v for k, v in custom_meta.items() if v is not None}) return attr.ib(default=default, factory=factory, metadata=metadata, **kwargs) def serializer(elem_type: Type[T]) -> Callable[[Serializer], Serializer]: """ Define a custom serialization function for the given type. Args: elem_type: The type to define the serializer for. Returns: Decorator for the function. The decorator will simply add the function to a map of deserializers and return the function. Examples: >>> from datetime import datetime >>> from mautrix.types import serializer, JSON >>> @serializer(datetime) ... def serialize_datetime(dt: datetime) -> JSON: ... return dt.timestamp() """ def decorator(func: Serializer) -> Serializer: serializer_map[elem_type] = func return func return decorator def deserializer(elem_type: Type[T]) -> Callable[[Deserializer], Deserializer]: """ Define a custom deserialization function for a given type hint. Args: elem_type: The type hint to define the deserializer for. Returns: Decorator for the function. The decorator will simply add the function to a map of deserializers and return the function. Examples: >>> from datetime import datetime >>> from mautrix.types import deserializer, JSON >>> @deserializer(datetime) ... def deserialize_datetime(data: JSON) -> datetime: ... return datetime.fromtimestamp(data) """ def decorator(func: Deserializer) -> Deserializer: deserializer_map[elem_type] = func return func return decorator def _fields(attrs_type: Type[T], only_if_flatten: bool = None) -> Iterator[Tuple[str, Type[T2]]]: for field in attr.fields(attrs_type): if field.metadata.get(META_HIDDEN, False): continue if only_if_flatten is None or field.metadata.get(META_FLATTEN, False) == only_if_flatten: yield field.metadata.get(META_JSON, field.name), field immutable = int, str, float, bool, type(None) def _safe_default(val: T) -> T: if isinstance(val, immutable): return val elif val is attr.NOTHING: return None elif isinstance(val, attr.Factory): if val.takes_self: # TODO implement? return None else: return val.factory() return copy.copy(val) def _dict_to_attrs( attrs_type: Type[T], data: JSON, default: Optional[T] = None, default_if_empty: bool = False ) -> T: data = data or {} unrecognized = {} new_items = { field_meta.name.lstrip("_"): _try_deserialize(field_meta, data) for _, field_meta in _fields(attrs_type, only_if_flatten=True) } fields = dict(_fields(attrs_type, only_if_flatten=False)) for key, value in data.items(): try: field_meta = fields[key] except KeyError: unrecognized[key] = value continue name = field_meta.name.lstrip("_") try: new_items[name] = _try_deserialize(field_meta, value) except UnknownSerializationError as e: raise SerializerError( f"Failed to deserialize {value} into key {name} of {attrs_type.__name__}" ) from e except SerializerError: raise except Exception as e: raise SerializerError( f"Failed to deserialize {value} into key {name} of {attrs_type.__name__}" ) from e if len(new_items) == 0 and default_if_empty and default is not attr.NOTHING: return _safe_default(default) try: obj = attrs_type(**new_items) except TypeError as e: for key, field_meta in _fields(attrs_type): if field_meta.default is attr.NOTHING and key not in new_items: log.debug("Failed to deserialize %s into %s", data, attrs_type.__name__) json_key = field_meta.metadata.get(META_JSON, key) raise SerializerError( f"Missing value for required key {json_key} in {attrs_type.__name__}" ) from e raise UnknownSerializationError() from e if len(unrecognized) > 0: obj.unrecognized_ = unrecognized return obj def _try_deserialize(field, value: JSON) -> T: try: return _deserialize(field.type, value, field.default) except SerializerError: if not field.metadata.get(META_IGNORE_ERRORS, False): raise except (TypeError, ValueError, KeyError) as e: if not field.metadata.get(META_IGNORE_ERRORS, False): raise UnknownSerializationError() from e def _has_custom_deserializer(cls) -> bool: return issubclass(cls, Serializable) and getattr(cls.deserialize, "__func__") != getattr( SerializableAttrs.deserialize, "__func__" ) def _deserialize(cls: Type[T], value: JSON, default: Optional[T] = None) -> T: if value is None: return _safe_default(default) try: deser = deserializer_map[cls] except KeyError: pass else: return deser(value) supertype = getattr(cls, "__supertype__", None) if supertype: cls = supertype try: deser = deserializer_map[supertype] except KeyError: pass else: return deser(value) if attr.has(cls): if _has_custom_deserializer(cls): return cls.deserialize(value) return _dict_to_attrs(cls, value, default, default_if_empty=True) elif cls == Any or cls == JSON: return value elif isinstance(cls, type) and issubclass(cls, Serializable): return cls.deserialize(value) type_class = getattr(cls, "__origin__", None) args = getattr(cls, "__args__", None) if type_class is Union: if len(args) == 2 and isinstance(None, args[1]): return _deserialize(args[0], value, default) elif type_class == list: (item_cls,) = args return [_deserialize(item_cls, item) for item in value] elif type_class == set: (item_cls,) = args return {_deserialize(item_cls, item) for item in value} elif type_class == dict: key_cls, val_cls = args return { _deserialize(key_cls, key): _deserialize(val_cls, item) for key, item in value.items() } if isinstance(value, list): return Lst(value) elif isinstance(value, dict): return Obj(**value) return value def _actual_type(cls: Type[T]) -> Type[T]: if cls is None: return cls if getattr(cls, "__origin__", None) is Union: if len(cls.__args__) == 2 and isinstance(None, cls.__args__[1]): return cls.__args__[0] return cls def _get_serializer(cls: Type[T]) -> Serializer: return serializer_map.get(_actual_type(cls), _serialize) def _serialize_attrs_field(data: T, field: T2) -> JSON: field_val = getattr(data, field.name) if field_val is None: if not field.metadata.get(META_OMIT_EMPTY, True): if field.default is not attr.NOTHING: field_val = _safe_default(field.default) else: return attr.NOTHING if field.metadata.get(META_OMIT_DEFAULT, False) and field_val == field.default: return attr.NOTHING return _get_serializer(field.type)(field_val) def _attrs_to_dict(data: T) -> JSON: new_dict = {} for json_name, field in _fields(data.__class__): if not json_name: continue serialized = _serialize_attrs_field(data, field) if serialized is not attr.NOTHING: if field.metadata.get(META_FLATTEN, False) and isinstance(serialized, dict): new_dict.update(serialized) else: new_dict[json_name] = serialized try: new_dict.update(data.unrecognized_) except (AttributeError, TypeError): pass return new_dict def _serialize(val: Any) -> JSON: if isinstance(val, Serializable): return val.serialize() elif isinstance(val, (tuple, list, set)): return [_serialize(subval) for subval in val] elif isinstance(val, dict): return {_serialize(subkey): _serialize(subval) for subkey, subval in val.items()} elif attr.has(val.__class__): return _attrs_to_dict(val) return val class SerializableAttrs(AbstractSerializable): """ An abstract :class:`Serializable` that assumes the subclass is an attrs dataclass. Examples: >>> from attr import dataclass >>> from mautrix.types import SerializableAttrs >>> @dataclass ... class Foo(SerializableAttrs): ... index: int ... field: Optional[str] = None """ unrecognized_: Dict[str, JSON] def __init__(self): self.unrecognized_ = {} @classmethod def deserialize(cls: Type[SerializableSubtype], data: JSON) -> SerializableSubtype: return _dict_to_attrs(cls, data) def serialize(self) -> JSON: return _attrs_to_dict(self) def get(self, item: str, default: Any = None) -> Any: try: return self[item] except KeyError: return default def __contains__(self, item: str) -> bool: return hasattr(self, item) or item in getattr(self, "unrecognized_", {}) def __getitem__(self, item: str) -> Any: try: return getattr(self, item) except AttributeError: try: return self.unrecognized_[item] except AttributeError: self.unrecognized_ = {} raise KeyError(item) def __setitem__(self, item: str, value: Any) -> None: if hasattr(self, item): setattr(self, item, value) else: try: self.unrecognized_[item] = value except AttributeError: self.unrecognized_ = {item: value} python-0.20.7/mautrix/types/util/serializable_attrs_test.py000066400000000000000000000216701473573527000242520ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import List, Optional from attr import dataclass import pytest from ..primitive import JSON from .serializable_attrs import Serializable, SerializableAttrs, SerializerError, field def test_simple_class(): @dataclass class Foo(SerializableAttrs): hello: int world: str serialized = {"hello": 5, "world": "hi"} deserialized = Foo.deserialize(serialized) assert deserialized == Foo(5, "hi") assert deserialized.serialize() == serialized with pytest.raises(SerializerError): Foo.deserialize({"world": "hi"}) def test_default(): @dataclass class Default(SerializableAttrs): no_default: int defaultful_value: int = 5 d1 = Default.deserialize({"no_default": 3}) assert d1.no_default == 3 assert d1.defaultful_value == 5 d2 = Default.deserialize({"no_default": 4, "defaultful_value": 6}) assert d2.no_default == 4 assert d2.defaultful_value == 6 def test_factory(): @dataclass class Factory(SerializableAttrs): manufactured_value: List[str] = field(factory=lambda: ["hi"]) assert Factory.deserialize({}).manufactured_value == ["hi"] assert Factory.deserialize({"manufactured_value": None}).manufactured_value == ["hi"] factory1 = Factory.deserialize({}) factory2 = Factory.deserialize({}) assert factory1.manufactured_value is not factory2.manufactured_value assert Factory.deserialize({"manufactured_value": ["bye"]}).manufactured_value == ["bye"] def test_hidden(): @dataclass class HiddenField(SerializableAttrs): visible: str hidden: int = field(hidden=True, default=5) deserialized_hidden = HiddenField.deserialize({"visible": "yay", "hidden": 4}) assert deserialized_hidden.hidden == 5 assert deserialized_hidden.unrecognized_["hidden"] == 4 assert HiddenField("hmm", 5).serialize() == {"visible": "hmm"} def test_ignore_errors(): @dataclass class Something(SerializableAttrs): required: bool @dataclass class Wrapper(SerializableAttrs): something: Optional[Something] = field(ignore_errors=True) @dataclass class ErroringWrapper(SerializableAttrs): something: Optional[Something] = field(ignore_errors=False) assert Wrapper.deserialize({"something": {"required": True}}) == Wrapper(Something(True)) assert Wrapper.deserialize({"something": {}}) == Wrapper(None) with pytest.raises(SerializerError): ErroringWrapper.deserialize({"something": 5}) with pytest.raises(SerializerError): ErroringWrapper.deserialize({"something": {}}) def test_json_key_override(): @dataclass class Meow(SerializableAttrs): meow: int = field(json="fi.mau.namespaced_meow") serialized = {"fi.mau.namespaced_meow": 123} deserialized = Meow.deserialize(serialized) assert deserialized == Meow(123) assert deserialized.serialize() == serialized def test_omit_empty(): @dataclass class OmitEmpty(SerializableAttrs): omitted: Optional[int] = field(omit_empty=True) @dataclass class DontOmitEmpty(SerializableAttrs): not_omitted: Optional[int] = field(omit_empty=False) assert OmitEmpty(None).serialize() == {} assert OmitEmpty(0).serialize() == {"omitted": 0} assert DontOmitEmpty(None).serialize() == {"not_omitted": None} assert DontOmitEmpty(0).serialize() == {"not_omitted": 0} def test_omit_default(): @dataclass class OmitDefault(SerializableAttrs): omitted: int = field(default=5, omit_default=True) @dataclass class DontOmitDefault(SerializableAttrs): not_omitted: int = 5 assert OmitDefault().serialize() == {} assert OmitDefault(5).serialize() == {} assert OmitDefault(6).serialize() == {"omitted": 6} assert DontOmitDefault().serialize() == {"not_omitted": 5} assert DontOmitDefault(5).serialize() == {"not_omitted": 5} assert DontOmitDefault(6).serialize() == {"not_omitted": 6} def test_flatten(): from mautrix.types import ContentURI @dataclass class OpenGraphImage(SerializableAttrs): url: ContentURI = field(default=None, json="og:image") mimetype: str = field(default=None, json="og:image:type") height: int = field(default=None, json="og:image:width") width: int = field(default=None, json="og:image:height") size: int = field(default=None, json="matrix:image:size") @dataclass class OpenGraphVideo(SerializableAttrs): url: ContentURI = field(default=None, json="og:video") mimetype: str = field(default=None, json="og:video:type") height: int = field(default=None, json="og:video:width") width: int = field(default=None, json="og:video:height") size: int = field(default=None, json="matrix:video:size") @dataclass class OpenGraphAudio(SerializableAttrs): url: ContentURI = field(default=None, json="og:audio") mimetype: str = field(default=None, json="og:audio:type") @dataclass class MXOpenGraph(SerializableAttrs): title: str = field(default=None, json="og:title") description: str = field(default=None, json="og:description") image: OpenGraphImage = field(default=None, flatten=True) video: OpenGraphVideo = field(default=None, flatten=True) audio: OpenGraphAudio = field(default=None, flatten=True) example_com_preview = { "og:title": "Example Domain", "og:description": "Example Domain\n\nThis domain is for use in illustrative examples in " "documents. You may use this domain in literature without prior " "coordination or asking for permission.\n\nMore information...", } google_com_preview = { "og:title": "Google", "og:image": "mxc://maunium.net/2021-06-20_jkscuJXkHjvzNaUJ", "og:description": "Search\n\nImages\n\nMaps\n\nPlay\n\nYouTube\n\nNews\n\nGmail\n\n" "Drive\n\nMore\n\n\u00bb\n\nWeb History\n\n|\n\nSettings\n\n|\n\n" "Sign in\n\nAdvanced search\n\nGoogle offered in:\n\nDeutsch\n\n" "AdvertisingPrograms\n\nBusiness Solutions", "og:image:width": 128, "og:image:height": 128, "og:image:type": "image/png", "matrix:image:size": 3428, } example_com_deserialized = MXOpenGraph.deserialize(example_com_preview) assert example_com_deserialized.title == "Example Domain" assert example_com_deserialized.image is None assert example_com_deserialized.video is None assert example_com_deserialized.audio is None google_com_deserialized = MXOpenGraph.deserialize(google_com_preview) assert google_com_deserialized.title == "Google" assert google_com_deserialized.image is not None assert google_com_deserialized.image.url == "mxc://maunium.net/2021-06-20_jkscuJXkHjvzNaUJ" assert google_com_deserialized.image.width == 128 assert google_com_deserialized.image.height == 128 assert google_com_deserialized.image.mimetype == "image/png" assert google_com_deserialized.image.size == 3428 assert google_com_deserialized.video is None assert google_com_deserialized.audio is None def test_flatten_arbitrary_serializable(): @dataclass class CustomSerializable(Serializable): is_hello: bool = True def serialize(self) -> JSON: if self.is_hello: return {"hello": "world"} return {} @classmethod def deserialize(cls, raw: JSON) -> "CustomSerializable": return CustomSerializable(is_hello=raw.get("hello") == "world") @dataclass class Thing(SerializableAttrs): custom: CustomSerializable = field(flatten=True) another: int = field(default=5) thing_1 = { "hello": "world", "another": 6, } thing_1_deserialized = Thing.deserialize(thing_1) assert thing_1_deserialized.custom.is_hello is True assert thing_1_deserialized.another == 6 thing_2 = { "another": 4, } thing_2_deserialized = Thing.deserialize(thing_2) assert thing_2_deserialized.custom.is_hello is False assert thing_2_deserialized.another == 4 assert Thing(custom=CustomSerializable(is_hello=True)).serialize() == { "hello": "world", "another": 5, } def test_flatten_optional(): @dataclass class OptionalThing(SerializableAttrs): key: str @classmethod def deserialize(cls, data: JSON) -> Optional["OptionalThing"]: if "key" in data: return super().deserialize(data) return None @dataclass class ThingWithOptional(SerializableAttrs): optional: OptionalThing = field(flatten=True) another_field: int = 2 assert ThingWithOptional.deserialize({}).optional is None assert ThingWithOptional.deserialize({"key": "hi"}).optional.key == "hi" python-0.20.7/mautrix/types/versions.py000066400000000000000000000113631473573527000202210ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, List, NamedTuple, Optional, Union from enum import IntEnum import re from attr import dataclass import attr from . import JSON from .util import Serializable, SerializableAttrs class VersionFormat(IntEnum): UNKNOWN = -1 LEGACY = 0 MODERN = 1 def __repr__(self) -> str: return f"VersionFormat.{self.name}" legacy_version_regex = re.compile(r"^r(\d+)\.(\d+)\.(\d+)$") modern_version_regex = re.compile(r"^v(\d+)\.(\d+)$") @attr.dataclass(frozen=True) class Version(Serializable): format: VersionFormat major: int minor: int patch: int raw: str def __str__(self) -> str: if self.format == VersionFormat.MODERN: return f"v{self.major}.{self.minor}" elif self.format == VersionFormat.LEGACY: return f"r{self.major}.{self.minor}.{self.patch}" else: return self.raw def serialize(self) -> JSON: return str(self) @classmethod def deserialize(cls, raw: JSON) -> "Version": assert isinstance(raw, str), "versions must be strings" if modern := modern_version_regex.fullmatch(raw): major, minor = modern.groups() return Version(VersionFormat.MODERN, int(major), int(minor), 0, raw) elif legacy := legacy_version_regex.fullmatch(raw): major, minor, patch = legacy.groups() return Version(VersionFormat.LEGACY, int(major), int(minor), int(patch), raw) else: return Version(VersionFormat.UNKNOWN, 0, 0, 0, raw) class SpecVersions: R010 = Version.deserialize("r0.1.0") R020 = Version.deserialize("r0.2.0") R030 = Version.deserialize("r0.3.0") R040 = Version.deserialize("r0.4.0") R050 = Version.deserialize("r0.5.0") R060 = Version.deserialize("r0.6.0") R061 = Version.deserialize("r0.6.1") V11 = Version.deserialize("v1.1") V12 = Version.deserialize("v1.2") V13 = Version.deserialize("v1.3") V14 = Version.deserialize("v1.4") V15 = Version.deserialize("v1.5") V16 = Version.deserialize("v1.6") V17 = Version.deserialize("v1.7") V18 = Version.deserialize("v1.8") V19 = Version.deserialize("v1.9") V110 = Version.deserialize("v1.10") V111 = Version.deserialize("v1.11") @dataclass class VersionsResponse(SerializableAttrs): versions: List[Version] unstable_features: Dict[str, bool] = attr.ib(factory=lambda: {}) def supports(self, thing: Union[Version, str]) -> Optional[bool]: """ Check if the versions response contains the given spec version or unstable feature. Args: thing: The spec version (as a :class:`Version` or string) or unstable feature name (as a string) to check. Returns: ``True`` if the exact version or unstable feature is supported, ``False`` if it's not supported, ``None`` for unstable features which are not included in the response at all. """ if isinstance(thing, Version): return thing in self.versions elif (parsed_version := Version.deserialize(thing)).format != VersionFormat.UNKNOWN: return parsed_version in self.versions return self.unstable_features.get(thing) def supports_at_least(self, version: Union[Version, str]) -> bool: """ Check if the versions response contains the given spec version or any higher version. Args: version: The spec version as a :class:`Version` or a string. Returns: ``True`` if a version equal to or higher than the given version is found, ``False`` otherwise. """ if isinstance(version, str): version = Version.deserialize(version) return any(v for v in self.versions if v > version) @property def latest_version(self) -> Version: return max(self.versions) @property def has_legacy_versions(self) -> bool: """ Check if the response contains any legacy (r0.x.y) versions. .. deprecated:: 0.16.10 :meth:`supports_at_least` and :meth:`supports` methods are now preferred. """ return any(v for v in self.versions if v.format == VersionFormat.LEGACY) @property def has_modern_versions(self) -> bool: """ Check if the response contains any modern (v1.1 or higher) versions. .. deprecated:: 0.16.10 :meth:`supports_at_least` and :meth:`supports` methods are now preferred. """ return self.supports_at_least(SpecVersions.V11) python-0.20.7/mautrix/util/000077500000000000000000000000001473573527000156045ustar00rootroot00000000000000python-0.20.7/mautrix/util/__init__.py000066400000000000000000000007761473573527000177270ustar00rootroot00000000000000__all__ = [ # Directory modules "async_db", "config", "db", "formatter", "logging", # File modules "async_body", "async_getter_lock", "background_task", "bridge_state", "color_log", "ffmpeg", "file_store", "format_duration", "magic", "manhole", "markdown", "message_send_checkpoint", "opt_prometheus", "program", "signed_token", "simple_lock", "simple_template", "utf16_surrogate", "variation_selector", ] python-0.20.7/mautrix/util/async_body.py000066400000000000000000000063611473573527000203160ustar00rootroot00000000000000# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import AsyncGenerator, Union import logging import aiohttp AsyncBody = AsyncGenerator[Union[bytes, bytearray, memoryview], None] async def async_iter_bytes(data: bytearray | bytes, chunk_size: int = 1024**2) -> AsyncBody: """ Return memory views into a byte array in chunks. This is used to prevent aiohttp from copying the entire request body. Args: data: The underlying data to iterate through. chunk_size: How big each returned chunk should be. Returns: An async generator that yields the given data in chunks. """ with memoryview(data) as mv: for i in range(0, len(data), chunk_size): yield mv[i : i + chunk_size] class FileTooLargeError(Exception): def __init__(self, max_size: int) -> None: super().__init__(f"File size larger than maximum ({max_size / 1024 / 1024} MiB)") _default_dl_log = logging.getLogger("mau.util.download") async def read_response_chunks( resp: aiohttp.ClientResponse, max_size: int, log: logging.Logger = _default_dl_log ) -> bytearray: """ Read the body from an aiohttp response in chunks into a mutable bytearray. Args: resp: The aiohttp response object to read the body from. max_size: The maximum size to read. FileTooLargeError will be raised if the Content-Length is higher than this, or if the body exceeds this size during reading. log: A logger for logging download status. Returns: The body data as a byte array. Raises: FileTooLargeError: if the body is larger than the provided max_size. """ content_length = int(resp.headers.get("Content-Length", "0")) if 0 < max_size < content_length: raise FileTooLargeError(max_size) size_str = "unknown length" if content_length == 0 else f"{content_length} bytes" log.info(f"Reading file download response with {size_str} (max: {max_size})") data = bytearray(content_length) mv = memoryview(data) if content_length > 0 else None read_size = 0 max_size += 1 while True: block = await resp.content.readany() if not block: break max_size -= len(block) if max_size <= 0: raise FileTooLargeError(max_size) if len(data) >= read_size + len(block): mv[read_size : read_size + len(block)] = block elif len(data) > read_size: log.warning("File being downloaded is bigger than expected") mv[read_size:] = block[: len(data) - read_size] mv.release() mv = None data.extend(block[len(data) - read_size :]) else: if mv is not None: mv.release() mv = None data.extend(block) read_size += len(block) if mv is not None: mv.release() log.info(f"Successfully read {read_size} bytes of file download response") return data __all__ = ["AsyncBody", "FileTooLargeError", "async_iter_bytes", "async_read_bytes"] python-0.20.7/mautrix/util/async_db/000077500000000000000000000000001473573527000173665ustar00rootroot00000000000000python-0.20.7/mautrix/util/async_db/__init__.py000066400000000000000000000017051473573527000215020ustar00rootroot00000000000000from mautrix import __optional_imports__ from .connection import LoggingConnection as Connection from .database import Database from .errors import ( DatabaseException, DatabaseNotOwned, ForeignTablesFound, UnsupportedDatabaseVersion, ) from .scheme import Scheme from .upgrade import UpgradeTable, register_upgrade try: from .asyncpg import PostgresDatabase except ImportError: if __optional_imports__: raise PostgresDatabase = None try: from aiosqlite import Cursor as SQLiteCursor from .aiosqlite import SQLiteDatabase except ImportError: if __optional_imports__: raise SQLiteDatabase = None SQLiteCursor = None __all__ = [ "Database", "UpgradeTable", "register_upgrade", "PostgresDatabase", "SQLiteDatabase", "SQLiteCursor", "Connection", "Scheme", "DatabaseException", "DatabaseNotOwned", "UnsupportedDatabaseVersion", "ForeignTablesFound", ] python-0.20.7/mautrix/util/async_db/aiosqlite.py000066400000000000000000000154121473573527000217350ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, AsyncContextManager from contextlib import asynccontextmanager import asyncio import logging import os import re import sqlite3 from yarl import URL import aiosqlite from .connection import LoggingConnection from .database import Database from .scheme import Scheme from .upgrade import UpgradeTable POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)") class TxnConnection(aiosqlite.Connection): def __init__(self, path: str, **kwargs) -> None: def connector() -> sqlite3.Connection: return sqlite3.connect( path, detect_types=sqlite3.PARSE_DECLTYPES, isolation_level=None, **kwargs ) super().__init__(connector, iter_chunk_size=64) @asynccontextmanager async def transaction(self) -> None: await self.execute("BEGIN TRANSACTION") try: yield except Exception: await self.rollback() raise else: await self.commit() def __execute(self, query: str, *args: Any): query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query) return super().execute(query, args) async def execute( self, query: str, *args: Any, timeout: float | None = None ) -> aiosqlite.Cursor: return await self.__execute(query, *args) async def executemany( self, query: str, *args: Any, timeout: float | None = None ) -> aiosqlite.Cursor: query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query) return await super().executemany(query, *args) async def fetch( self, query: str, *args: Any, timeout: float | None = None ) -> list[sqlite3.Row]: async with self.__execute(query, *args) as cursor: return list(await cursor.fetchall()) async def fetchrow( self, query: str, *args: Any, timeout: float | None = None ) -> sqlite3.Row | None: async with self.__execute(query, *args) as cursor: return await cursor.fetchone() async def fetchval( self, query: str, *args: Any, column: int = 0, timeout: float | None = None ) -> Any: row = await self.fetchrow(query, *args) if row is None: return None return row[column] class SQLiteDatabase(Database): scheme = Scheme.SQLITE _parent: SQLiteDatabase | None _pool: asyncio.Queue[TxnConnection] _stopped: bool _conns: int _init_commands: list[str] def __init__( self, url: URL, upgrade_table: UpgradeTable, db_args: dict[str, Any] | None = None, log: logging.Logger | None = None, owner_name: str | None = None, ignore_foreign_tables: bool = True, ) -> None: super().__init__( url, db_args=db_args, upgrade_table=upgrade_table, log=log, owner_name=owner_name, ignore_foreign_tables=ignore_foreign_tables, ) self._parent = None self._path = url.path self._pool = asyncio.Queue(self._db_args.pop("min_size", 1)) self._db_args.pop("max_size", None) self._stopped = False self._conns = 0 self._init_commands = self._add_missing_pragmas(self._db_args.pop("init_commands", [])) @staticmethod def _add_missing_pragmas(init_commands: list[str]) -> list[str]: has_foreign_keys = False has_journal_mode = False has_synchronous = False has_busy_timeout = False for cmd in init_commands: if "PRAGMA" not in cmd: continue if "foreign_keys" in cmd: has_foreign_keys = True elif "journal_mode" in cmd: has_journal_mode = True elif "synchronous" in cmd: has_synchronous = True elif "busy_timeout" in cmd: has_busy_timeout = True if not has_foreign_keys: init_commands.append("PRAGMA foreign_keys = ON") if not has_journal_mode: init_commands.append("PRAGMA journal_mode = WAL") if not has_synchronous and "PRAGMA journal_mode = WAL" in init_commands: init_commands.append("PRAGMA synchronous = NORMAL") if not has_busy_timeout: init_commands.append("PRAGMA busy_timeout = 5000") return init_commands def override_pool(self, db: Database) -> None: assert isinstance(db, SQLiteDatabase) self._parent = db async def start(self) -> None: if self._parent: await super().start() return if self._conns: raise RuntimeError("database pool has already been started") elif self._stopped: raise RuntimeError("database pool can't be restarted") self.log.debug(f"Connecting to {self.url}") self.log.debug(f"Database connection init commands: {self._init_commands}") if os.path.exists(self._path): if not os.access(self._path, os.W_OK): self.log.warning("Database file doesn't seem writable") elif not os.access(os.path.dirname(os.path.abspath(self._path)), os.W_OK): self.log.warning("Database file doesn't exist and directory doesn't seem writable") for _ in range(self._pool.maxsize): conn = await TxnConnection(self._path, **self._db_args) if self._init_commands: cur = await conn.cursor() for command in self._init_commands: self.log.trace("Executing init command: %s", command) await cur.execute(command) await conn.commit() conn.row_factory = sqlite3.Row self._pool.put_nowait(conn) self._conns += 1 await super().start() async def stop(self) -> None: if self._parent: return self._stopped = True while self._conns > 0: conn = await self._pool.get() self._conns -= 1 await conn.close() def acquire(self) -> AsyncContextManager[LoggingConnection]: if self._parent: return self._parent.acquire() return self._acquire() @asynccontextmanager async def _acquire(self) -> LoggingConnection: if self._stopped: raise RuntimeError("database pool has been stopped") conn = await self._pool.get() try: yield LoggingConnection(self.scheme, conn, self.log) finally: self._pool.put_nowait(conn) Database.schemes["sqlite"] = SQLiteDatabase Database.schemes["sqlite3"] = SQLiteDatabase python-0.20.7/mautrix/util/async_db/asyncpg.py000066400000000000000000000070701473573527000214100ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any from contextlib import asynccontextmanager import asyncio import logging import sys import traceback from yarl import URL import asyncpg from .connection import LoggingConnection from .database import Database from .scheme import Scheme from .upgrade import UpgradeTable class PostgresDatabase(Database): scheme = Scheme.POSTGRES _pool: asyncpg.pool.Pool | None _pool_override: bool _exit_on_ice: bool def __init__( self, url: URL, upgrade_table: UpgradeTable, db_args: dict[str, Any] = None, log: logging.Logger | None = None, owner_name: str | None = None, ignore_foreign_tables: bool = True, ) -> None: if url.scheme in ("cockroach", "cockroachdb"): self.scheme = Scheme.COCKROACH # Send postgres scheme to asyncpg url = url.with_scheme("postgres") self._exit_on_ice = True if db_args: self._exit_on_ice = db_args.pop("meow_exit_on_ice", True) db_args.pop("init_commands", None) super().__init__( url, db_args=db_args, upgrade_table=upgrade_table, log=log, owner_name=owner_name, ignore_foreign_tables=ignore_foreign_tables, ) self._pool = None self._pool_override = False def override_pool(self, db: PostgresDatabase) -> None: self._pool = db._pool self._pool_override = True async def start(self) -> None: if not self._pool_override: if self._pool: raise RuntimeError("Database has already been started") self._db_args["loop"] = asyncio.get_running_loop() log_url = self.url if log_url.password: log_url = log_url.with_password("password-redacted") self.log.debug(f"Connecting to {log_url}") self._pool = await asyncpg.create_pool(str(self.url), **self._db_args) await super().start() @property def pool(self) -> asyncpg.pool.Pool: if not self._pool: raise RuntimeError("Database has not been started") return self._pool async def stop(self) -> None: if not self._pool_override and self._pool is not None: await self._pool.close() async def _handle_exception(self, err: Exception) -> None: if self._exit_on_ice and isinstance(err, asyncpg.InternalClientError): pre_stack = traceback.format_stack()[:-2] post_stack = traceback.format_exception(err) header = post_stack[0] post_stack = post_stack[1:] self.log.critical( "Got asyncpg internal client error, exiting...\n%s%s%s", header, "".join(pre_stack), "".join(post_stack), ) sys.exit(26) @asynccontextmanager async def acquire(self) -> LoggingConnection: async with self.pool.acquire() as conn: yield LoggingConnection( self.scheme, conn, self.log, handle_exception=self._handle_exception ) Database.schemes["postgres"] = PostgresDatabase Database.schemes["postgresql"] = PostgresDatabase Database.schemes["cockroach"] = PostgresDatabase Database.schemes["cockroachdb"] = PostgresDatabase python-0.20.7/mautrix/util/async_db/connection.py000066400000000000000000000120401473573527000220740ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Awaitable, Callable, TypeVar from contextlib import asynccontextmanager from logging import WARNING import functools import time from mautrix import __optional_imports__ from mautrix.util.logging import SILLY, TraceLogger from .scheme import Scheme if __optional_imports__: from sqlite3 import Row from aiosqlite import Cursor from asyncpg import Record import asyncpg from . import aiosqlite Decorated = TypeVar("Decorated", bound=Callable[..., Any]) LOG_MESSAGE = "%s(%r) took %.3f seconds" def log_duration(func: Decorated) -> Decorated: func_name = func.__name__ @functools.wraps(func) async def wrapper(self: LoggingConnection, arg: str, *args: Any, **kwargs: str) -> Any: start = time.monotonic() ret = await func(self, arg, *args, **kwargs) duration = time.monotonic() - start self.log.log(WARNING if duration > 1 else SILLY, LOG_MESSAGE, func_name, arg, duration) return ret return wrapper async def handle_exception_noop(_: Exception) -> None: pass class LoggingConnection: def __init__( self, scheme: Scheme, wrapped: aiosqlite.TxnConnection | asyncpg.Connection, log: TraceLogger, handle_exception: Callable[[Exception], Awaitable[None]] = handle_exception_noop, ) -> None: self.scheme = scheme self.wrapped = wrapped self.log = log self._handle_exception = handle_exception self._inited = True def __setattr__(self, key: str, value: Any) -> None: if getattr(self, "_inited", False): raise RuntimeError("LoggingConnection fields are frozen") super().__setattr__(key, value) @asynccontextmanager async def transaction(self) -> None: try: async with self.wrapped.transaction(): yield except Exception as e: await self._handle_exception(e) raise @log_duration async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor: try: return await self.wrapped.execute(query, *args, timeout=timeout) except Exception as e: await self._handle_exception(e) raise @log_duration async def executemany( self, query: str, *args: Any, timeout: float | None = None ) -> str | Cursor: try: return await self.wrapped.executemany(query, *args, timeout=timeout) except Exception as e: await self._handle_exception(e) raise @log_duration async def fetch( self, query: str, *args: Any, timeout: float | None = None ) -> list[Row | Record]: try: return await self.wrapped.fetch(query, *args, timeout=timeout) except Exception as e: await self._handle_exception(e) raise @log_duration async def fetchval( self, query: str, *args: Any, column: int = 0, timeout: float | None = None ) -> Any: try: return await self.wrapped.fetchval(query, *args, column=column, timeout=timeout) except Exception as e: await self._handle_exception(e) raise @log_duration async def fetchrow( self, query: str, *args: Any, timeout: float | None = None ) -> Row | Record | None: try: return await self.wrapped.fetchrow(query, *args, timeout=timeout) except Exception as e: await self._handle_exception(e) raise async def table_exists(self, name: str) -> bool: if self.scheme == Scheme.SQLITE: return await self.fetchval( "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name=?1)", name ) elif self.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): return await self.fetchval( "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)", name ) else: raise RuntimeError(f"Unknown scheme {self.scheme}") @log_duration async def copy_records_to_table( self, table_name: str, *, records: list[tuple[Any, ...]], columns: tuple[str, ...] | list[str], schema_name: str | None = None, timeout: float | None = None, ) -> None: if self.scheme != Scheme.POSTGRES: raise RuntimeError("copy_records_to_table is only supported on Postgres") try: return await self.wrapped.copy_records_to_table( table_name, records=records, columns=columns, schema_name=schema_name, timeout=timeout, ) except Exception as e: await self._handle_exception(e) raise python-0.20.7/mautrix/util/async_db/connection.pyi000066400000000000000000000034651473573527000222600ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, AsyncContextManager, Awaitable, Callable from sqlite3 import Row from asyncpg import Record import asyncpg from mautrix.util.logging import TraceLogger from . import aiosqlite from .scheme import Scheme class LoggingConnection: scheme: Scheme wrapped: aiosqlite.TxnConnection | asyncpg.Connection _handle_exception: Callable[[Exception], Awaitable[None]] log: TraceLogger def __init__( self, scheme: Scheme, wrapped: aiosqlite.TxnConnection | asyncpg.Connection, log: TraceLogger, handle_exception: Callable[[Exception], Awaitable[None]] = None, ) -> None: ... async def transaction(self) -> AsyncContextManager[None]: ... async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: ... async def executemany(self, query: str, *args: Any, timeout: float | None = None) -> str: ... async def fetch( self, query: str, *args: Any, timeout: float | None = None ) -> list[Row | Record]: ... async def fetchval( self, query: str, *args: Any, column: int = 0, timeout: float | None = None ) -> Any: ... async def fetchrow( self, query: str, *args: Any, timeout: float | None = None ) -> Row | Record | None: ... async def table_exists(self, name: str) -> bool: ... async def copy_records_to_table( self, table_name: str, *, records: list[tuple[Any, ...]], columns: tuple[str, ...] | list[str], schema_name: str | None = None, timeout: float | None = None, ) -> None: ... python-0.20.7/mautrix/util/async_db/database.py000066400000000000000000000133641473573527000215130ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, AsyncContextManager, Type from abc import ABC, abstractmethod import logging from yarl import URL from mautrix import __optional_imports__ from mautrix.util.logging import TraceLogger from .connection import LoggingConnection from .errors import DatabaseNotOwned, ForeignTablesFound from .scheme import Scheme from .upgrade import UpgradeTable, upgrade_tables if __optional_imports__: from aiosqlite import Cursor from asyncpg import Record class Database(ABC): schemes: dict[str, Type[Database]] = {} log: TraceLogger scheme: Scheme url: URL _db_args: dict[str, Any] upgrade_table: UpgradeTable | None owner_name: str | None ignore_foreign_tables: bool def __init__( self, url: URL, upgrade_table: UpgradeTable | None, db_args: dict[str, Any] | None = None, log: TraceLogger | None = None, owner_name: str | None = None, ignore_foreign_tables: bool = True, ) -> None: self.url = url self._db_args = {**db_args} if db_args else {} self.upgrade_table = upgrade_table self.owner_name = owner_name self.ignore_foreign_tables = ignore_foreign_tables self.log = log or logging.getLogger("mau.db") assert isinstance(self.log, TraceLogger) @classmethod def create( cls, url: str | URL, *, db_args: dict[str, Any] | None = None, upgrade_table: UpgradeTable | str | None = None, log: logging.Logger | TraceLogger | None = None, owner_name: str | None = None, ignore_foreign_tables: bool = True, ) -> Database: url = URL(url) try: impl = cls.schemes[url.scheme] except KeyError as e: if url.scheme in ("postgres", "postgresql"): raise RuntimeError( f"Unknown database scheme {url.scheme}." " Perhaps you forgot to install asyncpg?" ) from e elif url.scheme in ("sqlite", "sqlite3"): raise RuntimeError( f"Unknown database scheme {url.scheme}." " Perhaps you forgot to install aiosqlite?" ) from e raise RuntimeError(f"Unknown database scheme {url.scheme}") from e if isinstance(upgrade_table, str): upgrade_table = upgrade_tables[upgrade_table] elif upgrade_table is None: upgrade_table = UpgradeTable() elif not isinstance(upgrade_table, UpgradeTable): raise ValueError(f"Can't use {type(upgrade_table)} as the upgrade table") return impl( url, db_args=db_args, upgrade_table=upgrade_table, log=log, owner_name=owner_name, ignore_foreign_tables=ignore_foreign_tables, ) def override_pool(self, db: Database) -> None: pass async def start(self) -> None: if not self.ignore_foreign_tables: await self._check_foreign_tables() if self.owner_name: await self._check_owner() if self.upgrade_table and len(self.upgrade_table.upgrades) > 0: await self.upgrade_table.upgrade(self) async def _check_foreign_tables(self) -> None: if await self.table_exists("state_groups_state"): raise ForeignTablesFound("found state_groups_state likely belonging to Synapse") elif await self.table_exists("roomserver_rooms"): raise ForeignTablesFound("found roomserver_rooms likely belonging to Dendrite") async def _check_owner(self) -> None: await self.execute( """CREATE TABLE IF NOT EXISTS database_owner ( key INTEGER PRIMARY KEY DEFAULT 0, owner TEXT NOT NULL )""" ) owner = await self.fetchval("SELECT owner FROM database_owner WHERE key=0") if not owner: await self.execute("INSERT INTO database_owner (owner) VALUES ($1)", self.owner_name) elif owner != self.owner_name: raise DatabaseNotOwned(owner) @abstractmethod async def stop(self) -> None: pass @abstractmethod def acquire(self) -> AsyncContextManager[LoggingConnection]: pass async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor: async with self.acquire() as conn: return await conn.execute(query, *args, timeout=timeout) async def executemany( self, query: str, *args: Any, timeout: float | None = None ) -> str | Cursor: async with self.acquire() as conn: return await conn.executemany(query, *args, timeout=timeout) async def fetch(self, query: str, *args: Any, timeout: float | None = None) -> list[Record]: async with self.acquire() as conn: return await conn.fetch(query, *args, timeout=timeout) async def fetchval( self, query: str, *args: Any, column: int = 0, timeout: float | None = None ) -> Any: async with self.acquire() as conn: return await conn.fetchval(query, *args, column=column, timeout=timeout) async def fetchrow( self, query: str, *args: Any, timeout: float | None = None ) -> Record | None: async with self.acquire() as conn: return await conn.fetchrow(query, *args, timeout=timeout) async def table_exists(self, name: str) -> bool: async with self.acquire() as conn: return await conn.table_exists(name) python-0.20.7/mautrix/util/async_db/errors.py000066400000000000000000000024561473573527000212630ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations class DatabaseException(RuntimeError): pass @property def explanation(self) -> str | None: return None class UnsupportedDatabaseVersion(DatabaseException): def __init__(self, name: str, version: int, latest: int) -> None: super().__init__( f"Unsupported {name} schema version v{version} (latest known is v{latest})" ) @property def explanation(self) -> str: return "Downgrading is not supported" class ForeignTablesFound(DatabaseException): def __init__(self, explanation: str) -> None: super().__init__(f"The database contains foreign tables ({explanation})") @property def explanation(self) -> str: return "You can use --ignore-foreign-tables to ignore this error" class DatabaseNotOwned(DatabaseException): def __init__(self, owner: str) -> None: super().__init__(f"The database is owned by {owner}") @property def explanation(self) -> str: return "Sharing the same database with different programs is not supported" python-0.20.7/mautrix/util/async_db/scheme.py000066400000000000000000000010741473573527000212060ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from enum import Enum class Scheme(Enum): POSTGRES = "postgres" COCKROACH = "cockroach" SQLITE = "sqlite" def __eq__(self, other: Scheme | str) -> bool: if isinstance(other, str): return self.value == other else: return super().__eq__(other) python-0.20.7/mautrix/util/async_db/upgrade.py000066400000000000000000000152711473573527000213750ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Awaitable, Callable, Optional, cast import functools import inspect import logging from mautrix.util.logging import TraceLogger from .. import async_db from .connection import LoggingConnection from .errors import UnsupportedDatabaseVersion from .scheme import Scheme Upgrade = Callable[[LoggingConnection, Scheme], Awaitable[Optional[int]]] UpgradeWithoutScheme = Callable[[LoggingConnection], Awaitable[Optional[int]]] async def noop_upgrade(_: LoggingConnection, _2: Scheme) -> None: pass def _wrap_upgrade(fn: UpgradeWithoutScheme | Upgrade) -> Upgrade: params = inspect.signature(fn).parameters if len(params) == 1: _wrapped: UpgradeWithoutScheme = cast(UpgradeWithoutScheme, fn) @functools.wraps(_wrapped) async def _wrapper(conn: LoggingConnection, _: Scheme) -> Optional[int]: return await _wrapped(conn) return _wrapper else: return fn class UpgradeTable: upgrades: list[Upgrade] allow_unsupported: bool database_name: str version_table_name: str log: TraceLogger def __init__( self, allow_unsupported: bool = False, version_table_name: str = "version", database_name: str = "database", log: logging.Logger | TraceLogger | None = None, ) -> None: self.upgrades = [] self.allow_unsupported = allow_unsupported self.version_table_name = version_table_name self.database_name = database_name self.log = log or logging.getLogger("mau.db.upgrade") def register( self, _outer_fn: Upgrade | UpgradeWithoutScheme | None = None, *, index: int = -1, description: str = "", transaction: bool = True, upgrades_to: int | Upgrade | None = None, ) -> Upgrade | Callable[[Upgrade | UpgradeWithoutScheme], Upgrade]: if isinstance(index, str): description = index index = -1 def actually_register(fn: Upgrade | UpgradeWithoutScheme) -> Upgrade: fn = _wrap_upgrade(fn) fn.__mau_db_upgrade_description__ = description fn.__mau_db_upgrade_transaction__ = transaction fn.__mau_db_upgrade_destination__ = ( upgrades_to if not upgrades_to or isinstance(upgrades_to, int) else _wrap_upgrade(upgrades_to) ) if index == -1 or index == len(self.upgrades): self.upgrades.append(fn) else: if len(self.upgrades) <= index: self.upgrades += [noop_upgrade] * (index - len(self.upgrades) + 1) self.upgrades[index] = fn return fn return actually_register(_outer_fn) if _outer_fn else actually_register async def _save_version(self, conn: LoggingConnection, version: int) -> None: self.log.trace(f"Saving current version (v{version}) to database") await conn.execute(f"DELETE FROM {self.version_table_name}") await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version) async def upgrade(self, db: async_db.Database) -> None: await db.execute( f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} ( version INTEGER PRIMARY KEY )""" ) row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1") version = row["version"] if row else 0 if len(self.upgrades) < version: unsupported_version_error = UnsupportedDatabaseVersion( self.database_name, version, len(self.upgrades) ) if not self.allow_unsupported: raise unsupported_version_error else: self.log.warning(str(unsupported_version_error)) return elif len(self.upgrades) == version: self.log.debug(f"Database at v{version}, not upgrading") return async with db.acquire() as conn: while version < len(self.upgrades): old_version = version upgrade = self.upgrades[version] new_version = ( getattr(upgrade, "__mau_db_upgrade_destination__", None) or version + 1 ) if callable(new_version): new_version = await new_version(conn, db.scheme) desc = getattr(upgrade, "__mau_db_upgrade_description__", None) suffix = f": {desc}" if desc else "" self.log.debug( f"Upgrading {self.database_name} from v{old_version} to v{new_version}{suffix}" ) if getattr(upgrade, "__mau_db_upgrade_transaction__", True): async with conn.transaction(): version = await upgrade(conn, db.scheme) or new_version await self._save_version(conn, version) else: version = await upgrade(conn, db.scheme) or new_version await self._save_version(conn, version) if version != new_version: self.log.warning( f"Upgrading {self.database_name} actually went from v{old_version} " f"to v{version}" ) upgrade_tables: dict[str, UpgradeTable] = {} def register_upgrade_table_parent_module(name: str) -> None: upgrade_tables[name] = UpgradeTable() def _find_upgrade_table(fn: Upgrade) -> UpgradeTable: try: module = fn.__module__ except AttributeError as e: raise ValueError( "Registering upgrades without an UpgradeTable requires the function " "to have the __module__ attribute." ) from e parts = module.split(".") used_parts = [] last_error = None for part in parts: used_parts.append(part) try: return upgrade_tables[".".join(used_parts)] except KeyError as e: last_error = e raise KeyError( "Registering upgrades without an UpgradeTable requires you to register a parent " "module with register_upgrade_table_parent_module first." ) from last_error def register_upgrade(index: int = -1, description: str = "") -> Callable[[Upgrade], Upgrade]: def actually_register(fn: Upgrade) -> Upgrade: return _find_upgrade_table(fn).register(fn, index=index, description=description) return actually_register python-0.20.7/mautrix/util/async_getter_lock.py000066400000000000000000000036551473573527000216660ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any import functools from mautrix import __optional_imports__ if __optional_imports__: from typing import Awaitable, Callable, ParamSpec Param = ParamSpec("Param") Func = Callable[Param, Awaitable[Any]] def async_getter_lock(fn: Func) -> Func: """ A utility decorator for locking async getters that have caches (preventing race conditions between cache check and e.g. async database actions). The class must have an ```_async_get_locks`` defaultdict that contains :class:`asyncio.Lock`s (see example for exact definition). Non-cache-affecting arguments should be only passed as keyword args. Args: fn: The function to decorate. Returns: The decorated function. Examples: >>> import asyncio >>> from collections import defaultdict >>> class User: ... _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) ... db: Any ... cache: dict[str, User] ... @classmethod ... @async_getter_lock ... async def get(cls, id: str, *, create: bool = False) -> User | None: ... try: ... return cls.cache[id] ... except KeyError: ... pass ... user = await cls.db.fetch_user(id) ... if user: ... return user ... elif create: ... return await cls.db.create_user(id) ... return None """ @functools.wraps(fn) async def wrapper(cls, *args, **kwargs) -> Any: async with cls._async_get_locks[args]: return await fn(cls, *args, **kwargs) return wrapper python-0.20.7/mautrix/util/background_task.py000066400000000000000000000034341473573527000213230ustar00rootroot00000000000000# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Coroutine import asyncio import logging _tasks = set() log = logging.getLogger("mau.background_task") async def catch(coro: Coroutine, caller: str) -> None: try: await coro except Exception: log.exception(f"Uncaught error in background task (created in {caller})") # Logger.findCaller finds the 3rd stack frame, so add an intermediate function # to get the caller of create(). def _find_caller() -> tuple[str, int, str, None]: return log.findCaller() def create(coro: Coroutine, *, name: str | None = None, catch_errors: bool = True) -> asyncio.Task: """ Create a background asyncio task safely, ensuring a reference is kept until the task completes. It also catches and logs uncaught errors (unless disabled via the parameter). Args: coro: The coroutine to wrap in a task and execute. name: An optional name for the created task. catch_errors: Should the task be wrapped in a try-except block to log any uncaught errors? Returns: An asyncio Task object wrapping the given coroutine. """ if catch_errors: try: file_name, line_number, function_name, _ = _find_caller() caller = f"{function_name} at {file_name}:{line_number}" except ValueError: caller = "unknown function" task = asyncio.create_task(catch(coro, caller), name=name) else: task = asyncio.create_task(coro, name=name) _tasks.add(task) task.add_done_callback(_tasks.discard) return task python-0.20.7/mautrix/util/bridge_state.py000066400000000000000000000120011473573527000206040ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, ClassVar, Dict, Optional import logging import time from attr import dataclass import aiohttp from mautrix.api import HTTPAPI from mautrix.types import SerializableAttrs, SerializableEnum, UserID, field class BridgeStateEvent(SerializableEnum): ##################################### # Global state events, no remote ID # ##################################### # Bridge process is starting up STARTING = "STARTING" # Bridge has started but has no valid credentials UNCONFIGURED = "UNCONFIGURED" # Bridge is running RUNNING = "RUNNING" # The server was unable to reach the bridge BRIDGE_UNREACHABLE = "BRIDGE_UNREACHABLE" ################################################ # Remote state events, should have a remote ID # ################################################ # Bridge has credentials and has started connecting to a remote network CONNECTING = "CONNECTING" # Bridge has begun backfilling BACKFILLING = "BACKFILLING" # Bridge has happily connected and is bridging messages CONNECTED = "CONNECTED" # Bridge has temporarily disconnected, expected to reconnect automatically TRANSIENT_DISCONNECT = "TRANSIENT_DISCONNECT" # Bridge has disconnected, will require user to log in again BAD_CREDENTIALS = "BAD_CREDENTIALS" # Bridge has disconnected for an unknown/unexpected reason - we should investigate UNKNOWN_ERROR = "UNKNOWN_ERROR" # User has logged out - stop tracking this remote LOGGED_OUT = "LOGGED_OUT" ok_ish_states = ( BridgeStateEvent.STARTING, BridgeStateEvent.UNCONFIGURED, BridgeStateEvent.RUNNING, BridgeStateEvent.CONNECTING, BridgeStateEvent.CONNECTED, BridgeStateEvent.BACKFILLING, ) @dataclass(kw_only=True) class BridgeState(SerializableAttrs): human_readable_errors: ClassVar[Dict[Optional[str], str]] = {} default_source: ClassVar[str] = "bridge" default_error_ttl: ClassVar[int] = 3600 default_ok_ttl: ClassVar[int] = 21600 state_event: BridgeStateEvent user_id: Optional[UserID] = None remote_id: Optional[str] = None remote_name: Optional[str] = None timestamp: Optional[int] = None ttl: int = 0 source: Optional[str] = None error: Optional[str] = None message: Optional[str] = None info: Optional[Dict[str, Any]] = None reason: Optional[str] = None send_attempts_: int = field(default=0, hidden=True) def fill(self) -> "BridgeState": self.timestamp = self.timestamp or int(time.time()) self.source = self.source or self.default_source if not self.ttl: self.ttl = ( self.default_ok_ttl if self.state_event in ok_ish_states else self.default_error_ttl ) if self.error: try: msg = self.human_readable_errors[self.error] except KeyError: pass else: self.message = msg.format(message=self.message) if self.message else msg return self def should_deduplicate(self, prev_state: Optional["BridgeState"]) -> bool: if ( not prev_state or prev_state.state_event != self.state_event or prev_state.error != self.error or prev_state.info != self.info ): # If there's no previous state or the state was different, send this one. return False # If the previous state is recent, drop this one return prev_state.timestamp + prev_state.ttl > self.timestamp async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool = True) -> bool: if not url: return True self.send_attempts_ += 1 headers = {"Authorization": f"Bearer {token}", "User-Agent": HTTPAPI.default_ua} try: async with ( aiohttp.ClientSession() as sess, sess.post(url, json=self.serialize(), headers=headers) as resp, ): if not 200 <= resp.status < 300: text = await resp.text() text = text.replace("\n", "\\n") log.warning( f"Unexpected status code {resp.status} " f"sending bridge state update: {text}" ) return False elif log_sent: log.debug(f"Sent new bridge state {self}") except Exception as e: log.warning(f"Failed to send updated bridge state: {e}") return False return True @dataclass(kw_only=True) class GlobalBridgeState(SerializableAttrs): remote_states: Optional[Dict[str, BridgeState]] = field(json="remoteState", default=None) bridge_state: BridgeState = field(json="bridgeState") python-0.20.7/mautrix/util/color_log.py000066400000000000000000000001361473573527000201350ustar00rootroot00000000000000# This only exists for compatibility with old log configs from .logging import ColorFormatter python-0.20.7/mautrix/util/config/000077500000000000000000000000001473573527000170515ustar00rootroot00000000000000python-0.20.7/mautrix/util/config/__init__.py000066400000000000000000000011261473573527000211620ustar00rootroot00000000000000from .base import BaseConfig, BaseMissingError, ConfigUpdateHelper from .file import BaseFileConfig, yaml from .proxy import BaseProxyConfig from .recursive_dict import RecursiveDict from .string import BaseStringConfig from .validation import BaseValidatableConfig, ConfigValueError, ForbiddenDefault, ForbiddenKey __all__ = [ "BaseConfig", "BaseMissingError", "ConfigUpdateHelper", "BaseFileConfig", "yaml", "BaseProxyConfig", "RecursiveDict", "BaseStringConfig", "BaseValidatableConfig", "ConfigValueError", "ForbiddenDefault", "ForbiddenKey", ] python-0.20.7/mautrix/util/config/base.py000066400000000000000000000045611473573527000203430ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from abc import ABC, abstractmethod from ruamel.yaml.comments import Comment, CommentedBase, CommentedMap from .recursive_dict import RecursiveDict class BaseMissingError(ValueError): pass class ConfigUpdateHelper: base: RecursiveDict[CommentedMap] def __init__(self, base: RecursiveDict, config: RecursiveDict) -> None: self.base = base self.source = config def copy(self, from_path: str, to_path: str | None = None) -> None: if from_path in self.source: val = self.source[from_path] # Small hack to make sure comments from the user config don't # partially leak into the updated version. if isinstance(val, CommentedBase): setattr(val, Comment.attrib, Comment()) self.base[to_path or from_path] = val def copy_dict( self, from_path: str, to_path: str | None = None, override_existing_map: bool = True, ) -> None: if from_path in self.source: to_path = to_path or from_path if override_existing_map or to_path not in self.base: self.base[to_path] = CommentedMap() for key, value in self.source[from_path].items(): self.base[to_path][key] = value def __iter__(self): yield self.copy yield self.copy_dict yield self.base class BaseConfig(ABC, RecursiveDict[CommentedMap]): @abstractmethod def load(self) -> None: pass @abstractmethod def load_base(self) -> RecursiveDict[CommentedMap] | None: pass def load_and_update(self) -> None: self.load() self.update() @abstractmethod def save(self) -> None: pass def update(self, save: bool = True) -> None: base = self.load_base() if not base: raise BaseMissingError("Can't update() without base config") self.do_update(ConfigUpdateHelper(base, self)) self._data = base._data if save: self.save() @abstractmethod def do_update(self, helper: ConfigUpdateHelper) -> None: pass python-0.20.7/mautrix/util/config/file.py000066400000000000000000000043641473573527000203510ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from abc import ABC import logging import os import pkgutil import tempfile from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap from yarl import URL from .base import BaseConfig from .recursive_dict import RecursiveDict yaml = YAML() yaml.indent(4) yaml.width = 200 log: logging.Logger = logging.getLogger("mau.util.config") class BaseFileConfig(BaseConfig, ABC): def __init__(self, path: str, base_path: str) -> None: super().__init__() self._data = CommentedMap() self.path: str = path self.base_path: str = base_path def load(self) -> None: with open(self.path, "r") as stream: self._data = yaml.load(stream) def load_base(self) -> RecursiveDict[CommentedMap] | None: if self.base_path.startswith("pkg://"): url = URL(self.base_path) return RecursiveDict(yaml.load(pkgutil.get_data(url.host, url.path)), CommentedMap) try: with open(self.base_path, "r") as stream: return RecursiveDict(yaml.load(stream), CommentedMap) except OSError: pass return None def save(self) -> None: try: tf = tempfile.NamedTemporaryFile( mode="w", delete=False, suffix=".yaml", dir=os.path.dirname(self.path) ) except OSError as e: log.warning(f"Failed to create tempfile to write updated config to disk: {e}") return try: yaml.dump(self._data, tf) except OSError as e: log.warning(f"Failed to write updated config to tempfile: {e}") tf.file.close() os.remove(tf.name) return tf.file.close() try: os.rename(tf.name, self.path) except OSError as e: log.warning(f"Failed to rename tempfile with updated config to {self.path}: {e}") try: os.remove(tf.name) except FileNotFoundError: pass python-0.20.7/mautrix/util/config/proxy.py000066400000000000000000000021441473573527000206050ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Callable from abc import ABC from ruamel.yaml.comments import CommentedMap from .base import BaseConfig from .recursive_dict import RecursiveDict class BaseProxyConfig(BaseConfig, ABC): def __init__( self, load: Callable[[], CommentedMap], load_base: Callable[[], RecursiveDict[CommentedMap] | None], save: Callable[[RecursiveDict[CommentedMap]], None], ) -> None: super().__init__() self._data = CommentedMap() self._load_proxy = load self._load_base_proxy = load_base self._save_proxy = save def load(self) -> None: self._data = self._load_proxy() or CommentedMap() def load_base(self) -> RecursiveDict[CommentedMap] | None: return self._load_base_proxy() def save(self) -> None: self._save_proxy(self._data) python-0.20.7/mautrix/util/config/recursive_dict.py000066400000000000000000000066731473573527000224510ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Generic, Type, TypeVar import copy from ruamel.yaml.comments import CommentedMap T = TypeVar("T") class RecursiveDict(Generic[T]): def __init__(self, data: T | None = None, dict_factory: Type[T] | None = None) -> None: self._dict_factory = dict_factory or dict self._data: CommentedMap = data or self._dict_factory() def clone(self) -> RecursiveDict: return RecursiveDict(data=copy.deepcopy(self._data), dict_factory=self._dict_factory) @staticmethod def parse_key(key: str) -> tuple[str, str | None]: if "." not in key: return key, None key, next_key = key.split(".", 1) if len(key) > 0 and key[0] == "[": end_index = next_key.index("]") key = key[1:] + "." + next_key[:end_index] next_key = next_key[end_index + 2 :] if len(next_key) > end_index + 1 else None return key, next_key def _recursive_get(self, data: T, key: str, default_value: Any) -> Any: key, next_key = self.parse_key(key) if next_key is not None: next_data = data.get(key, self._dict_factory()) return self._recursive_get(next_data, next_key, default_value) try: return data[key] except (AttributeError, KeyError): return default_value def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any: if allow_recursion and "." in key: return self._recursive_get(self._data, key, default_value) return self._data.get(key, default_value) def __getitem__(self, key: str) -> Any: return self.get(key, None) def __contains__(self, key: str) -> bool: return self.get(key, None) is not None def _recursive_set(self, data: T, key: str, value: Any) -> None: key, next_key = self.parse_key(key) if next_key is not None: if key not in data: data[key] = self._dict_factory() next_data = data.get(key, self._dict_factory()) return self._recursive_set(next_data, next_key, value) data[key] = value def set(self, key: str, value: Any, allow_recursion: bool = True) -> None: if allow_recursion and "." in key: self._recursive_set(self._data, key, value) return self._data[key] = value def __setitem__(self, key: str, value: Any) -> None: self.set(key, value) def _recursive_del(self, data: T, key: str) -> None: key, next_key = self.parse_key(key) if next_key is not None: if key not in data: return next_data = data[key] return self._recursive_del(next_data, next_key) try: del data[key] del data.ca.items[key] except KeyError: pass def delete(self, key: str, allow_recursion: bool = True) -> None: if allow_recursion and "." in key: self._recursive_del(self._data, key) return try: del self._data[key] del self._data.ca.items[key] except KeyError: pass def __delitem__(self, key: str) -> None: self.delete(key) python-0.20.7/mautrix/util/config/string.py000066400000000000000000000017321473573527000207340ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from abc import ABC import io from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap from .base import BaseConfig from .recursive_dict import RecursiveDict yaml = YAML() yaml.indent(4) yaml.width = 200 class BaseStringConfig(BaseConfig, ABC): def __init__(self, data: str, base_data: str) -> None: super().__init__() self._data = yaml.load(data) self._base = RecursiveDict(yaml.load(base_data), CommentedMap) def load(self) -> None: pass def load_base(self) -> RecursiveDict[CommentedMap] | None: return self._base def save(self) -> str: buf = io.StringIO() yaml.dump(self._data, buf) return buf.getvalue() python-0.20.7/mautrix/util/config/validation.py000066400000000000000000000030201473573527000215500ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any from abc import ABC, abstractmethod from attr import dataclass import attr from .base import BaseConfig class ConfigValueError(ValueError): def __init__(self, key: str, message: str) -> None: super().__init__( f"{key} not configured. {message}" if message else f"{key} not configured" ) class ForbiddenKey(str): pass @dataclass class ForbiddenDefault: key: str value: Any error: str | None = None condition: str | None = attr.ib(default=None, kw_only=True) def check(self, config: BaseConfig) -> bool: if self.condition and not config[self.condition]: return False elif isinstance(self.value, ForbiddenKey): return str(self.value) in config[self.key] else: return config[self.key] == self.value @property def exception(self) -> ConfigValueError: return ConfigValueError(self.key, self.error) class BaseValidatableConfig(BaseConfig, ABC): @property @abstractmethod def forbidden_defaults(self) -> list[ForbiddenDefault]: pass def check_default_values(self) -> None: for default in self.forbidden_defaults: if default.check(self): raise default.exception python-0.20.7/mautrix/util/db/000077500000000000000000000000001473573527000161715ustar00rootroot00000000000000python-0.20.7/mautrix/util/db/__init__.py000066400000000000000000000001031473573527000202740ustar00rootroot00000000000000from .base import Base, BaseClass __all__ = ["Base", "BaseClass"] python-0.20.7/mautrix/util/db/base.py000066400000000000000000000201301473573527000174510ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Type, TypeVar, cast from contextlib import contextmanager from sqlalchemy import Constraint, Table from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.ext.declarative import as_declarative, declarative_base from sqlalchemy.sql.base import ImmutableColumnCollection from sqlalchemy.sql.expression import ClauseElement, Select, and_ if TYPE_CHECKING: from sqlalchemy.engine.result import ResultProxy, RowProxy T = TypeVar("T", bound="BaseClass") class BaseClass: """ Base class for SQLAlchemy models. Provides SQLAlchemy declarative base features and some additional utilities. .. deprecated:: 0.15.0 The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy. """ __tablename__: str db: Engine t: Table __table__: Table c: ImmutableColumnCollection column_names: List[str] @classmethod def bind(cls, db_engine: Engine) -> None: cls.db = db_engine cls.t = cls.__table__ cls.c = cls.t.columns cls.column_names = cls.c.keys() @classmethod def copy( cls, bind: Optional[Engine] = None, rebase: Optional[declarative_base] = None ) -> Type[T]: copy = cast(Type[T], type(cls.__name__, (cls, rebase) if rebase else (cls,), {})) if bind is not None: copy.bind(db_engine=bind) return copy @classmethod def _one_or_none(cls: Type[T], rows: "ResultProxy") -> Optional[T]: """ Try scanning one row from a ResultProxy and return ``None`` if it fails. Args: rows: The SQLAlchemy result to scan. Returns: The scanned object, or ``None`` if there were no rows. """ try: return cls.scan(next(rows)) except StopIteration: return None @classmethod def _all(cls: Type[T], rows: "ResultProxy") -> Iterator[T]: """ Scan all rows from a ResultProxy. Args: rows: The SQLAlchemy result to scan. Yields: Each row scanned with :meth:`scan` """ for row in rows: yield cls.scan(row) @classmethod def scan(cls: Type[T], row: "RowProxy") -> T: """ Read the data from a row into an object. Args: row: The RowProxy object. Returns: An object containing the information in the row. """ return cls(**dict(zip(cls.column_names, row))) @classmethod def _make_simple_select(cls: Type[T], *args: ClauseElement) -> Select: """ Create a simple ``SELECT * FROM table WHERE `` statement. Args: *args: The WHERE clauses. If there are many elements, they're joined with AND. Returns: The SQLAlchemy SELECT statement object. """ if len(args) > 1: return cls.t.select().where(and_(*args)) elif len(args) == 1: return cls.t.select().where(args[0]) else: return cls.t.select() @classmethod def _select_all(cls: Type[T], *args: ClauseElement) -> Iterator[T]: """ Select all rows with given conditions. This is intended to be used by table-specific select methods. Args: *args: The WHERE clauses. If there are many elements, they're joined with AND. Yields: The objects representing the rows read with :meth:`scan` """ yield from cls._all(cls.db.execute(cls._make_simple_select(*args))) @classmethod def _select_one_or_none(cls: Type[T], *args: ClauseElement) -> T: """ Select one row with given conditions. If no row is found, return ``None``. This is intended to be used by table-specific select methods. Args: *args: The WHERE clauses. If there are many elements, they're joined with AND. Returns: The object representing the matched row read with :meth:`scan`, or ``None`` if no rows matched. """ return cls._one_or_none(cls.db.execute(cls._make_simple_select(*args))) def _constraint_to_clause(self, constraint: Constraint) -> ClauseElement: return and_( *[column == self.__dict__[name] for name, column in constraint.columns.items()] ) @property def _edit_identity(self: T) -> ClauseElement: """The SQLAlchemy WHERE clause used for editing and deleting individual rows. Usually AND of primary keys.""" return self._constraint_to_clause(self.t.primary_key) def edit(self: T, *, _update_values: bool = True, **values) -> None: """ Edit this row. Args: _update_values: Whether or not the values in memory should be updated as well as the values in the database. **values: The values to change. """ with self.db.begin() as conn: conn.execute(self.t.update().where(self._edit_identity).values(**values)) if _update_values: for key, value in values.items(): setattr(self, key, value) @contextmanager def edit_mode(self: T) -> None: """ Edit this row in a fancy context manager way. This stores the current edit identity, then yields to the context manager and finally puts the new values into the row using the old edit identity in the WHERE clause. >>> class TableClass(Base): ... ... >>> db_instance = TableClass(id="something") >>> with db_instance.edit_mode(): ... db_instance.id = "new_id" """ old_identity = self._edit_identity yield old_identity with self.db.begin() as conn: conn.execute(self.t.update().where(old_identity).values(**self._insert_values)) def delete(self: T) -> None: """Delete this row.""" with self.db.begin() as conn: conn.execute(self.t.delete().where(self._edit_identity)) @property def _insert_values(self: T) -> Dict[str, Any]: """Values for inserts. Generally you want all the values in the table.""" return { column_name: self.__dict__[column_name] for column_name in self.column_names if column_name in self.__dict__ } def insert(self) -> None: with self.db.begin() as conn: conn.execute(self.t.insert().values(**self._insert_values)) @property def _upsert_values(self: T) -> Dict[str, Any]: """The values to set when an upsert-insert conflicts and moves to the update part.""" return self._insert_values def _upsert_postgres(self: T, conn: Connection) -> None: conn.execute( pg_insert(self.t) .values(**self._insert_values) .on_conflict_do_update(constraint=self.t.primary_key, set_=self._upsert_values) ) def _upsert_sqlite(self: T, conn: Connection) -> None: conn.execute(self.t.insert().values(**self._insert_values).prefix_with("OR REPLACE")) def _upsert_generic(self: T, conn: Connection): conn.execute(self.t.delete().where(self._edit_identity)) conn.execute(self.t.insert().values(**self._insert_values)) def upsert(self: T) -> None: with self.db.begin() as conn: if self.db.dialect.name == "postgresql": self._upsert_postgres(conn) elif self.db.dialect.name == "sqlite": self._upsert_sqlite(conn) else: self._upsert_generic(conn) def __iter__(self): for key in self.column_names: yield self.__dict__[key] @as_declarative() class Base(BaseClass): """ .. deprecated:: 0.15.0 The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy. """ pass python-0.20.7/mautrix/util/ffmpeg.py000066400000000000000000000173121473573527000174260ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Iterable from pathlib import Path import asyncio import json import logging import mimetypes import os import shutil import tempfile try: from . import magic except ImportError: magic = None def _abswhich(program: str) -> str | None: path = shutil.which(program) return os.path.abspath(path) if path else None class ConverterError(ChildProcessError): pass class NotInstalledError(ConverterError): def __init__(self) -> None: super().__init__("failed to transcode media: ffmpeg is not installed") ffmpeg_path = _abswhich("ffmpeg") ffmpeg_default_params = ("-hide_banner", "-loglevel", "warning", "-y") ffprobe_path = _abswhich("ffprobe") ffprobe_default_params = ( "-loglevel", "quiet", "-print_format", "json", "-show_optional_fields", "1", "-show_format", "-show_streams", ) async def probe_path( input_file: os.PathLike[str] | str, logger: logging.Logger | None = None, ) -> Any: """ Probes a media file on the disk using ffprobe. Args: input_file: The full path to the file. Returns: A Python object containing the parsed JSON response from ffprobe Raises: ConverterError: if ffprobe returns a non-zero exit code. """ if ffprobe_path is None: raise NotInstalledError() input_file = Path(input_file) proc = await asyncio.create_subprocess_exec( ffprobe_path, *ffprobe_default_params, str(input_file), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() if proc.returncode != 0: err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})" raise ConverterError(f"ffprobe error: {err_text}") elif stderr and logger: logger.warning(f"ffprobe warning: {stderr.decode('utf-8')}") return json.loads(stdout) async def probe_bytes( data: bytes, input_mime: str | None = None, logger: logging.Logger | None = None, ) -> Any: """ Probe media file data using ffprobe. Args: data: The bytes of the file to probe. input_mime: The mime type of the input data. If not specified, will be guessed using magic. Returns: A Python object containing the parsed JSON response from ffprobe Raises: ConverterError: if ffprobe returns a non-zero exit code. """ if ffprobe_path is None: raise NotInstalledError() if input_mime is None: if magic is None: raise ValueError("input_mime was not specified and magic is not installed") input_mime = magic.mimetype(data) input_extension = mimetypes.guess_extension(input_mime) with tempfile.TemporaryDirectory(prefix="mautrix_ffmpeg_") as tmpdir: input_file = Path(tmpdir) / f"data{input_extension}" with open(input_file, "wb") as file: file.write(data) return await probe_path(input_file=input_file, logger=logger) async def convert_path( input_file: os.PathLike[str] | str, output_extension: str | None, input_args: Iterable[str] | None = None, output_args: Iterable[str] | None = None, remove_input: bool = False, output_path_override: os.PathLike[str] | str | None = None, logger: logging.Logger | None = None, ) -> Path | bytes: """ Convert a media file on the disk using ffmpeg. Args: input_file: The full path to the file. output_extension: The extension that the output file should be. input_args: Arguments to tell ffmpeg how to parse the input file. output_args: Arguments to tell ffmpeg how to convert the file to reach the wanted output. remove_input: Whether the input file should be removed after converting. Not compatible with ``output_path_override``. output_path_override: A custom output path to use (instead of using the input path with a different extension). Returns: The path to the converted file, or the stdout if ``output_path_override`` was set to ``-``. Raises: ConverterError: if ffmpeg returns a non-zero exit code. """ if ffmpeg_path is None: raise NotInstalledError() if output_path_override: output_file = output_path_override if remove_input: raise ValueError("remove_input can't be specified with output_path_override") elif not output_extension: raise ValueError("output_extension or output_path_override is required") else: input_file = Path(input_file) output_file = input_file.parent / f"{input_file.stem}{output_extension}" if input_file == output_file: output_file = Path(output_file) output_file = output_file.parent / f"{output_file.stem}-new{output_extension}" proc = await asyncio.create_subprocess_exec( ffmpeg_path, *ffmpeg_default_params, *(input_args or ()), "-i", str(input_file), *(output_args or ()), str(output_file), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() if proc.returncode != 0: err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})" raise ConverterError(f"ffmpeg error: {err_text}") elif stderr and logger: logger.warning(f"ffmpeg warning: {stderr.decode('utf-8')}") if remove_input and isinstance(input_file, Path): input_file.unlink(missing_ok=True) return stdout if output_file == "-" else output_file async def convert_bytes( data: bytes, output_extension: str, input_args: Iterable[str] | None = None, output_args: Iterable[str] | None = None, input_mime: str | None = None, logger: logging.Logger | None = None, ) -> bytes: """ Convert media file data using ffmpeg. Args: data: The bytes of the file to convert. output_extension: The extension that the output file should be. input_args: Arguments to tell ffmpeg how to parse the input file. output_args: Arguments to tell ffmpeg how to convert the file to reach the wanted output. input_mime: The mime type of the input data. If not specified, will be guessed using magic. Returns: The converted file as bytes. Raises: ConverterError: if ffmpeg returns a non-zero exit code. """ if ffmpeg_path is None: raise NotInstalledError() if input_mime is None: if magic is None: raise ValueError("input_mime was not specified and magic is not installed") input_mime = magic.mimetype(data) input_extension = mimetypes.guess_extension(input_mime) with tempfile.TemporaryDirectory(prefix="mautrix_ffmpeg_") as tmpdir: input_file = Path(tmpdir) / f"data{input_extension}" with open(input_file, "wb") as file: file.write(data) output_file = await convert_path( input_file=input_file, output_extension=output_extension, input_args=input_args, output_args=output_args, logger=logger, ) with open(output_file, "rb") as file: return file.read() __all__ = [ "ffmpeg_path", "ffmpeg_default_params", "ConverterError", "NotInstalledError", "convert_bytes", "convert_path", "probe_bytes", "probe_path", ] python-0.20.7/mautrix/util/file_store.py000066400000000000000000000043741473573527000203210ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import IO, Any, Protocol from abc import ABC, abstractmethod from pathlib import Path import json import pickle import time class Filer(Protocol): def dump(self, obj: Any, file: IO) -> None: pass def load(self, file: IO) -> Any: pass class FileStore(ABC): path: str | Path | IO filer: Filer binary: bool save_interval: float _last_save: float def __init__( self, path: str | Path | IO, filer: Filer | None = None, binary: bool = True, save_interval: float = 60.0, ) -> None: self.path = path self.filer = filer or (pickle if binary else json) self.binary = binary self.save_interval = save_interval self._last_save = time.monotonic() @abstractmethod def serialize(self) -> Any: pass @abstractmethod def deserialize(self, data: Any) -> None: pass def _save(self) -> None: if isinstance(self.path, IO): file = self.path close = False else: file = open(self.path, "wb" if self.binary else "w") close = True try: self.filer.dump(self.serialize(), file) finally: if close: file.close() def _load(self) -> None: if isinstance(self.path, IO): file = self.path close = False else: try: file = open(self.path, "rb" if self.binary else "r") except FileNotFoundError: return close = True try: self.deserialize(self.filer.load(file)) finally: if close: file.close() async def flush(self) -> None: self._save() async def open(self) -> None: self._load() def _time_limited_flush(self) -> None: if self._last_save + self.save_interval < time.monotonic(): self._save() self._last_save = time.monotonic() python-0.20.7/mautrix/util/format_duration.py000066400000000000000000000033311473573527000213530ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. def _pluralize(count: int, singular: str) -> str: return singular if count == 1 else f"{singular}s" def _include_if_positive(count: int, word: str) -> str: return f"{count} {_pluralize(count, word)}" if count > 0 else "" def format_duration(seconds: int) -> str: """ Format seconds as a simple duration in weeks/days/hours/minutes/seconds. Args: seconds: The number of seconds as an integer. Must be positive. Returns: The formatted duration. Examples: >>> from mautrix.util.format_duration import format_duration >>> format_duration(1234) '20 minutes and 34 seconds' >>> format_duration(987654) '1 week, 4 days, 10 hours, 20 minutes and 54 seconds' >>> format_duration(60) '1 minute' Raises: ValueError: if the duration is not positive. """ if seconds <= 0: raise ValueError("format_duration only accepts positive values") minutes, seconds = divmod(seconds, 60) hours, minutes = divmod(minutes, 60) days, hours = divmod(hours, 24) weeks, days = divmod(days, 7) parts = [ _include_if_positive(weeks, "week"), _include_if_positive(days, "day"), _include_if_positive(hours, "hour"), _include_if_positive(minutes, "minute"), _include_if_positive(seconds, "second"), ] parts = [part for part in parts if part] if len(parts) > 2: parts = [", ".join(parts[:-1]), parts[-1]] return " and ".join(parts) python-0.20.7/mautrix/util/format_duration_test.py000066400000000000000000000015401473573527000224120ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import pytest from .format_duration import format_duration tests = { 1234: "20 minutes and 34 seconds", 987654: "1 week, 4 days, 10 hours, 20 minutes and 54 seconds", 694861: "1 week, 1 day, 1 hour, 1 minute and 1 second", 1: "1 second", 59: "59 seconds", 60: "1 minute", 120: "2 minutes", } def test_format_duration() -> None: for seconds, formatted in tests.items(): assert format_duration(seconds) == formatted def test_non_positive_error() -> None: with pytest.raises(ValueError): format_duration(0) with pytest.raises(ValueError): format_duration(-123) python-0.20.7/mautrix/util/formatter/000077500000000000000000000000001473573527000176075ustar00rootroot00000000000000python-0.20.7/mautrix/util/formatter/__init__.py000066400000000000000000000015721473573527000217250ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .entity_string import AbstractEntity, EntityString, SemiAbstractEntity, SimpleEntity from .formatted_string import EntityType, FormattedString from .html_reader import HTMLNode, read_html from .markdown_string import MarkdownString from .parser import MatrixParser, RecursionContext async def parse_html(input_html: str) -> str: return (await MatrixParser().parse(input_html)).text __all__ = [ "AbstractEntity", "EntityString", "SemiAbstractEntity", "SimpleEntity", "EntityType", "FormattedString", "HTMLNode", "read_html", "MarkdownString", "MatrixParser", "RecursionContext", "parse_html", ] python-0.20.7/mautrix/util/formatter/entity_string.py000066400000000000000000000121501473573527000230620ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Generic, Iterable, Sequence, Type, TypeVar from abc import ABC, abstractmethod from itertools import chain from attr import dataclass import attr from .formatted_string import EntityType, FormattedString class AbstractEntity(ABC): def __init__( self, type: EntityType, offset: int, length: int, extra_info: dict[str, Any] ) -> None: pass @abstractmethod def copy(self) -> AbstractEntity: pass @abstractmethod def adjust_offset(self, offset: int, max_length: int = -1) -> AbstractEntity | None: pass class SemiAbstractEntity(AbstractEntity, ABC): offset: int length: int def adjust_offset(self, offset: int, max_length: int = -1) -> SemiAbstractEntity | None: entity = self.copy() entity.offset += offset if entity.offset < 0: entity.length += entity.offset if entity.length < 0: return None entity.offset = 0 elif entity.offset > max_length > -1: return None elif entity.offset + entity.length > max_length > -1: entity.length = max_length - entity.offset return entity @dataclass class SimpleEntity(SemiAbstractEntity): type: EntityType offset: int length: int extra_info: dict[str, Any] = attr.ib(factory=dict) def copy(self) -> SimpleEntity: return attr.evolve(self) TEntity = TypeVar("TEntity", bound=AbstractEntity) TEntityType = TypeVar("TEntityType") class EntityString(Generic[TEntity, TEntityType], FormattedString): text: str _entities: list[TEntity] entity_class: Type[AbstractEntity] = SimpleEntity def __init__(self, text: str = "", entities: list[TEntity] = None) -> None: self.text = text self._entities = entities or [] def __repr__(self) -> str: return f"{self.__class__.__name__}(text='{self.text}', entities={self.entities})" def __str__(self) -> str: return self.text @property def entities(self) -> list[TEntity]: return self._entities @entities.setter def entities(self, val: Iterable[TEntity]) -> None: self._entities = [entity for entity in val if entity is not None] def _offset_entities(self, offset: int) -> EntityString: self.entities = (entity.adjust_offset(offset, len(self.text)) for entity in self.entities) return self def append(self, *args: str | FormattedString) -> EntityString: for msg in args: if isinstance(msg, EntityString): self.entities += (entity.adjust_offset(len(self.text)) for entity in msg.entities) self.text += msg.text else: self.text += str(msg) return self def prepend(self, *args: str | FormattedString) -> EntityString: for msg in args: if isinstance(msg, EntityString): self.text = msg.text + self.text self.entities = chain( msg.entities, (entity.adjust_offset(len(msg.text)) for entity in self.entities) ) else: text = str(msg) self.text = text + self.text self.entities = (entity.adjust_offset(len(text)) for entity in self.entities) return self def format( self, entity_type: TEntityType, offset: int = None, length: int = None, **kwargs ) -> EntityString: self.entities.append( self.entity_class( type=entity_type, offset=offset or 0, length=length or len(self.text), extra_info=kwargs, ) ) return self def trim(self) -> EntityString: orig_len = len(self.text) self.text = self.text.lstrip() diff = orig_len - len(self.text) self.text = self.text.rstrip() self._offset_entities(-diff) return self def split(self, separator, max_items: int = -1) -> list[EntityString]: text_parts = self.text.split(separator, max_items - 1) output: list[EntityString] = [] offset = 0 for part in text_parts: msg = type(self)(part) msg.entities = (entity.adjust_offset(-offset, len(part)) for entity in self.entities) output.append(msg) offset += len(part) offset += len(separator) return output @classmethod def join(cls, items: Sequence[str | EntityString], separator: str = " ") -> EntityString: main = cls() for msg in items: if not isinstance(msg, EntityString): msg = cls(text=str(msg)) main.entities += [entity.adjust_offset(len(main.text)) for entity in msg.entities] main.text += msg.text + separator if len(separator) > 0: main.text = main.text[: -len(separator)] return main python-0.20.7/mautrix/util/formatter/formatted_string.py000066400000000000000000000105511473573527000235360ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Sequence from abc import ABC, abstractmethod from enum import Enum, auto class EntityType(Enum): """EntityType is a Matrix formatting entity type.""" BOLD = auto() ITALIC = auto() STRIKETHROUGH = auto() UNDERLINE = auto() URL = auto() EMAIL = auto() USER_MENTION = auto() ROOM_MENTION = auto() PREFORMATTED = auto() INLINE_CODE = auto() BLOCKQUOTE = auto() HEADER = auto() COLOR = auto() SPOILER = auto() class FormattedString(ABC): """FormattedString is an abstract HTML parsing target.""" @abstractmethod def append(self, *args: str | FormattedString) -> FormattedString: """ Append strings to this FormattedString. This method may mutate the source object, but it is not required to do so. Make sure to always use the return value when mutating and to duplicate strings if you don't want the original to change. Args: *args: The strings to append. Returns: A FormattedString that is a concatenation of this string and the given strings. """ pass @abstractmethod def prepend(self, *args: str | FormattedString) -> FormattedString: """ Prepend strings to this FormattedString. This method may mutate the source object, but it is not required to do so. Make sure to always use the return value when mutating and to duplicate strings if you don't want the original to change. Args: *args: The strings to prepend. Returns: A FormattedString that is a concatenation of the given strings and this string. """ pass @abstractmethod def format(self, entity_type: EntityType, **kwargs) -> FormattedString: """ Apply formatting to this FormattedString. This method may mutate the source object, but it is not required to do so. Make sure to always use the return value when mutating and to duplicate strings if you don't want the original to change. Args: entity_type: The type of formatting to apply to this string. **kwargs: Additional metadata required by the formatting type. Returns: A FormattedString with the given formatting applied. """ pass @abstractmethod def trim(self) -> FormattedString: """ Trim surrounding whitespace from this FormattedString. This method may mutate the source object, but it is not required to do so. Make sure to always use the return value when mutating and to duplicate strings if you don't want the original to change. Returns: A FormattedString without surrounding whitespace. """ pass @abstractmethod def split(self, separator, max_items: int = -1) -> list[FormattedString]: """ Split this FormattedString by the given separator. Args: separator: The separator to split by. max_items: The maximum number of items to return. If the limit is reached, the remaining string will be returned as one even if it contains the separator. Returns: The split strings. """ pass @classmethod def concat(cls, *args: str | FormattedString) -> FormattedString: """ Concatenate many FormattedStrings. Args: *args: The strings to concatenate. Returns: A FormattedString that is a concatenation of the given strings. """ return cls.join(items=args, separator="") @classmethod @abstractmethod def join(cls, items: Sequence[str | FormattedString], separator: str = " ") -> FormattedString: """ Join a list of FormattedStrings with the given separator. Args: items: The strings to join. separator: The separator to join them with. Returns: A FormattedString that is a combination of the given strings with the given separator between each one. """ pass python-0.20.7/mautrix/util/formatter/html_reader.py000066400000000000000000000041071473573527000224510ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from html.parser import HTMLParser class HTMLNode(list): tag: str text: str tail: str attrib: dict[str, str] def __repr__(self) -> str: return ( f"HTMLNode(tag='{self.tag}', attrs={self.attrib}, text='{self.text}', " f"tail='{self.tail}', children={list(self)})" ) def __init__(self, tag: str, attrs: list[tuple[str, str]]) -> None: super().__init__() self.tag = tag self.text = "" self.tail = "" self.attrib = dict(attrs) class NodeifyingParser(HTMLParser): # From https://www.w3.org/TR/html5/syntax.html#writing-html-documents-elements void_tags = ( "area", "base", "br", "col", "command", "embed", "hr", "img", "input", "link", "meta", "param", "source", "track", "wbr", ) stack: list[HTMLNode] def __init__(self) -> None: super().__init__() self.stack = [HTMLNode("html", [])] def handle_starttag(self, tag: str, attrs: list[tuple[str, str]]) -> None: node = HTMLNode(tag, attrs) self.stack[-1].append(node) if tag not in self.void_tags: self.stack.append(node) def handle_startendtag(self, tag, attrs): self.stack[-1].append(HTMLNode(tag, attrs)) def handle_endtag(self, tag: str) -> None: if tag == self.stack[-1].tag: self.stack.pop() def handle_data(self, data: str) -> None: if len(self.stack[-1]) > 0: self.stack[-1][-1].tail += data else: self.stack[-1].text += data def error(self, message: str) -> None: pass def read_html(data: str) -> HTMLNode: parser = NodeifyingParser() parser.feed(data) return parser.stack[0] python-0.20.7/mautrix/util/formatter/html_reader.pyi000066400000000000000000000007061473573527000226230ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. class HTMLNode(list[HTMLNode]): tag: str text: str tail: str attrib: dict[str, str] def __init__(self, tag: str, attrs: list[tuple[str, str]]) -> None: ... def read_html(data: str) -> HTMLNode: ... python-0.20.7/mautrix/util/formatter/markdown_string.py000066400000000000000000000051201473573527000233670ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import List, Sequence, Union from .formatted_string import EntityType, FormattedString class MarkdownString(FormattedString): text: str def __init__(self, text: str = "") -> None: self.text = text def __str__(self) -> str: return self.text def append(self, *args: Union[str, FormattedString]) -> MarkdownString: self.text += "".join(str(arg) for arg in args) return self def prepend(self, *args: Union[str, FormattedString]) -> MarkdownString: self.text = "".join(str(arg) for arg in args + (self.text,)) return self def format(self, entity_type: EntityType, **kwargs) -> MarkdownString: if entity_type == EntityType.BOLD: self.text = f"**{self.text}**" elif entity_type == EntityType.ITALIC: self.text = f"_{self.text}_" elif entity_type == EntityType.STRIKETHROUGH: self.text = f"~~{self.text}~~" elif entity_type == EntityType.SPOILER: reason = kwargs.get("reason", "") if reason: self.text = f"{reason}|{self.text}" self.text = f"||{self.text}||" elif entity_type == EntityType.URL: if kwargs["url"] != self.text: self.text = f"[{self.text}]({kwargs['url']})" elif entity_type == EntityType.PREFORMATTED: self.text = f"```{kwargs['language']}\n{self.text}\n```" elif entity_type == EntityType.INLINE_CODE: self.text = f"`{self.text}`" elif entity_type == EntityType.BLOCKQUOTE: children = self.trim().split("\n") children = [child.prepend("> ") for child in children] self.text = self.join(children, "\n").text elif entity_type == EntityType.HEADER: prefix = "#" * kwargs["size"] self.text = f"{prefix} {self.text}" return self def trim(self) -> MarkdownString: self.text = self.text.strip() return self def split(self, separator, max_items: int = -1) -> List[MarkdownString]: return [MarkdownString(text) for text in self.text.split(separator, max_items)] @classmethod def join( cls, items: Sequence[Union[str, FormattedString]], separator: str = " " ) -> MarkdownString: return cls(separator.join(str(item) for item in items)) python-0.20.7/mautrix/util/formatter/parser.py000066400000000000000000000266371473573527000214730ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, Callable, Generic, Type, TypeVar import re from ...types import EventID, MatrixURI, RoomAlias, RoomID, UserID from .formatted_string import EntityType, FormattedString from .html_reader import HTMLNode, read_html from .markdown_string import MarkdownString class RecursionContext: preserve_whitespace: bool ul_depth: int _inited: bool def __init__(self, preserve_whitespace: bool = False, ul_depth: int = 0) -> None: self.preserve_whitespace = preserve_whitespace self.ul_depth = ul_depth self._inited = True def __setattr__(self, key: str, value: Any) -> None: if getattr(self, "_inited", False) is True: raise TypeError("'RecursionContext' object is immutable") super(RecursionContext, self).__setattr__(key, value) def enter_list(self) -> RecursionContext: return RecursionContext( preserve_whitespace=self.preserve_whitespace, ul_depth=self.ul_depth + 1 ) def enter_code_block(self) -> RecursionContext: return RecursionContext(preserve_whitespace=True, ul_depth=self.ul_depth) T = TypeVar("T", bound=FormattedString) spaces = re.compile(r"\s+") space = " " class MatrixParser(Generic[T]): block_tags: tuple[str, ...] = ( "p", "pre", "blockquote", "ol", "ul", "li", "h1", "h2", "h3", "h4", "h5", "h6", "div", "hr", "table", ) list_bullets: tuple[str, ...] = ("●", "○", "■", "‣") e: Type[EntityType] = EntityType fs: Type[T] = MarkdownString read_html: Callable[[str], HTMLNode] = staticmethod(read_html) ignore_less_relevant_links: bool = True exclude_plaintext_attrib: str = "data-mautrix-exclude-plaintext" def list_bullet(self, depth: int) -> str: return self.list_bullets[(depth - 1) % len(self.list_bullets)] + " " async def list_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: ordered: bool = node.tag == "ol" tagged_children: list[tuple[T, str]] = await self.node_to_tagged_fstrings(node, ctx) counter: int = 1 indent_length: int = 0 if ordered: try: counter = int(node.attrib.get("start", "1")) except ValueError: counter = 1 longest_index = counter - 1 + len(tagged_children) indent_length = len(str(longest_index)) indent: str = (indent_length + 2) * " " children: list[T] = [] for child, tag in tagged_children: if tag != "li": continue if ordered: prefix = f"{counter}. " counter += 1 else: prefix = self.list_bullet(ctx.ul_depth) child = child.prepend(prefix) parts = child.split("\n") parts = parts[:1] + [part.prepend(indent) for part in parts[1:]] child = self.fs.join(parts, "\n") children.append(child) return self.fs.join(children, "\n") async def blockquote_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: msg = await self.tag_aware_parse_node(node, ctx) return msg.format(self.e.BLOCKQUOTE) async def hr_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: return self.fs("---") async def header_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: children = await self.node_to_fstrings(node, ctx) length = int(node.tag[1]) return self.fs.join(children, "").format(self.e.HEADER, size=length) async def basic_format_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: msg = await self.tag_aware_parse_node(node, ctx) if self.exclude_plaintext_attrib in node.attrib: return msg if node.tag in ("b", "strong"): msg = msg.format(self.e.BOLD) elif node.tag in ("i", "em"): msg = msg.format(self.e.ITALIC) elif node.tag in ("s", "strike", "del"): msg = msg.format(self.e.STRIKETHROUGH) elif node.tag in ("u", "ins"): msg = msg.format(self.e.UNDERLINE) return msg async def link_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: msg = await self.tag_aware_parse_node(node, ctx) href = node.attrib.get("href", "") if not href: return msg if href.startswith("mailto:"): return self.fs(href[len("mailto:") :]).format(self.e.EMAIL) matrix_uri = MatrixURI.try_parse(href) if matrix_uri: if matrix_uri.user_id: new_msg = await self.user_pill_to_fstring(msg, matrix_uri.user_id) elif matrix_uri.event_id: new_msg = await self.event_link_to_fstring( msg, matrix_uri.room_id or matrix_uri.room_alias, matrix_uri.event_id ) elif matrix_uri.room_alias: new_msg = await self.room_pill_to_fstring(msg, matrix_uri.room_alias) elif matrix_uri.room_id: new_msg = await self.room_id_link_to_fstring(msg, matrix_uri.room_id) else: new_msg = None if new_msg: return new_msg # Custom attribute to tell the parser that the link isn't relevant and # shouldn't be included in plaintext representation. if self.ignore_less_relevant_links and self.exclude_plaintext_attrib in node.attrib: return msg return await self.url_to_fstring(msg, href) async def url_to_fstring(self, msg: T, url: str) -> T | None: return msg.format(self.e.URL, url=url) async def user_pill_to_fstring(self, msg: T, user_id: UserID) -> T | None: return msg.format(self.e.USER_MENTION, user_id=user_id) async def room_pill_to_fstring(self, msg: T, room_alias: RoomAlias) -> T | None: return None async def room_id_link_to_fstring(self, msg: T, room_id: RoomID) -> T | None: return None async def event_link_to_fstring( self, msg: T, room: RoomID | RoomAlias, event_id: EventID ) -> T | None: return None async def img_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: return self.fs(node.attrib.get("alt") or node.attrib.get("title") or "") async def custom_node_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T | None: return None async def color_to_fstring(self, msg: T, color: str) -> T: return msg.format(self.e.COLOR, color=color) async def spoiler_to_fstring(self, msg: T, reason: str) -> T: return msg.format(self.e.SPOILER, reason=reason) async def node_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: custom = await self.custom_node_to_fstring(node, ctx) if custom: return custom elif node.tag == "mx-reply": return self.fs("") elif node.tag == "blockquote": return await self.blockquote_to_fstring(node, ctx) elif node.tag == "hr": return await self.hr_to_fstring(node, ctx) elif node.tag == "ol": return await self.list_to_fstring(node, ctx) elif node.tag == "ul": return await self.list_to_fstring(node, ctx.enter_list()) elif node.tag in ("h1", "h2", "h3", "h4", "h5", "h6"): return await self.header_to_fstring(node, ctx) elif node.tag == "br": return self.fs("\n") elif node.tag in ("b", "strong", "i", "em", "s", "del", "u", "ins"): return await self.basic_format_to_fstring(node, ctx) elif node.tag == "a": return await self.link_to_fstring(node, ctx) elif node.tag == "img": return await self.img_to_fstring(node, ctx) elif node.tag == "p": return (await self.tag_aware_parse_node(node, ctx)).append("\n") elif node.tag in ("font", "span"): msg = await self.tag_aware_parse_node(node, ctx) try: spoiler = node.attrib["data-mx-spoiler"] except KeyError: pass else: msg = await self.spoiler_to_fstring(msg, spoiler) try: color = node.attrib["color"] except KeyError: try: color = node.attrib["data-mx-color"] except KeyError: color = None if color: msg = await self.color_to_fstring(msg, color) return msg elif node.tag == "pre": lang = "" try: if node[0].tag == "code": node = node[0] lang = node.attrib["class"][len("language-") :] except (IndexError, KeyError): pass return (await self.parse_node(node, ctx.enter_code_block())).format( self.e.PREFORMATTED, language=lang ) elif node.tag == "code": return (await self.parse_node(node, ctx.enter_code_block())).format(self.e.INLINE_CODE) return await self.tag_aware_parse_node(node, ctx) async def text_to_fstring( self, text: str, ctx: RecursionContext, strip_leading_whitespace: bool = False ) -> T: if not ctx.preserve_whitespace: text = spaces.sub(space, text.lstrip() if strip_leading_whitespace else text) return self.fs(text) async def node_to_tagged_fstrings( self, node: HTMLNode, ctx: RecursionContext ) -> list[tuple[T, str]]: output = [] if node.text: output.append((await self.text_to_fstring(node.text, ctx), "text")) for child in node: output.append((await self.node_to_fstring(child, ctx), child.tag)) if child.tail: # For text following a block tag, the leading whitespace is meaningless (there'll # be a newline added later), but for other tags it can be interpreted as a space. text = await self.text_to_fstring( child.tail, ctx, strip_leading_whitespace=child.tag in self.block_tags ) output.append((text, "text")) return output async def node_to_fstrings(self, node: HTMLNode, ctx: RecursionContext) -> list[T]: return [msg for (msg, tag) in await self.node_to_tagged_fstrings(node, ctx)] async def tag_aware_parse_node(self, node: HTMLNode, ctx: RecursionContext) -> T: msgs = await self.node_to_tagged_fstrings(node, ctx) output = self.fs() prev_was_block = False for msg, tag in msgs: if tag in self.block_tags: msg = msg.append("\n") if not prev_was_block: msg = msg.prepend("\n") prev_was_block = True output = output.append(msg) return output.trim() async def parse_node(self, node: HTMLNode, ctx: RecursionContext) -> T: return self.fs.join(await self.node_to_fstrings(node, ctx)) async def parse(self, data: str) -> T: msg = await self.node_to_fstring( self.read_html(f"{data}"), RecursionContext() ) return msg python-0.20.7/mautrix/util/formatter/parser_test.py000066400000000000000000000040111473573527000225100ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import pytest from . import parse_html async def test_basic_markdown() -> None: tests = { "test": "**test**", "test!": "**t_e~~s~~t_!**", "example": "[example](https://example.com)", "
div {\n    display: none;\n}
": "```css\ndiv {\n display: none;\n}\n```", "hello": "`hello`", "
Testing
123
": "> Testing\n> 123", "
  • test
  • \n
  • foo
  • \n
  • bar
  • \n
": "● test\n● foo\n● bar", "
  1. test
  2. \n
  3. foo
  4. \n
  5. bar
  6. \n
": "123. test\n124. foo\n125. bar", "

header

": "#### header", "spoiler?": "||spoiler?||", "not really": "||SPOILER!|not really||", } for html, markdown_ish in tests.items(): assert await parse_html(html) == markdown_ish async def test_nested_markdown() -> None: input_html = """

Hello, World!

  1. example
    • item 1
    • item 2
  2. def random() -> int:
        if 4 is 1:
            return 5
        return 4
  3. Just some text
""".strip() expected_output = """ # Hello, World! > 1. [example](https://example.com) > 2. ● item 1 > ● item 2 > 3. ```python > def random() -> int: > if 4 is 1: > return 5 > return 4 > ``` > 4. **Just some text** """.strip() assert await parse_html(input_html) == expected_output python-0.20.7/mautrix/util/logging/000077500000000000000000000000001473573527000172325ustar00rootroot00000000000000python-0.20.7/mautrix/util/logging/__init__.py000066400000000000000000000002161473573527000213420ustar00rootroot00000000000000from .color import ColorFormatter from .trace import SILLY, TRACE, TraceLogger __all__ = ["ColorFormatter", "TraceLogger", "SILLY", "TRACE"] python-0.20.7/mautrix/util/logging/color.py000066400000000000000000000034761473573527000207340ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from copy import copy from logging import Formatter, LogRecord PREFIX = "\033[" RESET = PREFIX + "0m" MAU_COLOR = PREFIX + "32m" # green AIOHTTP_COLOR = PREFIX + "36m" # cyan MXID_COLOR = PREFIX + "33m" # yellow LEVEL_COLORS = { "DEBUG": "37m", # white "INFO": "36m", # cyan "WARNING": "33;1m", # yellow "ERROR": "31;1m", # red "CRITICAL": f"37;1m{PREFIX}41m", # white on red bg } LEVELNAME_OVERRIDE = { name: f"{PREFIX}{color}{name}{RESET}" for name, color in LEVEL_COLORS.items() } class ColorFormatter(Formatter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _color_name(self, module: str) -> str: as_api = "mau.as.api" if module.startswith(as_api): return f"{MAU_COLOR}{as_api}{RESET}.{MXID_COLOR}{module[len(as_api) + 1:]}{RESET}" elif module.startswith("mau."): try: next_dot = module.index(".", len("mau.")) return ( f"{MAU_COLOR}{module[:next_dot]}{RESET}" f".{MXID_COLOR}{module[next_dot+1:]}{RESET}" ) except ValueError: return MAU_COLOR + module + RESET elif module.startswith("aiohttp"): return AIOHTTP_COLOR + module + RESET return module def format(self, record: LogRecord): colored_record: LogRecord = copy(record) colored_record.name = self._color_name(record.name) colored_record.levelname = LEVELNAME_OVERRIDE.get(record.levelname, record.levelname) return super().format(colored_record) python-0.20.7/mautrix/util/logging/trace.py000066400000000000000000000016151473573527000207050ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Type, cast import logging TRACE = logging.TRACE = 5 logging.addLevelName(TRACE, "TRACE") SILLY = logging.SILLY = 1 logging.addLevelName(SILLY, "SILLY") OldLogger: Type[logging.Logger] = cast(Type[logging.Logger], logging.getLoggerClass()) class TraceLogger(OldLogger): def trace(self, msg, *args, **kwargs) -> None: self.log(TRACE, msg, *args, **kwargs) def silly(self, msg, *args, **kwargs) -> None: self.log(SILLY, msg, *args, **kwargs) def getChild(self, suffix: str) -> TraceLogger: return cast(TraceLogger, super().getChild(suffix)) logging.setLoggerClass(TraceLogger) python-0.20.7/mautrix/util/magic.py000066400000000000000000000026611473573527000172430ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import functools import magic try: _from_buffer = functools.partial(magic.from_buffer, mime=True) _from_filename = functools.partial(magic.from_file, mime=True) except AttributeError: _from_buffer = lambda data: magic.detect_from_content(data).mime_type _from_filename = lambda file: magic.detect_from_filename(file).mime_type def mimetype(data: bytes | bytearray | str) -> str: """ Uses magic to determine the mimetype of a file on disk or in memory. Supports both libmagic's Python bindings and the python-magic package. Args: data: The file data, either as in-memory bytes or a path to the file as a string. Returns: The mime type as a string. """ if isinstance(data, str): return _from_filename(data) elif isinstance(data, bytes): return _from_buffer(data) elif isinstance(data, bytearray): # Magic doesn't like bytearrays directly, so just copy the first 1024 bytes for it. return _from_buffer(bytes(data[:1024])) else: raise TypeError( f"mimetype() argument must be a string or bytes, not {type(data).__name__!r}" ) __all__ = ["mimetype"] python-0.20.7/mautrix/util/manhole.py000066400000000000000000000247651473573527000176170ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. # # Based on https://github.com/nhoad/aiomanhole Copyright (c) 2014, Nathan Hoad from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from abc import ABC, abstractmethod from io import BytesIO, StringIO from socket import SOL_SOCKET from types import CodeType import ast import asyncio import codeop import contextlib import functools import inspect import logging import os import pwd import struct import sys import traceback try: from socket import SO_PEERCRED except ImportError: SO_PEERCRED = None log = logging.getLogger("mau.manhole") TOP_LEVEL_AWAIT = sys.version_info >= (3, 8) ASYNC_EVAL_WRAPPER: str = """ async def __eval_async_expr(): try: pass finally: globals().update(locals()) """ def compile_async(tree: ast.AST) -> CodeType: flags = 0 if TOP_LEVEL_AWAIT: flags += ast.PyCF_ALLOW_TOP_LEVEL_AWAIT node_to_compile = tree else: insert_returns(tree.body) wrapper_node: ast.AST = ast.parse(ASYNC_EVAL_WRAPPER, "", "single") method_stmt = wrapper_node.body[0] try_stmt = method_stmt.body[0] try_stmt.body = tree.body node_to_compile = wrapper_node return compile(node_to_compile, "", "single", flags=flags) # From https://gist.github.com/nitros12/2c3c265813121492655bc95aa54da6b9 def insert_returns(body: List[ast.AST]) -> None: if isinstance(body[-1], ast.Expr): body[-1] = ast.Return(body[-1].value) ast.fix_missing_locations(body[-1]) elif isinstance(body[-1], ast.If): insert_returns(body[-1].body) insert_returns(body[-1].orelse) elif isinstance(body[-1], (ast.With, ast.AsyncWith)): insert_returns(body[-1].body) class StatefulCommandCompiler(codeop.CommandCompiler): """A command compiler that buffers input until a full command is available.""" buf: BytesIO def __init__(self) -> None: super().__init__() self.compiler = functools.partial( compile, optimize=1, flags=( ast.PyCF_ONLY_AST | codeop.PyCF_DONT_IMPLY_DEDENT | codeop.PyCF_ALLOW_INCOMPLETE_INPUT ), ) self.buf = BytesIO() def is_partial_command(self) -> bool: return bool(self.buf.getvalue()) def __call__(self, source: bytes, **kwargs: Any) -> Optional[CodeType]: buf = self.buf if self.is_partial_command(): buf.write(b"\n") buf.write(source) code = self.buf.getvalue().decode("utf-8") codeobj = super().__call__(code, **kwargs) if codeobj: self.reset() return compile_async(codeobj) return None def reset(self) -> None: self.buf.seek(0) self.buf.truncate(0) class Interpreter(ABC): @abstractmethod def __init__(self, namespace: Dict[str, Any], banner: Union[bytes, str]) -> None: pass @abstractmethod def close(self) -> None: pass @abstractmethod async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: pass class AsyncInterpreter(Interpreter): """An interactive asynchronous interpreter.""" reader: asyncio.StreamReader writer: asyncio.StreamWriter namespace: Dict[str, Any] banner: bytes compiler: StatefulCommandCompiler running: bool def __init__(self, namespace: Dict[str, Any], banner: Union[bytes, str]) -> None: super().__init__(namespace, banner) self.namespace = namespace self.banner = banner if isinstance(banner, bytes) else str(banner).encode("utf-8") self.compiler = StatefulCommandCompiler() async def send_exception(self) -> None: """When an exception has occurred, write the traceback to the user.""" self.compiler.reset() exc = traceback.format_exc() self.writer.write(exc.encode("utf-8")) await self.writer.drain() async def execute(self, codeobj: CodeType) -> Tuple[Any, str]: with contextlib.redirect_stdout(StringIO()) as buf: if TOP_LEVEL_AWAIT: value = eval(codeobj, self.namespace) if codeobj.co_flags & inspect.CO_COROUTINE: value = await value else: exec(codeobj, self.namespace) value = await eval("__eval_async_expr()", self.namespace) return value, buf.getvalue() async def handle_one_command(self) -> None: """Process a single command. May have many lines.""" while True: await self.write_prompt() codeobj = await self.read_command() if codeobj is not None: await self.run_command(codeobj) return async def run_command(self, codeobj: CodeType) -> None: """Execute a compiled code object, and write the output back to the client.""" try: value, stdout = await self.execute(codeobj) except Exception: await self.send_exception() return else: await self.send_output(value, stdout) async def write_prompt(self) -> None: writer = self.writer if self.compiler.is_partial_command(): writer.write(b"... ") else: writer.write(b">>> ") await writer.drain() async def read_command(self) -> Optional[CodeType]: """Read a command from the user line by line. Returns a code object suitable for execution. """ reader = self.reader line = await reader.readline() if line == b"": raise ConnectionResetError() try: # skip the newline to make CommandCompiler work as advertised codeobj = self.compiler(line.rstrip(b"\n")) except SyntaxError: await self.send_exception() return None return codeobj async def send_output(self, value: str, stdout: str) -> None: """Write the output or value of the expression back to user. >>> 5 5 >>> print('cash rules everything around me') cash rules everything around me """ writer = self.writer if value is not None: writer.write(f"{value!r}\n".encode("utf-8")) if stdout: writer.write(stdout.encode("utf-8")) await writer.drain() def close(self) -> None: if self.running: self.writer.close() self.running = False async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: """Main entry point for an interpreter session with a single client.""" self.reader = reader self.writer = writer self.running = True if self.banner: writer.write(self.banner) await writer.drain() while self.running: try: await self.handle_one_command() except ConnectionResetError: writer.close() self.running = False break except Exception: log.exception("Exception in manhole REPL") self.writer.write(traceback.format_exc()) await self.writer.drain() class InterpreterFactory: namespace: Dict[str, Any] banner: bytes interpreter_class: Type[Interpreter] clients: List[Interpreter] whitelist: Set[int] _conn_id: int def __init__( self, namespace: Dict[str, Any], banner: Union[bytes, str], interpreter_class: Type[Interpreter], whitelist: Set[int], ) -> None: self.namespace = namespace or {} self.banner = banner self.interpreter_class = interpreter_class self.clients = [] self.whitelist = whitelist self._conn_id = 0 @property def conn_id(self) -> int: self._conn_id += 1 return self._conn_id async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: sock = writer.transport.get_extra_info("socket") # TODO support non-linux OSes # I think FreeBSD uses SCM_CREDS creds = sock.getsockopt(SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i")) pid, uid, gid = struct.unpack("3i", creds) user_info = pwd.getpwuid(uid) username = f"{user_info.pw_name} ({uid})" if user_info and user_info.pw_name else uid if len(self.whitelist) > 0 and uid not in self.whitelist: writer.write(b"You are not whitelisted to use the manhole.") log.warning(f"Non-whitelisted user {username} tried to connect from PID {pid}") await writer.drain() writer.close() return namespace = {**self.namespace} interpreter = self.interpreter_class(namespace=namespace, banner=self.banner) namespace["exit"] = interpreter.close self.clients.append(interpreter) conn_id = self.conn_id log.info(f"Manhole connection OPENED: {conn_id} from PID {pid} by {username}") await asyncio.create_task(interpreter(reader, writer)) log.info(f"Manhole connection CLOSED: {conn_id} from PID {pid} by {username}") self.clients.remove(interpreter) async def start_manhole( path: str, banner: str = "", namespace: Optional[Dict[str, Any]] = None, loop: asyncio.AbstractEventLoop = None, whitelist: Set[int] = None, ) -> Tuple[asyncio.AbstractServer, Callable[[], None]]: """ Starts a manhole server on a given UNIX address. Args: path: The path to create the UNIX socket at. banner: The banner to show when clients connect. namespace: The globals to provide to connected clients. loop: The asyncio event loop to use. whitelist: List of user IDs to allow connecting. """ if not SO_PEERCRED: raise ValueError("SO_PEERCRED is not supported on this platform") factory = InterpreterFactory( namespace=namespace, banner=banner, interpreter_class=AsyncInterpreter, whitelist=whitelist, ) server = await asyncio.start_unix_server(factory, path=path) os.chmod(path, 0o666) def stop(): for client in factory.clients: client.close() server.close() return server, stop python-0.20.7/mautrix/util/markdown.py000066400000000000000000000021051473573527000177760ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import commonmark class HtmlEscapingRenderer(commonmark.HtmlRenderer): def __init__(self, allow_html: bool = False): super().__init__() self.allow_html = allow_html def lit(self, s): if self.allow_html: return super().lit(s) return super().lit(s.replace("<", "<").replace(">", ">")) def image(self, node, entering): prev = self.allow_html self.allow_html = True super().image(node, entering) self.allow_html = prev md_parser = commonmark.Parser() yes_html_renderer = commonmark.HtmlRenderer() no_html_renderer = HtmlEscapingRenderer() def render(message: str, allow_html: bool = False) -> str: parsed = md_parser.parse(message) if allow_html: return yes_html_renderer.render(parsed) else: return no_html_renderer.render(parsed) python-0.20.7/mautrix/util/message_send_checkpoint.py000066400000000000000000000063151473573527000230270ustar00rootroot00000000000000# Copyright (c) 2022 Sumner Evans # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional import logging from aiohttp.client import ClientTimeout from attr import dataclass import aiohttp from mautrix.api import HTTPAPI from mautrix.types import EventType, MessageType, SerializableAttrs, SerializableEnum class MessageSendCheckpointStep(SerializableEnum): CLIENT = "CLIENT" HOMESERVER = "HOMESERVER" BRIDGE = "BRIDGE" DECRYPTED = "DECRYPTED" REMOTE = "REMOTE" COMMAND = "COMMAND" class MessageSendCheckpointStatus(SerializableEnum): SUCCESS = "SUCCESS" WILL_RETRY = "WILL_RETRY" PERM_FAILURE = "PERM_FAILURE" UNSUPPORTED = "UNSUPPORTED" TIMEOUT = "TIMEOUT" DELIVERY_FAILED = "DELIVERY_FAILED" class MessageSendCheckpointReportedBy(SerializableEnum): ASMUX = "ASMUX" BRIDGE = "BRIDGE" @dataclass class MessageSendCheckpoint(SerializableAttrs): event_id: str room_id: str step: MessageSendCheckpointStep timestamp: int status: MessageSendCheckpointStatus event_type: EventType reported_by: MessageSendCheckpointReportedBy retry_num: int = 0 message_type: Optional[MessageType] = None info: Optional[str] = None client_type: Optional[str] = None client_version: Optional[str] = None async def send(self, endpoint: str, as_token: str, log: logging.Logger) -> None: if not endpoint: return try: headers = {"Authorization": f"Bearer {as_token}", "User-Agent": HTTPAPI.default_ua} async with ( aiohttp.ClientSession() as sess, sess.post( endpoint, json={"checkpoints": [self.serialize()]}, headers=headers, timeout=ClientTimeout(30), ) as resp, ): if not 200 <= resp.status < 300: text = await resp.text() text = text.replace("\n", "\\n") log.warning( f"Unexpected status code {resp.status} sending checkpoint " f"for {self.event_id} ({self.step}/{self.status}): {text}" ) else: log.info( f"Successfully sent checkpoint for {self.event_id} " f"({self.step}/{self.status})" ) except Exception as e: log.warning( f"Failed to send checkpoint for {self.event_id} ({self.step}/{self.status}): " f"{type(e).__name__}: {e}" ) CHECKPOINT_TYPES = { EventType.ROOM_REDACTION, EventType.ROOM_MESSAGE, EventType.ROOM_ENCRYPTED, EventType.ROOM_MEMBER, EventType.ROOM_NAME, EventType.ROOM_AVATAR, EventType.ROOM_TOPIC, EventType.STICKER, EventType.REACTION, EventType.CALL_INVITE, EventType.CALL_CANDIDATES, EventType.CALL_SELECT_ANSWER, EventType.CALL_ANSWER, EventType.CALL_HANGUP, EventType.CALL_REJECT, EventType.CALL_NEGOTIATE, } python-0.20.7/mautrix/util/opt_prometheus.py000066400000000000000000000036221473573527000212360ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, cast class _NoopPrometheusEntity: """NoopPrometheusEntity is a class that can be used as a no-op placeholder for prometheus metrics objects when prometheus_client isn't installed.""" def __init__(self, *args, **kwargs): pass def __call__(self, *args, **kwargs): if not kwargs and len(args) == 1 and callable(args[0]): return args[0] return self def __enter__(self): pass def __exit__(self, exc_type, exc_val, exc_tb): pass def __getattr__(self, item): return self try: from prometheus_client import Counter, Enum, Gauge, Histogram, Info, Summary is_installed = True except ImportError: Counter = Gauge = Summary = Histogram = Info = Enum = cast(Any, _NoopPrometheusEntity) is_installed = False def async_time(metric: Gauge | Summary | Histogram): """ Measure the time that each execution of the decorated async function takes. This is equivalent to the ``time`` method-decorator in the metrics, but supports async functions. Args: metric: The metric instance to store the measures in. """ if not hasattr(metric, "time") or not callable(metric.time): raise ValueError("async_time only supports metrics that support timing") def decorator(fn): async def wrapper(*args, **kwargs): with metric.time(): return await fn(*args, **kwargs) return wrapper if is_installed else fn return decorator __all__ = [ "Counter", "Gauge", "Summary", "Histogram", "Info", "Enum", "async_time", "is_installed", ] python-0.20.7/mautrix/util/opt_prometheus.pyi000066400000000000000000000060021473573527000214020ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Callable, Generic, Iterable, TypeVar T = TypeVar("T") Number = int | float class Metric: name: str documentation: str unit: str typ: str samples: list[Any] def add_sample( self, name: str, labels: Iterable[str], value: Any, timestamp: Any = None, exemplar: Any = None, ) -> None: ... class MetricWrapperBase(Generic[T]): def __init__( self, name: str, documentation: str, labelnames: Iterable[str] = (), namespace: str = "", subsystem: str = "", unit: str = "", registry: Any = None, labelvalues: Any = None, ) -> None: ... def describe(self) -> list[Metric]: ... def collect(self) -> list[Metric]: ... def labels(self, *labelvalues, **labelkwargs) -> T: ... def remove(self, *labelvalues) -> None: ... class ContextManager: def __enter__(self) -> None: ... def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... def __call__(self, f) -> None: ... class Counter(MetricWrapperBase[Counter]): def inc(self, amount: Number = 1) -> None: ... def count_exceptions(self, exception: Exception = Exception) -> ContextManager: ... class Gauge(MetricWrapperBase[Gauge]): def inc(self, amount: Number = 1) -> None: ... def dec(self, amount: Number = 1) -> None: ... def set(self, value: Number = 1) -> None: ... def set_to_current_time(self) -> None: ... def track_inprogress(self) -> ContextManager: ... def time(self) -> ContextManager: ... def set_function(self, f: Callable[[], Number]) -> None: ... class Summary(MetricWrapperBase[Summary]): def observe(self, amount: Number) -> None: ... def time(self) -> ContextManager: ... class Histogram(MetricWrapperBase[Histogram]): def __init__( self, name: str, documentation: str, labelnames: Iterable[str] = (), namespace: str = "", subsystem: str = "", unit: str = "", registry: Any = None, labelvalues: Any = None, buckets: Iterable[Number] = (), ) -> None: ... def observe(self, amount: Number = 1) -> None: ... def time(self) -> ContextManager: ... class Info(MetricWrapperBase[Info]): def info(self, val: dict[str, str]) -> None: ... class Enum(MetricWrapperBase[Enum]): def __init__( self, name: str, documentation: str, labelnames: Iterable[str] = (), namespace: str = "", subsystem: str = "", unit: str = "", registry: Any = None, labelvalues: Any = None, states: Iterable[str] = None, ) -> None: ... def state(self, state: str) -> None: ... def async_time(metric: Gauge | Summary | Histogram) -> Callable[[Callable], Callable]: ... python-0.20.7/mautrix/util/program.py000066400000000000000000000240351473573527000176310ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Any, AsyncIterable, Awaitable, Iterable, Union, cast from itertools import chain from time import time import argparse import asyncio import copy import inspect import logging import logging.config import signal import sys from .config import BaseFileConfig, BaseMissingError, BaseValidatableConfig, ConfigValueError from .logging import TraceLogger try: import uvloop except ImportError: uvloop = None try: import prometheus_client as prometheus except ImportError: prometheus = None NewTask = Union[Awaitable[Any], Iterable[Awaitable[Any]], AsyncIterable[Awaitable[Any]]] TaskList = Iterable[Awaitable[Any]] class Program: """ A generic main class for programs that handles argument parsing, config loading, logger setup and general startup/shutdown lifecycle. """ loop: asyncio.AbstractEventLoop log: TraceLogger parser: argparse.ArgumentParser args: argparse.Namespace config_class: type[BaseFileConfig] config: BaseFileConfig startup_actions: TaskList shutdown_actions: TaskList module: str name: str version: str command: str description: str def __init__( self, module: str | None = None, name: str | None = None, description: str | None = None, command: str | None = None, version: str | None = None, config_class: type[BaseFileConfig] | None = None, ) -> None: if module: self.module = module if name: self.name = name if description: self.description = description if command: self.command = command if version: self.version = version if config_class: self.config_class = config_class self.startup_actions = [] self.shutdown_actions = [] self._automatic_prometheus = True def run(self) -> None: """ Prepare and run the program. This is the main entrypoint and the only function that should be called manually. """ self._prepare() self._run() def _prepare(self) -> None: start_ts = time() self.preinit() self.log.info(f"Initializing {self.name} {self.version}") try: self.prepare() except Exception: self.log.critical("Unexpected error in initialization", exc_info=True) sys.exit(1) end_ts = time() self.log.info(f"Initialization complete in {round(end_ts - start_ts, 2)} seconds") def preinit(self) -> None: """ First part of startup: parse command-line arguments, load and update config, prepare logger. Exceptions thrown here will crash the program immediately. Asyncio must not be used at this stage, as the loop is only initialized later. """ self.prepare_arg_parser() self.args = self.parser.parse_args() self.prepare_config() self.prepare_log() self.check_config() @property def base_config_path(self) -> str: return f"pkg://{self.module}/example-config.yaml" def prepare_arg_parser(self) -> None: """Pre-init lifecycle method. Extend this if you want custom command-line arguments.""" self.parser = argparse.ArgumentParser(description=self.description, prog=self.command) self.parser.add_argument( "-c", "--config", type=str, default="config.yaml", metavar="", help="the path to your config file", ) self.parser.add_argument( "-n", "--no-update", action="store_true", help="Don't save updated config to disk" ) def prepare_config(self) -> None: """Pre-init lifecycle method. Extend this if you want to customize config loading.""" self.config = self.config_class(self.args.config, self.base_config_path) self.load_and_update_config() def load_and_update_config(self) -> None: self.config.load() try: self.config.update(save=not self.args.no_update) except BaseMissingError: print( "Failed to read base config from the default path " f"({self.base_config_path}). Maybe your installation is corrupted?" ) sys.exit(12) def check_config(self) -> None: """Pre-init lifecycle method. Extend this if you want to customize config validation.""" if not isinstance(self.config, BaseValidatableConfig): return try: self.config.check_default_values() except ConfigValueError as e: self.log.fatal(f"Configuration error: {e}") sys.exit(11) def prepare_log(self) -> None: """Pre-init lifecycle method. Extend this if you want to customize logging setup.""" logging.config.dictConfig(copy.deepcopy(self.config["logging"])) self.log = cast(TraceLogger, logging.getLogger("mau.init")) def prepare(self) -> None: """ Lifecycle method where the primary program initialization happens. Use this to fill startup_actions with async startup tasks. """ self.prepare_loop() def prepare_loop(self) -> None: """Init lifecycle method where the asyncio event loop is created.""" if uvloop is not None: uvloop.install() self.log.debug("Using uvloop for asyncio") self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) def start_prometheus(self) -> None: try: enabled = self.config["metrics.enabled"] listen_port = self.config["metrics.listen_port"] except KeyError: return if not enabled: return elif not prometheus: self.log.warning( "Metrics are enabled in config, but prometheus_client is not installed" ) return prometheus.start_http_server(listen_port) def _run(self) -> None: signal.signal(signal.SIGINT, signal.default_int_handler) signal.signal(signal.SIGTERM, signal.default_int_handler) self._stop_task = self.loop.create_future() exit_code = 0 try: self.log.debug("Running startup actions...") start_ts = time() self.loop.run_until_complete(self.start()) end_ts = time() self.log.info( f"Startup actions complete in {round(end_ts - start_ts, 2)} seconds, " "now running forever" ) exit_code = self.loop.run_until_complete(self._stop_task) self.log.debug("manual_stop() called, stopping...") except KeyboardInterrupt: self.log.debug("Interrupt received, stopping...") except Exception: self.log.critical("Unexpected error in main event loop", exc_info=True) self.loop.run_until_complete(self.system_exit()) sys.exit(2) except SystemExit: self.loop.run_until_complete(self.system_exit()) raise self.prepare_stop() self.loop.run_until_complete(self.stop()) self.prepare_shutdown() self.loop.close() asyncio.set_event_loop(None) self.log.info("Everything stopped, shutting down") sys.exit(exit_code) async def system_exit(self) -> None: """Lifecycle method that is called if the main event loop exits using ``sys.exit()``.""" async def start(self) -> None: """ First lifecycle method called inside the asyncio event loop. Extend this if you want more control over startup than just filling startup_actions in the prepare step. """ if self._automatic_prometheus: self.start_prometheus() await asyncio.gather(*(self.startup_actions or [])) def prepare_stop(self) -> None: """ Lifecycle method that is called before awaiting :meth:`stop`. Useful for filling shutdown_actions. """ async def stop(self) -> None: """ Lifecycle method used to stop things that need awaiting to stop. Extend this if you want more control over shutdown than just filling shutdown_actions in the prepare_stop method. """ await asyncio.gather(*(self.shutdown_actions or [])) def prepare_shutdown(self) -> None: """Lifecycle method that is called right before ``sys.exit(0)``.""" def manual_stop(self, exit_code: int = 0) -> None: """Tell the event loop to cleanly stop and run the stop lifecycle steps.""" self._stop_task.set_result(exit_code) def add_startup_actions(self, *actions: NewTask) -> None: self.startup_actions = self._add_actions(self.startup_actions, actions) def add_shutdown_actions(self, *actions: NewTask) -> None: self.shutdown_actions = self._add_actions(self.shutdown_actions, actions) @staticmethod async def _unpack_async_iterator(iterable: AsyncIterable[Awaitable[Any]]) -> None: tasks = [] async for task in iterable: if inspect.isawaitable(task): tasks.append(asyncio.create_task(task)) await asyncio.gather(*tasks) def _add_actions(self, to: TaskList, add: tuple[NewTask, ...]) -> TaskList: for item in add: if inspect.isasyncgen(item): to.append(self._unpack_async_iterator(item)) elif inspect.isawaitable(item): if isinstance(to, list): to.append(item) else: to = chain(to, [item]) elif isinstance(item, list): if isinstance(to, list): to += item else: to = chain(to, item) else: to = chain(to, item) return to python-0.20.7/mautrix/util/proxy.py000066400000000000000000000075061473573527000173470ustar00rootroot00000000000000from __future__ import annotations from typing import Awaitable, Callable, TypeVar import asyncio import json import logging import time import urllib.request from aiohttp import ClientConnectionError from yarl import URL from mautrix.util.logging import TraceLogger try: from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError except ImportError: class ProxyError(Exception): pass ProxyConnectionError = ProxyTimeoutError = ProxyError RETRYABLE_PROXY_EXCEPTIONS = ( ProxyError, ProxyTimeoutError, ProxyConnectionError, ClientConnectionError, ConnectionError, asyncio.TimeoutError, ) class ProxyHandler: current_proxy_url: str | None = None log = logging.getLogger("mau.proxy") def __init__(self, api_url: str | None) -> None: self.api_url = api_url def get_proxy_url_from_api(self, reason: str | None = None) -> str | None: assert self.api_url is not None api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {})) # NOTE: using urllib.request to intentionally block the whole bridge until the proxy change applied request = urllib.request.Request(api_url, method="GET") self.log.debug("Requesting proxy from: %s", api_url) try: with urllib.request.urlopen(request) as f: response = json.loads(f.read().decode()) except Exception: self.log.exception("Failed to retrieve proxy from API") return self.current_proxy_url else: return response["proxy_url"] def update_proxy_url(self, reason: str | None = None) -> bool: old_proxy = self.current_proxy_url new_proxy = None if self.api_url is not None: new_proxy = self.get_proxy_url_from_api(reason) else: new_proxy = urllib.request.getproxies().get("http") if old_proxy != new_proxy: self.log.debug("Set new proxy URL: %s", new_proxy) self.current_proxy_url = new_proxy return True self.log.debug("Got same proxy URL: %s", new_proxy) return False def get_proxy_url(self) -> str | None: if not self.current_proxy_url: self.update_proxy_url() return self.current_proxy_url T = TypeVar("T") async def proxy_with_retry( name: str, func: Callable[[], Awaitable[T]], logger: TraceLogger, proxy_handler: ProxyHandler, on_proxy_change: Callable[[], Awaitable[None]], max_retries: int = 10, min_wait_seconds: int = 0, max_wait_seconds: int = 60, multiply_wait_seconds: int = 10, retryable_exceptions: tuple[Exception] = RETRYABLE_PROXY_EXCEPTIONS, reset_after_seconds: int | None = None, ) -> T: errors = 0 last_error = 0 while True: try: return await func() except retryable_exceptions as e: errors += 1 if errors > max_retries: raise wait = errors * multiply_wait_seconds wait = max(wait, min_wait_seconds) wait = min(wait, max_wait_seconds) logger.warning( "%s while trying to %s, retrying in %d seconds", e.__class__.__name__, name, wait, ) if errors > 1 and proxy_handler.update_proxy_url( f"{e.__class__.__name__} while trying to {name}" ): await on_proxy_change() # If sufficient time has passed since the previous error, reset the # error count. Useful for long running tasks with rare failures. if reset_after_seconds is not None: now = time.time() if last_error and now - last_error > reset_after_seconds: errors = 0 last_error = now python-0.20.7/mautrix/util/signed_token.py000066400000000000000000000024661473573527000206370ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from hashlib import sha256 import base64 import hmac import json def _get_checksum(key: str, payload: bytes) -> str: hasher = hmac.new(key.encode("utf-8"), msg=payload, digestmod=sha256) checksum = base64.urlsafe_b64encode(hasher.digest()) return checksum.decode("utf-8").rstrip("=") def sign_token(key: str, payload: dict) -> str: payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")) checksum = _get_checksum(key, payload_b64) payload_str = payload_b64.decode("utf-8").rstrip("=") return f"{checksum}:{payload_str}" def verify_token(key: str, data: str) -> dict | None: if not data: return None try: checksum, payload = data.split(":", 1) except ValueError: return None payload += (3 - (len(payload) + 3) % 4) * "=" if checksum != _get_checksum(key, payload.encode("utf-8")): return None payload = base64.urlsafe_b64decode(payload).decode("utf-8") try: return json.loads(payload) except json.JSONDecodeError: return None python-0.20.7/mautrix/util/simple_lock.py000066400000000000000000000027151473573527000204640ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import asyncio import logging class SimpleLock: _event: asyncio.Event log: logging.Logger | None message: str | None noop_mode: bool def __init__( self, message: str | None = None, log: logging.Logger | None = None, noop_mode: bool = False, ) -> None: self.noop_mode = noop_mode if not noop_mode: self._event = asyncio.Event() self._event.set() self.log = log self.message = message def __enter__(self) -> None: if not self.noop_mode: self._event.clear() async def __aenter__(self) -> None: self.__enter__() def __exit__(self, exc_type, exc_val, exc_tb) -> None: if not self.noop_mode: self._event.set() def __aexit__(self, exc_type, exc_val, exc_tb) -> None: self.__exit__(exc_type, exc_val, exc_tb) @property def locked(self) -> bool: return not self.noop_mode and not self._event.is_set() async def wait(self, task: str | None = None) -> None: if self.locked: if self.log and self.message: self.log.debug(self.message, task) await self._event.wait() python-0.20.7/mautrix/util/simple_template.py000066400000000000000000000030601473573527000213410ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations from typing import Generic, TypeVar T = TypeVar("T") class SimpleTemplate(Generic[T]): _template: str _keyword: str _prefix: str _suffix: str _type: type[T] def __init__( self, template: str, keyword: str, prefix: str = "", suffix: str = "", type: type[T] = str ) -> None: self._template = template self._keyword = keyword index = self._template.find("{%s}" % keyword) length = len(keyword) + 2 self._prefix = prefix + self._template[:index] self._suffix = self._template[index + length :] + suffix self._type = type def format(self, arg: T) -> str: return self._template.format(**{self._keyword: arg}) def format_full(self, arg: T) -> str: return f"{self._prefix}{arg}{self._suffix}" def parse(self, val: str) -> T | None: prefix_ok = val[: len(self._prefix)] == self._prefix has_suffix = len(self._suffix) > 0 suffix_ok = not has_suffix or val[-len(self._suffix) :] == self._suffix if prefix_ok and suffix_ok: start = len(self._prefix) end = -len(self._suffix) if has_suffix else len(val) try: return self._type(val[start:end]) except ValueError: pass return None python-0.20.7/mautrix/util/utf16_surrogate.py000066400000000000000000000026531473573527000212240ustar00rootroot00000000000000# From https://github.com/LonamiWebs/Telethon/blob/v1.24.0/telethon/helpers.py#L38-L62 # Copyright (c) LonamiWebs, MIT license import struct def add(text: str) -> str: """ Add surrogate pairs to characters in the text. This makes the indices match how most platforms calculate string length when formatting texts using offset-based entities. Args: text: The text to add surrogate pairs to. Returns: The text with surrogate pairs. """ return "".join( ( "".join(chr(y) for y in struct.unpack(" str: """ Remove surrogate pairs from text. This does the opposite of :func:`add`. Args: text: The text with surrogate pairs. Returns: The text without surrogate pairs. """ return text.encode("utf-16", "surrogatepass").decode("utf-16") def is_within(text: str, index: int, *, length: int = None) -> bool: """ Returns: `True` if ``index`` is within a surrogate (before and after it, not at!). """ if length is None: length = len(text) return ( 1 < index < len(text) and "\ud800" <= text[index - 1] <= "\udfff" # in bounds and "\ud800" <= text[index] <= "\udfff" # previous is # current is ) __all__ = ["add", "remove"] python-0.20.7/mautrix/util/variation_selector.json000066400000000000000000000156511473573527000224030ustar00rootroot00000000000000{ "0023": "#", "002A": "*", "0030": "0", "0031": "1", "0032": "2", "0033": "3", "0034": "4", "0035": "5", "0036": "6", "0037": "7", "0038": "8", "0039": "9", "00A9": "©", "00AE": "®", "203C": "‼", "2049": "⁉", "2122": "™", "2139": "ℹ", "2194": "↔", "2195": "↕", "2196": "↖", "2197": "↗", "2198": "↘", "2199": "↙", "21A9": "↩", "21AA": "↪", "231A": "⌚", "231B": "⌛", "2328": "⌨", "23CF": "⏏", "23E9": "⏩", "23EA": "⏪", "23ED": "⏭", "23EE": "⏮", "23EF": "⏯", "23F1": "⏱", "23F2": "⏲", "23F3": "⏳", "23F8": "⏸", "23F9": "⏹", "23FA": "⏺", "24C2": "Ⓜ", "25AA": "▪", "25AB": "▫", "25B6": "▶", "25C0": "◀", "25FB": "◻", "25FC": "◼", "25FD": "◽", "25FE": "◾", "2600": "☀", "2601": "☁", "2602": "☂", "2603": "☃", "2604": "☄", "260E": "☎", "2611": "☑", "2614": "☔", "2615": "☕", "2618": "☘", "261D": "☝", "2620": "☠", "2622": "☢", "2623": "☣", "2626": "☦", "262A": "☪", "262E": "☮", "262F": "☯", "2638": "☸", "2639": "☹", "263A": "☺", "2640": "♀", "2642": "♂", "2648": "♈", "2649": "♉", "264A": "♊", "264B": "♋", "264C": "♌", "264D": "♍", "264E": "♎", "264F": "♏", "2650": "♐", "2651": "♑", "2652": "♒", "2653": "♓", "265F": "♟", "2660": "♠", "2663": "♣", "2665": "♥", "2666": "♦", "2668": "♨", "267B": "♻", "267E": "♾", "267F": "♿", "2692": "⚒", "2693": "⚓", "2694": "⚔", "2695": "⚕", "2696": "⚖", "2697": "⚗", "2699": "⚙", "269B": "⚛", "269C": "⚜", "26A0": "⚠", "26A1": "⚡", "26A7": "⚧", "26AA": "⚪", "26AB": "⚫", "26B0": "⚰", "26B1": "⚱", "26BD": "⚽", "26BE": "⚾", "26C4": "⛄", "26C5": "⛅", "26C8": "⛈", "26CF": "⛏", "26D1": "⛑", "26D3": "⛓", "26D4": "⛔", "26E9": "⛩", "26EA": "⛪", "26F0": "⛰", "26F1": "⛱", "26F2": "⛲", "26F3": "⛳", "26F4": "⛴", "26F5": "⛵", "26F7": "⛷", "26F8": "⛸", "26F9": "⛹", "26FA": "⛺", "26FD": "⛽", "2702": "✂", "2708": "✈", "2709": "✉", "270C": "✌", "270D": "✍", "270F": "✏", "2712": "✒", "2714": "✔", "2716": "✖", "271D": "✝", "2721": "✡", "2733": "✳", "2734": "✴", "2744": "❄", "2747": "❇", "2753": "❓", "2757": "❗", "2763": "❣", "2764": "❤", "27A1": "➡", "2934": "⤴", "2935": "⤵", "2B05": "⬅", "2B06": "⬆", "2B07": "⬇", "2B1B": "⬛", "2B1C": "⬜", "2B50": "⭐", "2B55": "⭕", "3030": "〰", "303D": "〽", "3297": "㊗", "3299": "㊙", "1F004": "🀄", "1F170": "🅰", "1F171": "🅱", "1F17E": "🅾", "1F17F": "🅿", "1F202": "🈂", "1F21A": "🈚", "1F22F": "🈯", "1F237": "🈷", "1F30D": "🌍", "1F30E": "🌎", "1F30F": "🌏", "1F315": "🌕", "1F31C": "🌜", "1F321": "🌡", "1F324": "🌤", "1F325": "🌥", "1F326": "🌦", "1F327": "🌧", "1F328": "🌨", "1F329": "🌩", "1F32A": "🌪", "1F32B": "🌫", "1F32C": "🌬", "1F336": "🌶", "1F378": "🍸", "1F37D": "🍽", "1F393": "🎓", "1F396": "🎖", "1F397": "🎗", "1F399": "🎙", "1F39A": "🎚", "1F39B": "🎛", "1F39E": "🎞", "1F39F": "🎟", "1F3A7": "🎧", "1F3AC": "🎬", "1F3AD": "🎭", "1F3AE": "🎮", "1F3C2": "🏂", "1F3C4": "🏄", "1F3C6": "🏆", "1F3CA": "🏊", "1F3CB": "🏋", "1F3CC": "🏌", "1F3CD": "🏍", "1F3CE": "🏎", "1F3D4": "🏔", "1F3D5": "🏕", "1F3D6": "🏖", "1F3D7": "🏗", "1F3D8": "🏘", "1F3D9": "🏙", "1F3DA": "🏚", "1F3DB": "🏛", "1F3DC": "🏜", "1F3DD": "🏝", "1F3DE": "🏞", "1F3DF": "🏟", "1F3E0": "🏠", "1F3ED": "🏭", "1F3F3": "🏳", "1F3F5": "🏵", "1F3F7": "🏷", "1F408": "🐈", "1F415": "🐕", "1F41F": "🐟", "1F426": "🐦", "1F43F": "🐿", "1F441": "👁", "1F442": "👂", "1F446": "👆", "1F447": "👇", "1F448": "👈", "1F449": "👉", "1F44D": "👍", "1F44E": "👎", "1F453": "👓", "1F46A": "👪", "1F47D": "👽", "1F4A3": "💣", "1F4B0": "💰", "1F4B3": "💳", "1F4BB": "💻", "1F4BF": "💿", "1F4CB": "📋", "1F4DA": "📚", "1F4DF": "📟", "1F4E4": "📤", "1F4E5": "📥", "1F4E6": "📦", "1F4EA": "📪", "1F4EB": "📫", "1F4EC": "📬", "1F4ED": "📭", "1F4F7": "📷", "1F4F9": "📹", "1F4FA": "📺", "1F4FB": "📻", "1F4FD": "📽", "1F508": "🔈", "1F50D": "🔍", "1F512": "🔒", "1F513": "🔓", "1F549": "🕉", "1F54A": "🕊", "1F550": "🕐", "1F551": "🕑", "1F552": "🕒", "1F553": "🕓", "1F554": "🕔", "1F555": "🕕", "1F556": "🕖", "1F557": "🕗", "1F558": "🕘", "1F559": "🕙", "1F55A": "🕚", "1F55B": "🕛", "1F55C": "🕜", "1F55D": "🕝", "1F55E": "🕞", "1F55F": "🕟", "1F560": "🕠", "1F561": "🕡", "1F562": "🕢", "1F563": "🕣", "1F564": "🕤", "1F565": "🕥", "1F566": "🕦", "1F567": "🕧", "1F56F": "🕯", "1F570": "🕰", "1F573": "🕳", "1F574": "🕴", "1F575": "🕵", "1F576": "🕶", "1F577": "🕷", "1F578": "🕸", "1F579": "🕹", "1F587": "🖇", "1F58A": "🖊", "1F58B": "🖋", "1F58C": "🖌", "1F58D": "🖍", "1F590": "🖐", "1F5A5": "🖥", "1F5A8": "🖨", "1F5B1": "🖱", "1F5B2": "🖲", "1F5BC": "🖼", "1F5C2": "🗂", "1F5C3": "🗃", "1F5C4": "🗄", "1F5D1": "🗑", "1F5D2": "🗒", "1F5D3": "🗓", "1F5DC": "🗜", "1F5DD": "🗝", "1F5DE": "🗞", "1F5E1": "🗡", "1F5E3": "🗣", "1F5E8": "🗨", "1F5EF": "🗯", "1F5F3": "🗳", "1F5FA": "🗺", "1F610": "😐", "1F687": "🚇", "1F68D": "🚍", "1F691": "🚑", "1F694": "🚔", "1F698": "🚘", "1F6AD": "🚭", "1F6B2": "🚲", "1F6B9": "🚹", "1F6BA": "🚺", "1F6BC": "🚼", "1F6CB": "🛋", "1F6CD": "🛍", "1F6CE": "🛎", "1F6CF": "🛏", "1F6E0": "🛠", "1F6E1": "🛡", "1F6E2": "🛢", "1F6E3": "🛣", "1F6E4": "🛤", "1F6E5": "🛥", "1F6E9": "🛩", "1F6F0": "🛰", "1F6F3": "🛳" } python-0.20.7/mautrix/util/variation_selector.py000066400000000000000000000071741473573527000220630ustar00rootroot00000000000000# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations import json import pkgutil import aiohttp EMOJI_VAR_URL = "https://www.unicode.org/Public/14.0.0/ucd/emoji/emoji-variation-sequences.txt" def read_data() -> dict[str, str]: """ Get the list of emoji that need a variation selector. This loads the local data file that was previously generated from the Unicode spec data files. Returns: A dict from hex to the emoji string (you have to bring the variation selectors yourself). """ return json.loads(pkgutil.get_data("mautrix.util", "variation_selector.json")) async def fetch_data() -> dict[str, str]: """ Generate the list of emoji that need a variation selector from the Unicode spec data files. Returns: A dict from hex to the emoji string (you have to bring the variation selectors yourself). """ async with aiohttp.ClientSession() as sess, sess.get(EMOJI_VAR_URL) as resp: data = await resp.text() emojis = {} for line in data.split("\n"): if "emoji style" in line: emoji_hex = line.split(" ", 1)[0] emojis[emoji_hex] = rf"\U{emoji_hex:>08}".encode("ascii").decode("unicode-escape") return emojis if __name__ == "__main__": import asyncio import sys import pkg_resources path = pkg_resources.resource_filename("mautrix.util", "variation_selector.json") emojis = asyncio.run(fetch_data()) with open(path, "w") as file: json.dump(emojis, file, indent=" ", ensure_ascii=False) file.write("\n") print(f"Wrote {len(emojis)} emojis to {path}") sys.exit(0) VARIATION_SELECTOR_16 = "\ufe0f" ADD_VARIATION_TRANSLATION = str.maketrans( {ord(emoji): f"{emoji}{VARIATION_SELECTOR_16}" for emoji in read_data().values()} ) SKIN_TONE_MODIFIERS = ("\U0001F3FB", "\U0001F3FC", "\U0001F3FD", "\U0001F3FE", "\U0001F3FF") SKIN_TONE_REPLACEMENTS = {f"{VARIATION_SELECTOR_16}{mod}": mod for mod in SKIN_TONE_MODIFIERS} VARIATION_SELECTOR_REPLACEMENTS = { **SKIN_TONE_REPLACEMENTS, "\U0001F408\ufe0f\u200d\u2b1b\ufe0f": "\U0001F408\u200d\u2b1b", } def add(val: str) -> str: r""" Add emoji variation selectors (16) to all emojis that have multiple forms in the given string. This will remove all variation selectors first to make sure it doesn't add duplicates. .. versionadded:: 0.12.5 Examples: >>> from mautrix.util import variation_selector >>> variation_selector.add("\U0001f44d") "\U0001f44d\ufe0f" >>> variation_selector.add("\U0001f44d\ufe0f") "\U0001f44d\ufe0f" >>> variation_selector.add("4\u20e3") "4\ufe0f\u20e3" >>> variation_selector.add("\U0001f9d0") "\U0001f9d0" Args: val: The string to add variation selectors to. Returns: The string with variation selectors added. """ added = remove(val).translate(ADD_VARIATION_TRANSLATION) for invalid_selector, replacement in VARIATION_SELECTOR_REPLACEMENTS.items(): added = added.replace(invalid_selector, replacement) return added def remove(val: str) -> str: """ Remove all emoji variation selectors in the given string. .. versionadded:: 0.12.5 Args: val: The string to remove variation selectors from. Returns: The string with variation selectors removed. """ return val.replace(VARIATION_SELECTOR_16, "") __all__ = ["add", "remove", "read_data", "fetch_data"] python-0.20.7/optional-requirements.txt000066400000000000000000000002231473573527000202420ustar00rootroot00000000000000python-magic ruamel.yaml SQLAlchemy<2 commonmark lxml asyncpg aiosqlite prometheus_client setuptools uvloop python-olm unpaddedbase64 pycryptodome python-0.20.7/pyproject.toml000066400000000000000000000004361473573527000160550ustar00rootroot00000000000000[tool.isort] profile = "black" force_to_top = "typing" from_first = true combine_as_imports = true line_length = 99 [tool.black] line-length = 99 target-version = ["py310"] [tool.pytest.ini_options] asyncio_mode = "auto" addopts = "--ignore mautrix/util/db/ --ignore mautrix/bridge/" python-0.20.7/requirements.txt000066400000000000000000000000231473573527000164150ustar00rootroot00000000000000aiohttp attrs yarl python-0.20.7/setup.py000066400000000000000000000033271473573527000146550ustar00rootroot00000000000000import setuptools from mautrix import __version__ encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome"] test_dependencies = ["aiosqlite", "asyncpg", "ruamel.yaml", *encryption_dependencies] setuptools.setup( name="mautrix", version=__version__, url="https://github.com/mautrix/python", project_urls={ "Changelog": "https://github.com/mautrix/python/blob/master/CHANGELOG.md", }, author="Tulir Asokan", author_email="tulir@maunium.net", description="A Python 3 asyncio Matrix framework.", long_description=open("README.rst").read(), packages=setuptools.find_packages(), install_requires=[ "aiohttp>=3,<4", "attrs>=18.1.0", "yarl>=1.5,<2", ], extras_require={ "detect_mimetype": ["python-magic>=0.4.15,<0.5"], "lint": ["black~=24.1", "isort"], "test": ["pytest", "pytest-asyncio", *test_dependencies], "encryption": encryption_dependencies, }, tests_require=test_dependencies, python_requires="~=3.10", classifiers=[ "Development Status :: 4 - Beta", "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)", "Topic :: Communications :: Chat", "Framework :: AsyncIO", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ], package_data={ "mautrix": ["py.typed"], "mautrix.types.event": ["type.pyi"], "mautrix.util": ["opt_prometheus.pyi", "variation_selector.json"], "mautrix.util.formatter": ["html_reader.pyi"], }, )