`")
return
whitelist = set()
whitelist_whitelist = evt.config["manhole.whitelist"]
for arg in evt.args:
try:
uid = int(arg)
except ValueError:
await evt.reply(f"{arg} is not an integer.")
return
if whitelist_whitelist and uid not in whitelist_whitelist:
await evt.reply(f"{uid} is not in the list of allowed UIDs.")
return
whitelist.add(uid)
if evt.bridge.manhole:
added = [uid for uid in whitelist if uid not in evt.bridge.manhole.whitelist]
evt.bridge.manhole.whitelist |= set(added)
if len(added) == 0:
await evt.reply(
f"There's an existing manhole opened by {evt.bridge.manhole.opened_by}"
" and all the given UIDs are already whitelisted."
)
else:
added_str = (
f"{', '.join(str(uid) for uid in added[:-1])} and {added[-1]}"
if len(added) > 1
else added[0]
)
await evt.reply(
f"There's an existing manhole opened by {evt.bridge.manhole.opened_by}"
f". Added {added_str} to the whitelist."
)
evt.log.info(f"{evt.sender.mxid} added {added_str} to the manhole whitelist.")
return
namespace = await evt.bridge.manhole_global_namespace(evt.sender.mxid)
banner = evt.bridge.manhole_banner(evt.sender.mxid)
path = evt.config["manhole.path"]
wl_list = list(whitelist)
whitelist_str = (
f"{', '.join(str(uid) for uid in wl_list[:-1])} and {wl_list[-1]}"
if len(wl_list) > 1
else wl_list[0]
)
evt.log.info(f"{evt.sender.mxid} opened a manhole with {whitelist_str} whitelisted.")
server, close = await start_manhole(
path=path, banner=banner, namespace=namespace, loop=evt.loop, whitelist=whitelist
)
evt.bridge.manhole = ManholeState(
server=server, opened_by=evt.sender.mxid, close=close, whitelist=whitelist
)
plrl = "s" if len(whitelist) != 1 else ""
await evt.reply(f"Opened manhole at unix://{path} with UID{plrl} {whitelist_str} whitelisted")
await server.wait_closed()
evt.bridge.manhole = None
try:
os.unlink(path)
except FileNotFoundError:
pass
evt.log.info(f"{evt.sender.mxid}'s manhole was closed.")
try:
await evt.reply("Your manhole was closed.")
except (AttributeError, MatrixConnectionError) as e:
evt.log.warning(f"Failed to send manhole close notification: {e}")
@command_handler(
needs_auth=False,
needs_admin=True,
help_section=SECTION_ADMIN,
help_text="Close an open manhole.",
)
async def close_manhole(evt: CommandEvent) -> None:
if not evt.bridge.manhole:
await evt.reply("There is no open manhole.")
return
opened_by = evt.bridge.manhole.opened_by
evt.bridge.manhole.close()
evt.bridge.manhole = None
if opened_by != evt.sender.mxid:
await evt.reply(f"Closed manhole opened by {opened_by}")
python-0.20.4/mautrix/bridge/commands/meta.py 0000664 0000000 0000000 00000006022 14547234302 0021150 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.types import EventID
from .handler import (
SECTION_GENERAL,
CommandEvent,
HelpCacheKey,
HelpSection,
command_handler,
command_handlers,
)
@command_handler(
needs_auth=False, help_section=SECTION_GENERAL, help_text="Cancel an ongoing action."
)
async def cancel(evt: CommandEvent) -> EventID:
if evt.sender.command_status:
action = evt.sender.command_status["action"]
evt.sender.command_status = None
return await evt.reply(f"{action} cancelled.")
else:
return await evt.reply("No ongoing command.")
@command_handler(
needs_auth=False, help_section=SECTION_GENERAL, help_text="Get the bridge version."
)
async def version(evt: CommandEvent) -> None:
if not evt.processor.bridge:
await evt.reply("Bridge version unknown")
else:
await evt.reply(
f"[{evt.processor.bridge.name}]({evt.processor.bridge.repo_url}) "
f"{evt.processor.bridge.markdown_version or evt.processor.bridge.version}"
)
@command_handler(needs_auth=False)
async def unknown_command(evt: CommandEvent) -> EventID:
return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
help_cache: dict[HelpCacheKey, str] = {}
async def _get_help_text(evt: CommandEvent) -> str:
cache_key = await evt.get_help_key()
if cache_key not in help_cache:
help_sections: dict[HelpSection, list[str]] = {}
for handler in command_handlers.values():
if (
handler.has_help
and handler.has_permission(cache_key)
and handler.is_enabled_for(evt)
):
help_sections.setdefault(handler.help_section, [])
help_sections[handler.help_section].append(handler.help + " ")
help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order)
helps = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted]
help_cache[cache_key] = "\n".join(helps)
return help_cache[cache_key]
def _get_management_status(evt: CommandEvent) -> str:
if evt.is_management:
return "This is a management room: prefixing commands with `$cmdprefix` is not required."
elif evt.is_portal:
return (
"**This is a portal room**: you must always prefix commands with `$cmdprefix`.\n"
"Management commands will not be bridged."
)
return "**This is not a management room**: you must prefix commands with `$cmdprefix`."
@command_handler(
name="help",
needs_auth=False,
help_section=SECTION_GENERAL,
help_text="Show this help message.",
)
async def help_cmd(evt: CommandEvent) -> EventID:
return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt))
python-0.20.4/mautrix/bridge/commands/relay.py 0000664 0000000 0000000 00000003023 14547234302 0021334 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from mautrix.types import EventID
from .handler import SECTION_RELAY, CommandEvent, command_handler
@command_handler(
needs_auth=True,
management_only=False,
name="set-relay",
help_section=SECTION_RELAY,
help_text="Relay messages in this room through your account.",
is_enabled_for=lambda evt: evt.config["bridge.relay.enabled"],
)
async def set_relay(evt: CommandEvent) -> EventID:
if not evt.is_portal:
return await evt.reply("This is not a portal room.")
await evt.portal.set_relay_user(evt.sender)
return await evt.reply(
"Messages from non-logged-in users in this room will now be bridged "
"through your account."
)
@command_handler(
needs_auth=True,
management_only=False,
name="unset-relay",
help_section=SECTION_RELAY,
help_text="Stop relaying messages in this room.",
is_enabled_for=lambda evt: evt.config["bridge.relay.enabled"],
)
async def unset_relay(evt: CommandEvent) -> EventID:
if not evt.is_portal:
return await evt.reply("This is not a portal room.")
elif not evt.portal.has_relay:
return await evt.reply("This room does not have a relay user set.")
await evt.portal.set_relay_user(None)
return await evt.reply("Messages from non-logged-in users will no longer be bridged.")
python-0.20.4/mautrix/bridge/config.py 0000664 0000000 0000000 00000021552 14547234302 0017673 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, ClassVar
from abc import ABC
import json
import os
import re
import secrets
import time
from mautrix.util.config import (
BaseFileConfig,
BaseValidatableConfig,
ConfigUpdateHelper,
ForbiddenDefault,
yaml,
)
class BaseBridgeConfig(BaseFileConfig, BaseValidatableConfig, ABC):
env_prefix: str | None = None
registration_path: str
_registration: dict | None
_check_tokens: bool
env: dict[str, Any]
def __init__(
self, path: str, registration_path: str, base_path: str, env_prefix: str | None = None
) -> None:
super().__init__(path, base_path)
self.registration_path = registration_path
self._registration = None
self._check_tokens = True
self.env = {}
if not self.env_prefix:
self.env_prefix = env_prefix
if self.env_prefix:
env_prefix = f"{self.env_prefix}_"
for key, value in os.environ.items():
if not key.startswith(env_prefix):
continue
key = key.removeprefix(env_prefix)
if value.startswith("json::"):
value = json.loads(value.removeprefix("json::"))
self.env[key] = value
def __getitem__(self, item: str) -> Any:
if self.env:
try:
sanitized_item = item.replace(".", "_").replace("[", "").replace("]", "").upper()
return self.env[sanitized_item]
except KeyError:
pass
return super().__getitem__(item)
def save(self) -> None:
super().save()
if self._registration and self.registration_path:
with open(self.registration_path, "w") as stream:
yaml.dump(self._registration, stream)
@staticmethod
def _new_token() -> str:
return secrets.token_urlsafe(48)
@property
def forbidden_defaults(self) -> list[ForbiddenDefault]:
return [
ForbiddenDefault("homeserver.address", "https://example.com"),
ForbiddenDefault("homeserver.address", "https://matrix.example.com"),
ForbiddenDefault("homeserver.domain", "example.com"),
] + (
[
ForbiddenDefault(
"appservice.as_token",
"This value is generated when generating the registration",
"Did you forget to generate the registration?",
),
ForbiddenDefault(
"appservice.hs_token",
"This value is generated when generating the registration",
"Did you forget to generate the registration?",
),
]
if self._check_tokens
else []
)
def do_update(self, helper: ConfigUpdateHelper) -> None:
copy, copy_dict = helper.copy, helper.copy_dict
copy("homeserver.address")
copy("homeserver.domain")
copy("homeserver.verify_ssl")
copy("homeserver.http_retry_count")
copy("homeserver.connection_limit")
copy("homeserver.status_endpoint")
copy("homeserver.message_send_checkpoint_endpoint")
copy("homeserver.async_media")
if self.get("homeserver.asmux", False):
helper.base["homeserver.software"] = "asmux"
else:
copy("homeserver.software")
copy("appservice.address")
copy("appservice.hostname")
copy("appservice.port")
copy("appservice.max_body_size")
copy("appservice.tls_cert")
copy("appservice.tls_key")
if "appservice.database" in self and self["appservice.database"].startswith("sqlite:///"):
helper.base["appservice.database"] = self["appservice.database"].replace(
"sqlite:///", "sqlite:"
)
else:
copy("appservice.database")
copy("appservice.database_opts")
copy("appservice.id")
copy("appservice.bot_username")
copy("appservice.bot_displayname")
copy("appservice.bot_avatar")
copy("appservice.as_token")
copy("appservice.hs_token")
copy("appservice.ephemeral_events")
copy("bridge.management_room_text.welcome")
copy("bridge.management_room_text.welcome_connected")
copy("bridge.management_room_text.welcome_unconnected")
copy("bridge.management_room_text.additional_help")
copy("bridge.management_room_multiple_messages")
copy("bridge.encryption.allow")
copy("bridge.encryption.default")
copy("bridge.encryption.require")
copy("bridge.encryption.appservice")
copy("bridge.encryption.delete_keys.delete_outbound_on_ack")
copy("bridge.encryption.delete_keys.dont_store_outbound")
copy("bridge.encryption.delete_keys.ratchet_on_decrypt")
copy("bridge.encryption.delete_keys.delete_fully_used_on_decrypt")
copy("bridge.encryption.delete_keys.delete_prev_on_new_session")
copy("bridge.encryption.delete_keys.delete_on_device_delete")
copy("bridge.encryption.delete_keys.periodically_delete_expired")
copy("bridge.encryption.delete_keys.delete_outdated_inbound")
copy("bridge.encryption.verification_levels.receive")
copy("bridge.encryption.verification_levels.send")
copy("bridge.encryption.verification_levels.share")
copy("bridge.encryption.allow_key_sharing")
if self.get("bridge.encryption.key_sharing.allow", False):
helper.base["bridge.encryption.allow_key_sharing"] = True
require_verif = self.get("bridge.encryption.key_sharing.require_verification", True)
require_cs = self.get("bridge.encryption.key_sharing.require_cross_signing", False)
if require_verif:
helper.base["bridge.encryption.verification_levels.share"] = "verified"
elif not require_cs:
helper.base["bridge.encryption.verification_levels.share"] = "unverified"
# else: default (cross-signed-tofu)
copy("bridge.encryption.rotation.enable_custom")
copy("bridge.encryption.rotation.milliseconds")
copy("bridge.encryption.rotation.messages")
copy("bridge.encryption.rotation.disable_device_change_key_rotation")
copy("bridge.relay.enabled")
copy_dict("bridge.relay.message_formats", override_existing_map=False)
copy("manhole.enabled")
copy("manhole.path")
copy("manhole.whitelist")
copy("logging")
@property
def namespaces(self) -> dict[str, list[dict[str, Any]]]:
"""
Generate the user ID and room alias namespace config for the registration as specified in
https://matrix.org/docs/spec/application_service/r0.1.0.html#application-services
"""
homeserver = self["homeserver.domain"]
regex_ph = f"regexplaceholder{int(time.time())}"
username_format = self["bridge.username_template"].format(userid=regex_ph)
alias_format = (
self["bridge.alias_template"].format(groupname=regex_ph)
if "bridge.alias_template" in self
else None
)
return {
"users": [
{
"exclusive": True,
"regex": re.escape(f"@{username_format}:{homeserver}").replace(regex_ph, ".*"),
}
],
"aliases": [
{
"exclusive": True,
"regex": re.escape(f"#{alias_format}:{homeserver}").replace(regex_ph, ".*"),
}
]
if alias_format
else [],
}
def generate_registration(self) -> None:
self["appservice.as_token"] = self._new_token()
self["appservice.hs_token"] = self._new_token()
namespaces = self.namespaces
bot_username = self["appservice.bot_username"]
homeserver_domain = self["homeserver.domain"]
namespaces.setdefault("users", []).append(
{
"exclusive": True,
"regex": re.escape(f"@{bot_username}:{homeserver_domain}"),
}
)
self._registration = {
"id": self["appservice.id"],
"as_token": self["appservice.as_token"],
"hs_token": self["appservice.hs_token"],
"namespaces": namespaces,
"url": self["appservice.address"],
"sender_localpart": self._new_token(),
"rate_limited": False,
}
if self["appservice.ephemeral_events"]:
self._registration["de.sorunome.msc2409.push_ephemeral"] = True
self._registration["push_ephemeral"] = True
python-0.20.4/mautrix/bridge/crypto_state_store.py 0000664 0000000 0000000 00000003502 14547234302 0022355 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Awaitable, Callable
from abc import ABC
from mautrix import __optional_imports__
from mautrix.bridge.portal import BasePortal
from mautrix.crypto import StateStore
from mautrix.types import RoomEncryptionStateEventContent, RoomID, UserID
from mautrix.util.async_db import Database
GetPortalFunc = Callable[[RoomID], Awaitable[BasePortal]]
class BaseCryptoStateStore(StateStore, ABC):
get_portal: GetPortalFunc
def __init__(self, get_portal: GetPortalFunc):
self.get_portal = get_portal
async def is_encrypted(self, room_id: RoomID) -> bool:
portal = await self.get_portal(room_id)
return portal.encrypted if portal else False
class PgCryptoStateStore(BaseCryptoStateStore):
db: Database
def __init__(self, db: Database, get_portal: GetPortalFunc) -> None:
super().__init__(get_portal)
self.db = db
async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]:
rows = await self.db.fetch(
"SELECT room_id FROM mx_user_profile "
"LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id "
"WHERE user_id=$1 AND portal.encrypted=true",
user_id,
)
return [row["room_id"] for row in rows]
async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
val = await self.db.fetchval(
"SELECT encryption FROM mx_room_state WHERE room_id=$1", room_id
)
if not val:
return None
return RoomEncryptionStateEventContent.parse_json(val)
python-0.20.4/mautrix/bridge/custom_puppet.py 0000664 0000000 0000000 00000030716 14547234302 0021337 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import hashlib
import hmac
import logging
from yarl import URL
from mautrix.appservice import AppService, IntentAPI
from mautrix.client import ClientAPI
from mautrix.errors import (
IntentError,
MatrixError,
MatrixInvalidToken,
MatrixRequestError,
WellKnownError,
)
from mautrix.types import LoginType, MatrixUserIdentifier, RoomID, UserID
from .. import bridge as br
class CustomPuppetError(MatrixError):
"""Base class for double puppeting setup errors."""
class InvalidAccessToken(CustomPuppetError):
def __init__(self):
super().__init__("The given access token was invalid.")
class OnlyLoginSelf(CustomPuppetError):
def __init__(self):
super().__init__("You may only enable double puppeting with your own Matrix account.")
class EncryptionKeysFound(CustomPuppetError):
def __init__(self):
super().__init__(
"The given access token is for a device that has encryption keys set up. "
"Please provide a fresh token, don't reuse one from another client."
)
class HomeserverURLNotFound(CustomPuppetError):
def __init__(self, domain: str):
super().__init__(
f"Could not discover a valid homeserver URL for {domain}."
" Please ensure a client .well-known file is set up, or ask the bridge administrator "
"to add the homeserver URL to the bridge config."
)
class OnlyLoginTrustedDomain(CustomPuppetError):
def __init__(self):
super().__init__(
"This bridge doesn't allow double-puppeting with accounts on untrusted servers."
)
class AutologinError(CustomPuppetError):
pass
class CustomPuppetMixin(ABC):
"""
Mixin for the Puppet class to enable Matrix puppeting.
Attributes:
sync_with_custom_puppets: Whether or not custom puppets should /sync
allow_discover_url: Allow logging into other homeservers using .well-known discovery.
homeserver_url_map: Static map from server name to URL that are always allowed to log in.
only_handle_own_synced_events: Whether or not typing notifications and read receipts by
other users should be filtered away before passing them to
the Matrix event handler.
az: The AppService object.
loop: The asyncio event loop.
log: The logger to use.
mx: The Matrix event handler to send /sync events to.
by_custom_mxid: A mapping from custom mxid to puppet object.
default_mxid: The default user ID of the puppet.
default_mxid_intent: The IntentAPI for the default user ID.
custom_mxid: The user ID of the custom puppet.
access_token: The access token for the custom puppet.
intent: The primary IntentAPI.
"""
allow_discover_url: bool = False
homeserver_url_map: dict[str, URL] = {}
only_handle_own_synced_events: bool = True
login_shared_secret_map: dict[str, bytes] = {}
login_device_name: str | None = None
az: AppService
loop: asyncio.AbstractEventLoop
log: logging.Logger
mx: br.BaseMatrixHandler
by_custom_mxid: dict[UserID, CustomPuppetMixin] = {}
default_mxid: UserID
default_mxid_intent: IntentAPI
custom_mxid: UserID | None
access_token: str | None
base_url: URL | None
intent: IntentAPI
@abstractmethod
async def save(self) -> None:
"""Save the information of this puppet. Called from :meth:`switch_mxid`"""
@property
def mxid(self) -> UserID:
"""The main Matrix user ID of this puppet."""
return self.custom_mxid or self.default_mxid
@property
def is_real_user(self) -> bool:
"""Whether this puppet uses a real Matrix user instead of an appservice-owned ID."""
return bool(self.custom_mxid and self.access_token)
def _fresh_intent(self) -> IntentAPI:
if self.access_token == "appservice-config" and self.custom_mxid:
_, server = self.az.intent.parse_user_id(self.custom_mxid)
try:
secret = self.login_shared_secret_map[server]
except KeyError:
raise AutologinError(f"No shared secret configured for {server}")
self.log.debug(f"Using as_token for double puppeting {self.custom_mxid}")
return self.az.intent.user(
self.custom_mxid,
secret.decode("utf-8").removeprefix("as_token:"),
self.base_url,
as_token=True,
)
return (
self.az.intent.user(self.custom_mxid, self.access_token, self.base_url)
if self.is_real_user
else self.default_mxid_intent
)
@classmethod
def can_auto_login(cls, mxid: UserID) -> bool:
_, server = cls.az.intent.parse_user_id(mxid)
return server in cls.login_shared_secret_map and (
server in cls.homeserver_url_map or server == cls.az.domain
)
@classmethod
async def _login_with_shared_secret(cls, mxid: UserID) -> str:
_, server = cls.az.intent.parse_user_id(mxid)
try:
secret = cls.login_shared_secret_map[server]
except KeyError:
raise AutologinError(f"No shared secret configured for {server}")
if secret.startswith(b"as_token:"):
return "appservice-config"
try:
base_url = cls.homeserver_url_map[server]
except KeyError:
if server == cls.az.domain:
base_url = cls.az.intent.api.base_url
else:
raise AutologinError(f"No homeserver URL configured for {server}")
client = ClientAPI(base_url=base_url)
login_args = {}
if secret == b"appservice":
login_type = LoginType.APPSERVICE
client.api.token = cls.az.as_token
else:
flows = await client.get_login_flows()
flow = flows.get_first_of_type(LoginType.DEVTURE_SHARED_SECRET, LoginType.PASSWORD)
if not flow:
raise AutologinError("No supported shared secret auth login flows")
login_type = flow.type
token = hmac.new(secret, mxid.encode("utf-8"), hashlib.sha512).hexdigest()
if login_type == LoginType.DEVTURE_SHARED_SECRET:
login_args["token"] = token
elif login_type == LoginType.PASSWORD:
login_args["password"] = token
resp = await client.login(
identifier=MatrixUserIdentifier(user=mxid),
device_id=cls.login_device_name,
initial_device_display_name=cls.login_device_name,
login_type=login_type,
**login_args,
store_access_token=False,
update_hs_url=False,
)
return resp.access_token
async def switch_mxid(
self, access_token: str | None, mxid: UserID | None, start_sync_task: bool = True
) -> None:
"""
Switch to a real Matrix user or away from one.
Args:
access_token: The access token for the custom account, or ``None`` to switch back to
the appservice-owned ID.
mxid: The expected Matrix user ID of the custom account, or ``None`` when
``access_token`` is None.
"""
if access_token == "auto":
access_token = await self._login_with_shared_secret(mxid)
if access_token != "appservice-config":
self.log.debug(f"Logged in for {mxid} using shared secret")
if mxid is not None:
_, mxid_domain = self.az.intent.parse_user_id(mxid)
if mxid_domain in self.homeserver_url_map:
base_url = self.homeserver_url_map[mxid_domain]
elif mxid_domain == self.az.domain:
base_url = None
else:
if not self.allow_discover_url:
raise OnlyLoginTrustedDomain()
try:
base_url = await IntentAPI.discover(mxid_domain, self.az.http_session)
except WellKnownError as e:
raise HomeserverURLNotFound(mxid_domain) from e
if base_url is None:
raise HomeserverURLNotFound(mxid_domain)
else:
base_url = None
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
self.base_url = base_url
self.intent = self._fresh_intent()
await self.start(check_e2ee_keys=True)
try:
del self.by_custom_mxid[prev_mxid]
except KeyError:
pass
if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self
try:
await self._leave_rooms_with_default_user()
except Exception:
self.log.warning("Error when leaving rooms with default user", exc_info=True)
await self.save()
async def try_start(self, retry_auto_login: bool = True) -> None:
try:
await self.start(retry_auto_login=retry_auto_login)
except Exception:
self.log.exception("Failed to initialize custom mxid")
async def _invalidate_double_puppet(self) -> None:
if self.custom_mxid and self.by_custom_mxid.get(self.custom_mxid) == self:
del self.by_custom_mxid[self.custom_mxid]
self.custom_mxid = None
self.access_token = None
await self.save()
self.intent = self._fresh_intent()
async def start(
self,
retry_auto_login: bool = False,
start_sync_task: bool = True,
check_e2ee_keys: bool = False,
) -> None:
"""Initialize the custom account this puppet uses. Should be called at startup to start
the /sync task. Is called by :meth:`switch_mxid` automatically."""
if not self.is_real_user:
return
try:
whoami = await self.intent.whoami()
except MatrixInvalidToken as e:
if retry_auto_login and self.custom_mxid and self.can_auto_login(self.custom_mxid):
self.log.debug(f"Got {e.errcode} while trying to initialize custom mxid")
await self.switch_mxid("auto", self.custom_mxid)
return
self.log.warning(f"Got {e.errcode} while trying to initialize custom mxid")
whoami = None
if not whoami or whoami.user_id != self.custom_mxid:
prev_custom_mxid = self.custom_mxid
await self._invalidate_double_puppet()
if whoami and whoami.user_id != prev_custom_mxid:
raise OnlyLoginSelf()
raise InvalidAccessToken()
if check_e2ee_keys:
try:
devices = await self.intent.query_keys({whoami.user_id: [whoami.device_id]})
device_keys = devices.device_keys.get(whoami.user_id, {}).get(whoami.device_id)
except Exception:
self.log.warning(
"Failed to query keys to check if double puppeting token was reused",
exc_info=True,
)
else:
if device_keys and len(device_keys.keys) > 0:
await self._invalidate_double_puppet()
raise EncryptionKeysFound()
self.log.info(f"Initialized custom mxid: {whoami.user_id}")
def stop(self) -> None:
"""
No-op
.. deprecated:: 0.20.1
"""
async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
"""
Whether or not the default puppet user should leave the given room when this puppet is
switched to using a custom user account.
Args:
room_id: The room to check.
Returns:
Whether or not the default user account should leave.
"""
return True
async def _leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms():
try:
if await self.default_puppet_should_leave_room(room_id):
await self.default_mxid_intent.leave_room(room_id)
await self.intent.ensure_joined(room_id)
except (IntentError, MatrixRequestError):
pass
python-0.20.4/mautrix/bridge/disappearing_message.py 0000664 0000000 0000000 00000002461 14547234302 0022576 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import TypeVar
from abc import ABC, abstractmethod
import time
from attr import dataclass
from mautrix.types import EventID, RoomID
@dataclass
class AbstractDisappearingMessage(ABC):
room_id: RoomID
event_id: EventID
expiration_seconds: int
expiration_ts: int | None = None
@abstractmethod
async def insert(self) -> None:
pass
@abstractmethod
async def update(self) -> None:
pass
def start_timer(self) -> None:
self.expiration_ts = int(time.time() * 1000) + (self.expiration_seconds * 1000)
@abstractmethod
async def delete(self) -> None:
pass
@classmethod
@abstractmethod
async def get_all_scheduled(cls: type[DisappearingMessage]) -> list[DisappearingMessage]:
pass
@classmethod
@abstractmethod
async def get_unscheduled_for_room(
cls: type[DisappearingMessage], room_id: RoomID
) -> list[DisappearingMessage]:
pass
DisappearingMessage = TypeVar("DisappearingMessage", bound=AbstractDisappearingMessage)
python-0.20.4/mautrix/bridge/e2ee.py 0000664 0000000 0000000 00000037025 14547234302 0017250 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import asyncio
import logging
import sys
from mautrix import __optional_imports__
from mautrix.appservice import AppService
from mautrix.client import Client, InternalEventType, SyncStore
from mautrix.crypto import CryptoStore, OlmMachine, PgCryptoStore, RejectKeyShare, StateStore
from mautrix.errors import EncryptionError, MForbidden, MNotFound, SessionNotFound
from mautrix.types import (
JSON,
DeviceIdentity,
EncryptedEvent,
EncryptedMegolmEventContent,
EventFilter,
EventType,
Filter,
LoginType,
MessageEvent,
RequestedKeyInfo,
RoomEventFilter,
RoomFilter,
RoomID,
RoomKeyWithheldCode,
Serializable,
StateEvent,
StateFilter,
TrustState,
)
from mautrix.util import background_task
from mautrix.util.async_db import Database
from mautrix.util.logging import TraceLogger
from .. import bridge as br
from .crypto_state_store import PgCryptoStateStore
class EncryptionManager:
loop: asyncio.AbstractEventLoop
log: TraceLogger = logging.getLogger("mau.bridge.e2ee")
client: Client
crypto: OlmMachine
crypto_store: CryptoStore | SyncStore
crypto_db: Database | None
state_store: StateStore
min_send_trust: TrustState
key_sharing_enabled: bool
appservice_mode: bool
periodically_delete_expired_keys: bool
delete_outdated_inbound: bool
bridge: br.Bridge
az: AppService
_id_prefix: str
_id_suffix: str
_share_session_events: dict[RoomID, asyncio.Event]
_key_delete_task: asyncio.Task | None
def __init__(
self,
bridge: br.Bridge,
homeserver_address: str,
user_id_prefix: str,
user_id_suffix: str,
db_url: str,
) -> None:
self.loop = bridge.loop or asyncio.get_event_loop()
self.bridge = bridge
self.az = bridge.az
self.device_name = bridge.name
self._id_prefix = user_id_prefix
self._id_suffix = user_id_suffix
self._share_session_events = {}
pickle_key = "mautrix.bridge.e2ee"
self.crypto_db = Database.create(
url=db_url,
upgrade_table=PgCryptoStore.upgrade_table,
log=logging.getLogger("mau.crypto.db"),
)
self.crypto_store = PgCryptoStore("", pickle_key, self.crypto_db)
self.state_store = PgCryptoStateStore(self.crypto_db, bridge.get_portal)
default_http_retry_count = bridge.config.get("homeserver.http_retry_count", None)
self.client = Client(
base_url=homeserver_address,
mxid=self.az.bot_mxid,
loop=self.loop,
sync_store=self.crypto_store,
log=self.log.getChild("client"),
default_retry_count=default_http_retry_count,
state_store=self.bridge.state_store,
)
self.crypto = OlmMachine(self.client, self.crypto_store, self.state_store)
self.client.add_event_handler(InternalEventType.SYNC_STOPPED, self._exit_on_sync_fail)
self.crypto.allow_key_share = self.allow_key_share
verification_levels = bridge.config["bridge.encryption.verification_levels"]
self.min_send_trust = TrustState.parse(verification_levels["send"])
self.crypto.share_keys_min_trust = TrustState.parse(verification_levels["share"])
self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"])
self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"]
self.appservice_mode = bridge.config["bridge.encryption.appservice"]
if self.appservice_mode:
self.az.otk_handler = self.crypto.handle_as_otk_counts
self.az.device_list_handler = self.crypto.handle_as_device_lists
self.az.to_device_handler = self.crypto.handle_as_to_device_event
self.periodically_delete_expired_keys = False
self.delete_outdated_inbound = False
self._key_delete_task = None
del_cfg = bridge.config["bridge.encryption.delete_keys"]
if del_cfg:
self.crypto.delete_outbound_keys_on_ack = del_cfg["delete_outbound_on_ack"]
self.crypto.dont_store_outbound_keys = del_cfg["dont_store_outbound"]
self.crypto.delete_previous_keys_on_receive = del_cfg["delete_prev_on_new_session"]
self.crypto.ratchet_keys_on_decrypt = del_cfg["ratchet_on_decrypt"]
self.crypto.delete_fully_used_keys_on_decrypt = del_cfg["delete_fully_used_on_decrypt"]
self.crypto.delete_keys_on_device_delete = del_cfg["delete_on_device_delete"]
self.periodically_delete_expired_keys = del_cfg["periodically_delete_expired"]
self.delete_outdated_inbound = del_cfg["delete_outdated_inbound"]
self.crypto.disable_device_change_key_rotation = bridge.config[
"bridge.encryption.rotation.disable_device_change_key_rotation"
]
async def _exit_on_sync_fail(self, data) -> None:
if data["error"]:
self.log.critical("Exiting due to crypto sync error")
sys.exit(32)
async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInfo) -> bool:
if not self.key_sharing_enabled:
self.log.debug(
f"Key sharing not enabled, ignoring key request from "
f"{device.user_id}/{device.device_id}"
)
return False
elif device.trust == TrustState.BLACKLISTED:
raise RejectKeyShare(
f"Rejecting key request from blacklisted device "
f"{device.user_id}/{device.device_id}",
code=RoomKeyWithheldCode.BLACKLISTED,
reason="Your device has been blacklisted by the bridge",
)
elif await self.crypto.resolve_trust(device) >= self.crypto.share_keys_min_trust:
portal = await self.bridge.get_portal(request.room_id)
if portal is None:
raise RejectKeyShare(
f"Rejecting key request for {request.session_id} from "
f"{device.user_id}/{device.device_id}: room is not a portal",
code=RoomKeyWithheldCode.UNAVAILABLE,
reason="Requested room is not a portal",
)
user = await self.bridge.get_user(device.user_id)
if not await user.is_in_portal(portal):
raise RejectKeyShare(
f"Rejecting key request for {request.session_id} from "
f"{device.user_id}/{device.device_id}: user is not in portal",
code=RoomKeyWithheldCode.UNAUTHORIZED,
reason="You're not in that portal",
)
self.log.debug(
f"Accepting key request for {request.session_id} from "
f"{device.user_id}/{device.device_id}"
)
return True
else:
raise RejectKeyShare(
f"Rejecting key request from unverified device "
f"{device.user_id}/{device.device_id}",
code=RoomKeyWithheldCode.UNVERIFIED,
reason="Your device is not trusted by the bridge",
)
def _ignore_user(self, user_id: str) -> bool:
return (
user_id.startswith(self._id_prefix)
and user_id.endswith(self._id_suffix)
and user_id != self.az.bot_mxid
)
async def handle_member_event(self, evt: StateEvent) -> None:
if self._ignore_user(evt.state_key):
# We don't want to invalidate group sessions because a ghost left or joined
return
await self.crypto.handle_member_event(evt)
async def _share_session_lock(self, room_id: RoomID) -> bool:
try:
event = self._share_session_events[room_id]
except KeyError:
self._share_session_events[room_id] = asyncio.Event()
return True
else:
await event.wait()
return False
async def encrypt(
self, room_id: RoomID, event_type: EventType, content: Serializable | JSON
) -> tuple[EventType, EncryptedMegolmEventContent]:
try:
encrypted = await self.crypto.encrypt_megolm_event(room_id, event_type, content)
except EncryptionError:
self.log.debug("Got EncryptionError, sharing group session and trying again")
if await self._share_session_lock(room_id):
try:
users = await self.az.state_store.get_members_filtered(
room_id, self._id_prefix, self._id_suffix, self.az.bot_mxid
)
await self.crypto.share_group_session(room_id, users)
finally:
self._share_session_events.pop(room_id).set()
encrypted = await self.crypto.encrypt_megolm_event(room_id, event_type, content)
return EventType.ROOM_ENCRYPTED, encrypted
async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> MessageEvent:
try:
decrypted = await self.crypto.decrypt_megolm_event(evt)
except SessionNotFound as e:
if not wait_session_timeout:
raise
self.log.debug(
f"Couldn't find session {e.session_id} trying to decrypt {evt.event_id},"
f" waiting {wait_session_timeout} seconds..."
)
got_keys = await self.crypto.wait_for_session(
evt.room_id, e.session_id, timeout=wait_session_timeout
)
if got_keys:
self.log.debug(
f"Got session {e.session_id} after waiting, "
f"trying to decrypt {evt.event_id} again"
)
decrypted = await self.crypto.decrypt_megolm_event(evt)
else:
raise
self.log.trace("Decrypted event %s: %s", evt.event_id, decrypted)
return decrypted
async def start(self) -> None:
flows = await self.client.get_login_flows()
if not flows.supports_type(LoginType.APPSERVICE):
self.log.critical(
"Encryption enabled in config, but homeserver does not support appservice login"
)
sys.exit(30)
self.log.debug("Logging in with bridge bot user")
if self.crypto_db:
try:
await self.crypto_db.start()
except Exception as e:
self.bridge._log_db_error(e)
await self.crypto_store.open()
device_id = await self.crypto_store.get_device_id()
if device_id:
self.log.debug(f"Found device ID in database: {device_id}")
# We set the API token to the AS token here to authenticate the appservice login
# It'll get overridden after the login
self.client.api.token = self.az.as_token
await self.client.login(
login_type=LoginType.APPSERVICE,
device_name=self.device_name,
device_id=device_id,
store_access_token=True,
update_hs_url=False,
)
await self.crypto.load()
if not device_id:
await self.crypto_store.put_device_id(self.client.device_id)
self.log.debug(f"Logged in with new device ID {self.client.device_id}")
elif self.crypto.account.shared:
await self._verify_keys_are_on_server()
if self.appservice_mode:
self.log.info("End-to-bridge encryption support is enabled (appservice mode)")
else:
_ = self.client.start(self._filter)
self.log.info("End-to-bridge encryption support is enabled (sync mode)")
if self.delete_outdated_inbound:
deleted = await self.crypto_store.redact_outdated_group_sessions()
if len(deleted) > 0:
self.log.debug(
f"Deleted {len(deleted)} inbound keys which lacked expiration metadata"
)
if self.periodically_delete_expired_keys:
self._key_delete_task = background_task.create(self._periodically_delete_keys())
background_task.create(self._resync_encryption_info())
async def _resync_encryption_info(self) -> None:
rows = await self.crypto_db.fetch(
"""SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'"""
)
room_ids = [row["room_id"] for row in rows]
if not room_ids:
return
self.log.debug(f"Resyncing encryption state event in rooms: {room_ids}")
for room_id in room_ids:
try:
evt = await self.client.get_state_event(room_id, EventType.ROOM_ENCRYPTION)
except (MNotFound, MForbidden) as e:
self.log.debug(f"Failed to get encryption state in {room_id}: {e}")
q = """
UPDATE mx_room_state SET encryption=NULL
WHERE room_id=$1 AND encryption='{"resync":true}'
"""
await self.crypto_db.execute(q, room_id)
else:
self.log.debug(f"Resynced encryption state in {room_id}: {evt}")
q = """
UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2
WHERE room_id=$3 AND max_age IS NULL and max_messages IS NULL
"""
await self.crypto_db.execute(
q, evt.rotation_period_ms, evt.rotation_period_msgs, room_id
)
async def _verify_keys_are_on_server(self) -> None:
self.log.debug("Making sure keys are still on server")
try:
resp = await self.client.query_keys([self.client.mxid])
except Exception:
self.log.critical(
"Failed to query own keys to make sure device still exists", exc_info=True
)
sys.exit(33)
try:
own_keys = resp.device_keys[self.client.mxid][self.client.device_id]
if len(own_keys.keys) > 0:
return
except KeyError:
pass
self.log.critical("Existing device doesn't have keys on server, resetting crypto")
await self.crypto.crypto_store.delete()
await self.client.logout_all()
sys.exit(34)
async def stop(self) -> None:
if self._key_delete_task:
self._key_delete_task.cancel()
self._key_delete_task = None
self.client.stop()
await self.crypto_store.close()
if self.crypto_db:
await self.crypto_db.stop()
@property
def _filter(self) -> Filter:
all_events = EventType.find("*")
return Filter(
account_data=EventFilter(types=[all_events]),
presence=EventFilter(not_types=[all_events]),
room=RoomFilter(
include_leave=False,
state=StateFilter(not_types=[all_events]),
timeline=RoomEventFilter(not_types=[all_events]),
account_data=RoomEventFilter(not_types=[all_events]),
ephemeral=RoomEventFilter(not_types=[all_events]),
),
)
async def _periodically_delete_keys(self) -> None:
while True:
deleted = await self.crypto_store.redact_expired_group_sessions()
if deleted:
self.log.info(f"Deleted expired megolm sessions: {deleted}")
else:
self.log.debug("No expired megolm sessions found")
await asyncio.sleep(24 * 60 * 60)
python-0.20.4/mautrix/bridge/matrix.py 0000664 0000000 0000000 00000121675 14547234302 0017741 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from collections import defaultdict
import asyncio
import logging
import sys
import time
from mautrix import __optional_imports__
from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY, AppService
from mautrix.errors import (
DecryptionError,
IntentError,
MatrixError,
MExclusive,
MForbidden,
MUnknownToken,
SessionNotFound,
)
from mautrix.types import (
BaseRoomEvent,
BeeperMessageStatusEventContent,
EncryptedEvent,
Event,
EventID,
EventType,
MediaRepoConfig,
Membership,
MemberStateEventContent,
MessageEvent,
MessageEventContent,
MessageStatus,
MessageStatusReason,
MessageType,
PresenceEvent,
ReactionEvent,
ReceiptEvent,
ReceiptType,
RedactionEvent,
RelatesTo,
RelationType,
RoomID,
RoomType,
SingleReceiptEventContent,
SpecVersions,
StateEvent,
StateUnsigned,
TextMessageEventContent,
TrustState,
TypingEvent,
UserID,
Version,
VersionsResponse,
)
from mautrix.util import background_task, markdown
from mautrix.util.logging import TraceLogger
from mautrix.util.message_send_checkpoint import (
CHECKPOINT_TYPES,
MessageSendCheckpoint,
MessageSendCheckpointReportedBy,
MessageSendCheckpointStatus,
MessageSendCheckpointStep,
)
from mautrix.util.opt_prometheus import Histogram
from .. import bridge as br
from . import commands as cmd
encryption_import_error = None
media_encrypt_import_error = None
try:
from .e2ee import EncryptionManager
except ImportError as e:
if __optional_imports__:
raise
encryption_import_error = e
EncryptionManager = None
try:
from mautrix.crypto.attachments import encrypt_attachment
except ImportError as e:
if __optional_imports__:
raise
media_encrypt_import_error = e
encrypt_attachment = None
EVENT_TIME = Histogram(
"bridge_matrix_event", "Time spent processing Matrix events", ["event_type"]
)
class UnencryptedMessageError(DecryptionError):
def __init__(self) -> None:
super().__init__("unencrypted message")
@property
def human_message(self) -> str:
return "the message is not encrypted"
class EncryptionUnsupportedError(DecryptionError):
def __init__(self) -> None:
super().__init__("encryption is not supported")
@property
def human_message(self) -> str:
return "the bridge is not configured to support encryption"
class DeviceUntrustedError(DecryptionError):
def __init__(self, trust: TrustState) -> None:
explanation = {
TrustState.BLACKLISTED: "device is blacklisted",
TrustState.UNVERIFIED: "unverified",
TrustState.UNKNOWN_DEVICE: "device info not found",
TrustState.FORWARDED: "keys were forwarded from an unknown device",
TrustState.CROSS_SIGNED_UNTRUSTED: (
"cross-signing keys changed after setting up the bridge"
),
}.get(trust)
base = "your device is not trusted"
self.message = f"{base} ({explanation})" if explanation else base
super().__init__(self.message)
@property
def human_message(self) -> str:
return self.message
class BaseMatrixHandler:
log: TraceLogger = logging.getLogger("mau.mx")
az: AppService
commands: cmd.CommandProcessor
config: config.BaseBridgeConfig
bridge: br.Bridge
e2ee: EncryptionManager | None
require_e2ee: bool
media_config: MediaRepoConfig
versions: VersionsResponse
minimum_spec_version: Version = SpecVersions.V11
room_locks: dict[str, asyncio.Lock]
user_id_prefix: str
user_id_suffix: str
def __init__(
self,
command_processor: cmd.CommandProcessor | None = None,
bridge: br.Bridge | None = None,
) -> None:
self.az = bridge.az
self.config = bridge.config
self.bridge = bridge
self.commands = command_processor or cmd.CommandProcessor(bridge=bridge)
self.media_config = MediaRepoConfig(upload_size=50 * 1024 * 1024)
self.versions = VersionsResponse.deserialize({"versions": ["v1.3"]})
self.az.matrix_event_handler(self.int_handle_event)
self.room_locks = defaultdict(asyncio.Lock)
self.e2ee = None
self.require_e2ee = False
if self.config["bridge.encryption.allow"]:
if not EncryptionManager:
self.log.fatal(
"Encryption enabled in config, but dependencies not installed.",
exc_info=encryption_import_error,
)
sys.exit(31)
if not encrypt_attachment:
self.log.fatal(
"Encryption enabled in config, but media encryption dependencies "
"not installed.",
exc_info=media_encrypt_import_error,
)
sys.exit(31)
self.e2ee = EncryptionManager(
bridge=bridge,
user_id_prefix=self.user_id_prefix,
user_id_suffix=self.user_id_suffix,
homeserver_address=self.config["homeserver.address"],
db_url=self.config["appservice.database"],
)
self.require_e2ee = self.config["bridge.encryption.require"]
self.management_room_text = self.config.get(
"bridge.management_room_text",
{
"welcome": "Hello, I'm a bridge bot.",
"welcome_connected": "Use `help` for help.",
"welcome_unconnected": "Use `help` for help on how to log in.",
},
)
self.management_room_multiple_messages = self.config.get(
"bridge.management_room_multiple_messages",
False,
)
async def check_versions(self) -> None:
if not self.versions.supports_at_least(self.minimum_spec_version):
self.log.fatal(
"The homeserver is outdated "
"(server supports Matrix %s, but the bridge requires at least %s)",
self.versions.latest_version,
self.minimum_spec_version,
)
sys.exit(18)
if self.bridge.homeserver_software.is_hungry and not self.versions.supports(
"com.beeper.hungry"
):
self.log.fatal(
"The config claims the homeserver is hungryserv, "
"but the /versions response didn't confirm it"
)
sys.exit(18)
async def wait_for_connection(self) -> None:
self.log.info("Ensuring connectivity to homeserver")
while True:
try:
self.versions = await self.az.intent.versions()
break
except Exception:
self.log.exception("Connection to homeserver failed, retrying in 10 seconds")
await asyncio.sleep(10)
await self.check_versions()
try:
await self.az.intent.whoami()
except MForbidden:
self.log.debug(
"Whoami endpoint returned M_FORBIDDEN, "
"trying to register bridge bot before retrying..."
)
await self.az.intent.ensure_registered()
await self.az.intent.whoami()
if self.versions.supports("fi.mau.msc2659.stable") or self.versions.supports_at_least(
SpecVersions.V17
):
try:
txn_id = self.az.intent.api.get_txn_id()
duration = await self.az.ping_self(txn_id)
self.log.debug(
"Homeserver->bridge connection works, "
f"roundtrip time is {duration} ms (txn ID: {txn_id})"
)
except Exception:
self.log.exception("Error checking homeserver -> bridge connection")
sys.exit(16)
else:
self.log.debug(
"Homeserver does not support checking status of homeserver -> bridge connection"
)
try:
self.media_config = await self.az.intent.get_media_repo_config()
except Exception:
self.log.warning("Failed to fetch media repo config", exc_info=True)
async def init_as_bot(self) -> None:
self.log.debug("Initializing appservice bot")
displayname = self.config["appservice.bot_displayname"]
if displayname:
try:
await self.az.intent.set_displayname(
displayname if displayname != "remove" else ""
)
except Exception:
self.log.exception("Failed to set bot displayname")
avatar = self.config["appservice.bot_avatar"]
if avatar:
try:
await self.az.intent.set_avatar_url(avatar if avatar != "remove" else "")
except Exception:
self.log.exception("Failed to set bot avatar")
if self.bridge.homeserver_software.is_hungry and self.bridge.beeper_network_name:
self.log.debug("Setting contact info on the appservice bot")
await self.az.intent.beeper_update_profile(
{
"com.beeper.bridge.service": self.bridge.beeper_service_name,
"com.beeper.bridge.network": self.bridge.beeper_network_name,
"com.beeper.bridge.is_bridge_bot": True,
}
)
async def init_encryption(self) -> None:
if self.e2ee:
await self.e2ee.start()
async def allow_message(self, user: br.BaseUser) -> bool:
return user.is_whitelisted or (
self.config["bridge.relay.enabled"] and user.relay_whitelisted
)
@staticmethod
async def allow_command(user: br.BaseUser) -> bool:
return user.is_whitelisted
@staticmethod
async def allow_bridging_message(user: br.BaseUser, portal: br.BasePortal) -> bool:
return await user.is_logged_in() or (user.relay_whitelisted and portal.has_relay)
@staticmethod
async def allow_puppet_invite(user: br.BaseUser, puppet: br.BasePuppet) -> bool:
return await user.is_logged_in()
async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
pass
async def handle_kick(
self, room_id: RoomID, user_id: UserID, kicked_by: UserID, reason: str, event_id: EventID
) -> None:
pass
async def handle_ban(
self, room_id: RoomID, user_id: UserID, banned_by: UserID, reason: str, event_id: EventID
) -> None:
pass
async def handle_unban(
self, room_id: RoomID, user_id: UserID, unbanned_by: UserID, reason: str, event_id: EventID
) -> None:
pass
async def handle_join(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
pass
async def handle_knock(
self, room_id: RoomID, user_id: UserID, reason: str, event_id: EventID
) -> None:
pass
async def handle_retract_knock(
self, room_id: RoomID, user_id: UserID, reason: str, event_id: EventID
) -> None:
pass
async def handle_reject_knock(
self, room_id: RoomID, user_id: UserID, sender: UserID, reason: str, event_id: EventID
) -> None:
pass
async def handle_accept_knock(
self, room_id: RoomID, user_id: UserID, sender: UserID, reason: str, event_id: EventID
) -> None:
pass
async def handle_member_info_change(
self,
room_id: RoomID,
user_id: UserID,
content: MemberStateEventContent,
prev_content: MemberStateEventContent,
event_id: EventID,
) -> None:
pass
async def handle_puppet_group_invite(
self,
room_id: RoomID,
puppet: br.BasePuppet,
invited_by: br.BaseUser,
evt: StateEvent,
members: list[UserID],
) -> None:
if self.az.bot_mxid not in members:
await puppet.default_mxid_intent.leave_room(
room_id, reason="This ghost does not join multi-user rooms without the bridge bot."
)
async def handle_puppet_dm_invite(
self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent
) -> None:
portal = await invited_by.get_portal_with(puppet)
if portal:
await portal.accept_matrix_dm(room_id, invited_by, puppet)
else:
await puppet.default_mxid_intent.leave_room(
room_id, reason="This bridge does not support creating DMs."
)
async def handle_puppet_space_invite(
self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent
) -> None:
await puppet.default_mxid_intent.leave_room(
room_id, reason="This ghost does not join spaces."
)
async def handle_puppet_nonportal_invite(
self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent
) -> None:
intent = puppet.default_mxid_intent
await intent.join_room(room_id)
try:
create_evt = await intent.get_state_event(room_id, EventType.ROOM_CREATE)
members = await intent.get_room_members(room_id)
except MatrixError:
self.log.exception(f"Failed to get state after joining {room_id} as {intent.mxid}")
background_task.create(intent.leave_room(room_id, reason="Internal error"))
return
if create_evt.type == RoomType.SPACE:
await self.handle_puppet_space_invite(room_id, puppet, invited_by, evt)
elif len(members) > 2 or not evt.content.is_direct:
await self.handle_puppet_group_invite(room_id, puppet, invited_by, evt, members)
else:
await self.handle_puppet_dm_invite(room_id, puppet, invited_by, evt)
async def handle_puppet_invite(
self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent
) -> None:
intent = puppet.default_mxid_intent
if not await self.allow_puppet_invite(invited_by, puppet):
self.log.debug(f"Rejecting invite for {intent.mxid} to {room_id}: user can't invite")
await intent.leave_room(room_id, reason="You're not allowed to invite this ghost.")
return
async with self.room_locks[room_id]:
portal = await self.bridge.get_portal(room_id)
if portal:
try:
await portal.handle_matrix_invite(invited_by, puppet)
except br.RejectMatrixInvite as e:
await intent.leave_room(room_id, reason=e.message)
except br.IgnoreMatrixInvite:
pass
else:
await intent.join_room(room_id)
return
else:
await self.handle_puppet_nonportal_invite(room_id, puppet, invited_by, evt)
async def handle_invite(
self, room_id: RoomID, user_id: UserID, invited_by: br.BaseUser, evt: StateEvent
) -> None:
pass
async def handle_reject(
self, room_id: RoomID, user_id: UserID, reason: str, event_id: EventID
) -> None:
pass
async def handle_disinvite(
self,
room_id: RoomID,
user_id: UserID,
disinvited_by: UserID,
reason: str,
event_id: EventID,
) -> None:
pass
async def handle_event(self, evt: Event) -> None:
"""
Called by :meth:`int_handle_event` for message events other than m.room.message.
**N.B.** You may need to add the event class to :attr:`allowed_event_classes`
or override :meth:`allow_matrix_event` for it to reach here.
"""
async def handle_state_event(self, evt: StateEvent) -> None:
"""
Called by :meth:`int_handle_event` for state events other than m.room.membership.
**N.B.** You may need to add the event class to :attr:`allowed_event_classes`
or override :meth:`allow_matrix_event` for it to reach here.
"""
async def handle_ephemeral_event(
self, evt: ReceiptEvent | PresenceEvent | TypingEvent
) -> None:
if evt.type == EventType.RECEIPT:
await self.handle_receipt(evt)
async def send_permission_error(self, room_id: RoomID) -> None:
await self.az.intent.send_notice(
room_id,
text=(
"You are not whitelisted to use this bridge.\n\n"
"If you are the owner of this bridge, see the bridge.permissions "
"section in your config file."
),
html=(
"You are not whitelisted to use this bridge.
"
"If you are the owner of this bridge, see the "
"bridge.permissions
section in your config file.
"
),
)
async def accept_bot_invite(self, room_id: RoomID, inviter: br.BaseUser) -> None:
try:
await self.az.intent.join_room(room_id)
except Exception:
self.log.exception(f"Failed to join room {room_id} as bridge bot")
return
if not await self.allow_command(inviter):
await self.send_permission_error(room_id)
await self.az.intent.leave_room(room_id)
return
await self.send_welcome_message(room_id, inviter)
async def send_welcome_message(self, room_id: RoomID, inviter: br.BaseUser) -> None:
has_two_members, bridge_bot_in_room = await self._is_direct_chat(room_id)
is_management = has_two_members and bridge_bot_in_room
welcome_messages = [self.management_room_text.get("welcome")]
if is_management:
if await inviter.is_logged_in():
welcome_messages.append(self.management_room_text.get("welcome_connected"))
else:
welcome_messages.append(self.management_room_text.get("welcome_unconnected"))
additional_help = self.management_room_text.get("additional_help")
if additional_help:
welcome_messages.append(additional_help)
else:
cmd_prefix = self.commands.command_prefix
welcome_messages.append(f"Use `{cmd_prefix} help` for help.")
if self.management_room_multiple_messages:
for m in welcome_messages:
await self.az.intent.send_notice(room_id, text=m, html=markdown.render(m))
else:
combined = "\n".join(welcome_messages)
combined_html = "".join(map(markdown.render, welcome_messages))
await self.az.intent.send_notice(room_id, text=combined, html=combined_html)
async def int_handle_invite(self, evt: StateEvent) -> None:
self.log.debug(f"{evt.sender} invited {evt.state_key} to {evt.room_id}")
inviter = await self.bridge.get_user(evt.sender)
if inviter is None:
self.log.exception(f"Failed to find user with Matrix ID {evt.sender}")
return
elif evt.state_key == self.az.bot_mxid:
await self.accept_bot_invite(evt.room_id, inviter)
return
puppet = await self.bridge.get_puppet(UserID(evt.state_key))
if puppet:
await self.handle_puppet_invite(evt.room_id, puppet, inviter, evt)
return
await self.handle_invite(evt.room_id, UserID(evt.state_key), inviter, evt)
def is_command(self, message: MessageEventContent) -> tuple[bool, str]:
text = message.body
prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix)
if is_command:
text = text[len(prefix) + 1 :].lstrip()
return is_command, text
async def _send_mss(
self,
evt: Event,
status: MessageStatus,
reason: MessageStatusReason | None = None,
error: str | None = None,
message: str | None = None,
) -> None:
if not self.config.get("bridge.message_status_events", False):
return
status_content = BeeperMessageStatusEventContent(
network="", # TODO set network properly
relates_to=RelatesTo(rel_type=RelationType.REFERENCE, event_id=evt.event_id),
status=status,
reason=reason,
error=error,
message=message,
)
await self.az.intent.send_message_event(
evt.room_id, EventType.BEEPER_MESSAGE_STATUS, status_content
)
async def _send_crypto_status_error(
self,
evt: Event,
err: DecryptionError | None = None,
retry_num: int = 0,
is_final: bool = True,
edit: EventID | None = None,
wait_for: int | None = None,
) -> EventID | None:
msg = str(err)
if isinstance(err, (SessionNotFound, UnencryptedMessageError)):
msg = err.human_message
self._send_message_checkpoint(
evt, MessageSendCheckpointStep.DECRYPTED, msg, permanent=is_final, retry_num=retry_num
)
if wait_for:
msg += f". The bridge will retry for {wait_for} seconds"
full_msg = f"\u26a0 Your message was not bridged: {msg}."
if isinstance(err, EncryptionUnsupportedError):
full_msg = "🔒️ This bridge has not been configured to support encryption"
event_id = None
if self.config.get("bridge.delivery_error_reports", True):
try:
content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=full_msg)
if edit:
content.set_edit(edit)
event_id = await self.az.intent.send_message(evt.room_id, content)
except IntentError:
self.log.debug("IntentError while sending encryption error", exc_info=True)
self.log.error(
"Got IntentError while trying to send encryption error message. "
"This likely means the bridge bot is not in the room, which can "
"happen if you force-enable e2ee on the homeserver without enabling "
"it by default on the bridge (bridge -> encryption -> default)."
)
await self._send_mss(
evt,
status=MessageStatus.RETRIABLE if is_final else MessageStatus.PENDING,
reason=MessageStatusReason.UNDECRYPTABLE,
error=str(err),
message=err.human_message if err else None,
)
return event_id
async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) -> None:
room_id = evt.room_id
user_id = evt.sender
event_id = evt.event_id
message = evt.content
if not was_encrypted and self.require_e2ee:
self.log.warning(f"Dropping {event_id} from {user_id} as it's not encrypted!")
await self._send_crypto_status_error(evt, UnencryptedMessageError(), 0)
return
sender = await self.bridge.get_user(user_id)
if not sender or not await self.allow_message(sender):
self.log.debug(
f"Ignoring message {event_id} from {user_id} to {room_id}:"
" user is not whitelisted."
)
self._send_message_checkpoint(
evt, MessageSendCheckpointStep.BRIDGE, "user is not whitelisted"
)
return
self.log.debug(f"Received Matrix event {event_id} from {sender.mxid} in {room_id}")
self.log.trace("Event %s content: %s", event_id, message)
if isinstance(message, TextMessageEventContent):
message.trim_reply_fallback()
is_command, text = self.is_command(message)
portal = await self.bridge.get_portal(room_id)
if not is_command and portal:
if await self.allow_bridging_message(sender, portal):
await portal.handle_matrix_message(sender, message, event_id)
else:
self.log.debug(
f"Ignoring event {event_id} from {sender.mxid}:"
" not allowed to send to portal"
)
self._send_message_checkpoint(
evt,
MessageSendCheckpointStep.BRIDGE,
"user is not allowed to send to the portal",
)
return
if message.msgtype != MessageType.TEXT:
self.log.debug(
f"Ignoring event {event_id}: not a portal room and not a m.text message"
)
self._send_message_checkpoint(
evt, MessageSendCheckpointStep.BRIDGE, "not a portal room and not a m.text message"
)
return
elif not await self.allow_command(sender):
self.log.debug(
f"Ignoring command {event_id} from {sender.mxid}: not allowed to run commands"
)
self._send_message_checkpoint(
evt, MessageSendCheckpointStep.COMMAND, "not allowed to run commands"
)
return
has_two_members, bridge_bot_in_room = await self._is_direct_chat(room_id)
is_management = has_two_members and bridge_bot_in_room
if is_command or is_management:
try:
command, arguments = text.split(" ", 1)
args = arguments.split(" ")
except ValueError:
# Not enough values to unpack, i.e. no arguments
command = text
args = []
try:
await self.commands.handle(
room_id,
event_id,
sender,
command,
args,
message,
portal,
is_management,
bridge_bot_in_room,
)
except Exception as e:
self.log.debug(f"Error handling command {command} from {sender}: {e}")
self._send_message_checkpoint(evt, MessageSendCheckpointStep.COMMAND, e)
await self._send_mss(
evt,
status=MessageStatus.FAIL,
reason=MessageStatusReason.GENERIC_ERROR,
error="",
message="Command execution failed",
)
else:
await MessageSendCheckpoint(
event_id=event_id,
room_id=room_id,
step=MessageSendCheckpointStep.COMMAND,
timestamp=int(time.time() * 1000),
status=MessageSendCheckpointStatus.SUCCESS,
reported_by=MessageSendCheckpointReportedBy.BRIDGE,
event_type=EventType.ROOM_MESSAGE,
message_type=message.msgtype,
).send(
self.bridge.config["homeserver.message_send_checkpoint_endpoint"],
self.az.as_token,
self.log,
)
await self._send_mss(evt, status=MessageStatus.SUCCESS)
else:
self.log.debug(
f"Ignoring event {event_id} from {sender.mxid}:"
" not a command and not a portal room"
)
self._send_message_checkpoint(
evt, MessageSendCheckpointStep.COMMAND, "not a command and not a portal room"
)
await self._send_mss(
evt,
status=MessageStatus.FAIL,
reason=MessageStatusReason.UNSUPPORTED,
error="Unknown room",
message="Unknown room",
)
async def _is_direct_chat(self, room_id: RoomID) -> tuple[bool, bool]:
try:
members = await self.az.intent.get_room_members(room_id)
return len(members) == 2, self.az.bot_mxid in members
except MatrixError:
return False, False
async def handle_receipt(self, evt: ReceiptEvent) -> None:
for event_id, receipts in evt.content.items():
for user_id, data in receipts.get(ReceiptType.READ, {}).items():
user = await self.bridge.get_user(user_id, create=False)
if not user or not await user.is_logged_in():
continue
portal = await self.bridge.get_portal(evt.room_id)
if not portal:
continue
await portal.schedule_disappearing()
if (
data.get(DOUBLE_PUPPET_SOURCE_KEY) == self.az.bridge_name
and await self.bridge.get_double_puppet(user_id) is not None
):
continue
await self.handle_read_receipt(user, portal, event_id, data)
async def handle_read_receipt(
self,
user: br.BaseUser,
portal: br.BasePortal,
event_id: EventID,
data: SingleReceiptEventContent,
) -> None:
pass
async def try_handle_sync_event(self, evt: Event) -> None:
try:
if isinstance(evt, (ReceiptEvent, PresenceEvent, TypingEvent)):
await self.handle_ephemeral_event(evt)
else:
self.log.trace("Unknown event type received from sync: %s", evt)
except Exception:
self.log.exception("Error handling manually received Matrix event")
async def _post_decrypt(
self, evt: Event, retry_num: int = 0, error_event_id: EventID | None = None
) -> None:
trust_state = evt["mautrix"]["trust_state"]
if trust_state < self.e2ee.min_send_trust:
self.log.warning(
f"Dropping {evt.event_id} from {evt.sender} due to insufficient verification level"
f" (event: {trust_state}, required: {self.e2ee.min_send_trust})"
)
await self._send_crypto_status_error(
evt,
retry_num=retry_num,
err=DeviceUntrustedError(trust_state),
edit=error_event_id,
)
return
self._send_message_checkpoint(
evt, MessageSendCheckpointStep.DECRYPTED, retry_num=retry_num
)
if error_event_id:
await self.az.intent.redact(evt.room_id, error_event_id)
await self.int_handle_event(evt, was_encrypted=True)
async def handle_encrypted(self, evt: EncryptedEvent) -> None:
if not self.e2ee:
self.log.debug(
"Got encrypted message %s from %s, but encryption is not enabled",
evt.event_id,
evt.sender,
)
await self._send_crypto_status_error(evt, EncryptionUnsupportedError())
return
try:
decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=3)
except SessionNotFound as e:
await self._handle_encrypted_wait(evt, e, wait=22)
except DecryptionError as e:
self.log.warning(f"Failed to decrypt {evt.event_id}: {e}")
self.log.trace("%s decryption traceback:", evt.event_id, exc_info=True)
await self._send_crypto_status_error(evt, e)
else:
await self._post_decrypt(decrypted)
async def _handle_encrypted_wait(
self, evt: EncryptedEvent, err: SessionNotFound, wait: int
) -> None:
self.log.debug(
f"Couldn't find session {err.session_id} trying to decrypt {evt.event_id},"
" waiting even longer"
)
background_task.create(
self.e2ee.crypto.request_room_key(
evt.room_id,
evt.content.sender_key,
evt.content.session_id,
from_devices={evt.sender: [evt.content.device_id]},
)
)
event_id = await self._send_crypto_status_error(evt, err, is_final=False, wait_for=wait)
got_keys = await self.e2ee.crypto.wait_for_session(
evt.room_id, err.session_id, timeout=wait
)
if got_keys:
self.log.debug(
f"Got session {err.session_id} after waiting more, "
f"trying to decrypt {evt.event_id} again"
)
try:
decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=0)
except DecryptionError as e:
await self._send_crypto_status_error(evt, e, retry_num=1, edit=event_id)
self.log.warning(f"Failed to decrypt {evt.event_id}: {e}")
self.log.trace("%s decryption traceback:", evt.event_id, exc_info=True)
else:
await self._post_decrypt(decrypted, retry_num=1, error_event_id=event_id)
return
else:
self.log.warning(f"Didn't get {err.session_id}, giving up on {evt.event_id}")
await self._send_crypto_status_error(
evt, SessionNotFound(err.session_id), retry_num=1, edit=event_id
)
async def handle_encryption(self, evt: StateEvent) -> None:
await self.az.state_store.set_encryption_info(evt.room_id, evt.content)
portal = await self.bridge.get_portal(evt.room_id)
if portal:
portal.encrypted = True
await portal.save()
if portal.is_direct:
portal.log.debug("Received encryption event in direct portal: %s", evt.content)
await portal.enable_dm_encryption()
def _send_message_checkpoint(
self,
evt: Event,
step: MessageSendCheckpointStep,
err: Exception | str | None = None,
permanent: bool = True,
retry_num: int = 0,
) -> None:
endpoint = self.bridge.config["homeserver.message_send_checkpoint_endpoint"]
if not endpoint:
return
if evt.type not in CHECKPOINT_TYPES:
return
self.log.debug(f"Sending message send checkpoint for {evt.event_id} (step: {step})")
status = MessageSendCheckpointStatus.SUCCESS
if err:
status = (
MessageSendCheckpointStatus.PERM_FAILURE
if permanent
else MessageSendCheckpointStatus.WILL_RETRY
)
checkpoint = MessageSendCheckpoint(
event_id=evt.event_id,
room_id=evt.room_id,
step=step,
timestamp=int(time.time() * 1000),
status=status,
reported_by=MessageSendCheckpointReportedBy.BRIDGE,
event_type=evt.type,
message_type=evt.content.msgtype if evt.type == EventType.ROOM_MESSAGE else None,
info=str(err) if err else None,
retry_num=retry_num,
)
background_task.create(checkpoint.send(endpoint, self.az.as_token, self.log))
allowed_event_classes: tuple[type, ...] = (
MessageEvent,
StateEvent,
ReactionEvent,
EncryptedEvent,
RedactionEvent,
ReceiptEvent,
TypingEvent,
PresenceEvent,
)
async def allow_matrix_event(self, evt: Event) -> bool:
# If the event is not one of the allowed classes, ignore it.
if not isinstance(evt, self.allowed_event_classes):
return False
# For room events, make sure the message didn't originate from the bridge.
if isinstance(evt, BaseRoomEvent):
# If the event is from a bridge ghost, ignore it.
if evt.sender == self.az.bot_mxid or self.bridge.is_bridge_ghost(evt.sender):
return False
# If the event is marked as double puppeted and we can confirm that we are in fact
# double puppeting that user ID, ignore it.
if (
evt.content.get(DOUBLE_PUPPET_SOURCE_KEY) == self.az.bridge_name
and await self.bridge.get_double_puppet(evt.sender) is not None
):
return False
# For non-room events and non-bridge-originated room events, allow.
return True
async def int_handle_event(self, evt: Event, was_encrypted: bool = False) -> None:
if isinstance(evt, StateEvent) and evt.type == EventType.ROOM_MEMBER and self.e2ee:
await self.e2ee.handle_member_event(evt)
if not await self.allow_matrix_event(evt):
return
self.log.trace("Received event: %s", evt)
if not was_encrypted:
self._send_message_checkpoint(evt, MessageSendCheckpointStep.BRIDGE)
start_time = time.time()
if evt.type == EventType.ROOM_MEMBER:
evt: StateEvent
unsigned = evt.unsigned or StateUnsigned()
prev_content = unsigned.prev_content or MemberStateEventContent()
prev_membership = prev_content.membership if prev_content else Membership.JOIN
if evt.content.membership == Membership.INVITE:
if prev_membership == Membership.KNOCK:
await self.handle_accept_knock(
evt.room_id,
UserID(evt.state_key),
evt.sender,
evt.content.reason,
evt.event_id,
)
else:
await self.int_handle_invite(evt)
elif evt.content.membership == Membership.LEAVE:
if prev_membership == Membership.BAN:
await self.handle_unban(
evt.room_id,
UserID(evt.state_key),
evt.sender,
evt.content.reason,
evt.event_id,
)
elif prev_membership == Membership.INVITE:
if evt.sender == evt.state_key:
await self.handle_reject(
evt.room_id, UserID(evt.state_key), evt.content.reason, evt.event_id
)
else:
await self.handle_disinvite(
evt.room_id,
UserID(evt.state_key),
evt.sender,
evt.content.reason,
evt.event_id,
)
elif prev_membership == Membership.KNOCK:
if evt.sender == evt.state_key:
await self.handle_retract_knock(
evt.room_id, UserID(evt.state_key), evt.content.reason, evt.event_id
)
else:
await self.handle_reject_knock(
evt.room_id,
UserID(evt.state_key),
evt.sender,
evt.content.reason,
evt.event_id,
)
elif evt.sender == evt.state_key:
await self.handle_leave(evt.room_id, UserID(evt.state_key), evt.event_id)
else:
await self.handle_kick(
evt.room_id,
UserID(evt.state_key),
evt.sender,
evt.content.reason,
evt.event_id,
)
elif evt.content.membership == Membership.BAN:
await self.handle_ban(
evt.room_id,
UserID(evt.state_key),
evt.sender,
evt.content.reason,
evt.event_id,
)
elif evt.content.membership == Membership.JOIN:
if prev_membership != Membership.JOIN:
await self.handle_join(evt.room_id, UserID(evt.state_key), evt.event_id)
else:
await self.handle_member_info_change(
evt.room_id, UserID(evt.state_key), evt.content, prev_content, evt.event_id
)
elif evt.content.membership == Membership.KNOCK:
await self.handle_knock(
evt.room_id,
UserID(evt.state_key),
evt.content.reason,
evt.event_id,
)
elif evt.type in (EventType.ROOM_MESSAGE, EventType.STICKER):
evt: MessageEvent
if evt.type != EventType.ROOM_MESSAGE:
evt.content.msgtype = MessageType(str(evt.type))
await self.handle_message(evt, was_encrypted=was_encrypted)
elif evt.type == EventType.ROOM_ENCRYPTED:
await self.handle_encrypted(evt)
elif evt.type == EventType.ROOM_ENCRYPTION:
await self.handle_encryption(evt)
else:
if evt.type.is_state and isinstance(evt, StateEvent):
await self.handle_state_event(evt)
elif evt.type.is_ephemeral and isinstance(
evt, (PresenceEvent, TypingEvent, ReceiptEvent)
):
await self.handle_ephemeral_event(evt)
else:
await self.handle_event(evt)
await self.log_event_handle_duration(evt, time.time() - start_time)
async def log_event_handle_duration(self, evt: Event, duration: float) -> None:
EVENT_TIME.labels(event_type=str(evt.type)).observe(duration)
python-0.20.4/mautrix/bridge/notification_disabler.py 0000664 0000000 0000000 00000005247 14547234302 0022764 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Type
import logging
from mautrix.api import Method, Path, PathBuilder
from mautrix.appservice import IntentAPI
from mautrix.types import RoomID, UserID
from mautrix.util.logging import TraceLogger
from .puppet import BasePuppet
from .user import BaseUser
class NotificationDisabler:
puppet_cls: Type[BasePuppet]
config_enabled: bool = False
log: TraceLogger = logging.getLogger("mau.notification_disabler")
user_id: UserID
room_id: RoomID
intent: IntentAPI | None
enabled: bool
def __init__(self, room_id: RoomID, user: BaseUser) -> None:
self.user_id = user.mxid
self.room_id = room_id
self.enabled = False
@property
def _path(self) -> PathBuilder:
return Path.v3.pushrules["global"].override[
f"net.maunium.silence_while_backfilling:{self.room_id}"
]
@property
def _rule(self) -> dict:
return {
"actions": ["dont_notify"],
"conditions": [
{
"kind": "event_match",
"key": "room_id",
"pattern": self.room_id,
}
],
}
async def __aenter__(self) -> None:
puppet = await self.puppet_cls.get_by_custom_mxid(self.user_id)
self.intent = puppet.intent if puppet and puppet.is_real_user else None
if not self.intent or not self.config_enabled:
return
self.enabled = True
try:
self.log.debug(f"Disabling notifications in {self.room_id} for {self.intent.mxid}")
await self.intent.api.request(Method.PUT, self._path, content=self._rule)
except Exception:
self.log.warning(
f"Failed to disable notifications in {self.room_id} "
f"for {self.intent.mxid} while backfilling",
exc_info=True,
)
raise
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
if not self.enabled:
return
try:
self.log.debug(f"Re-enabling notifications in {self.room_id} for {self.intent.mxid}")
await self.intent.api.request(Method.DELETE, self._path)
except Exception:
self.log.warning(
f"Failed to re-enable notifications in {self.room_id} "
f"for {self.intent.mxid} after backfilling",
exc_info=True,
)
python-0.20.4/mautrix/bridge/portal.py 0000664 0000000 0000000 00000050140 14547234302 0017722 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, NamedTuple
from abc import ABC, abstractmethod
from collections import defaultdict
from string import Template
import asyncio
import html
import logging
import time
from mautrix.appservice import AppService, IntentAPI
from mautrix.errors import MatrixError, MatrixRequestError, MForbidden, MNotFound
from mautrix.types import (
JSON,
EncryptionAlgorithm,
EventID,
EventType,
Format,
MessageEventContent,
MessageType,
RoomEncryptionStateEventContent,
RoomID,
RoomTombstoneStateEventContent,
TextMessageEventContent,
UserID,
)
from mautrix.util import background_task
from mautrix.util.logging import TraceLogger
from mautrix.util.simple_lock import SimpleLock
from .. import bridge as br
class RelaySender(NamedTuple):
sender: br.BaseUser | None
is_relay: bool
class RejectMatrixInvite(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message
class IgnoreMatrixInvite(Exception):
pass
class DMCreateError(RejectMatrixInvite):
"""
An error raised by :meth:`BasePortal.prepare_dm` if the DM can't be set up.
The message in the exception will be sent to the user as a message before the ghost leaves.
"""
class BasePortal(ABC):
log: TraceLogger = logging.getLogger("mau.portal")
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
disappearing_msg_class: type[br.AbstractDisappearingMessage] | None = None
_disappearing_lock: asyncio.Lock | None
az: AppService
matrix: br.BaseMatrixHandler
bridge: br.Bridge
loop: asyncio.AbstractEventLoop
main_intent: IntentAPI
mxid: RoomID | None
name: str | None
encrypted: bool
is_direct: bool
backfill_lock: SimpleLock
relay_user_id: UserID | None
_relay_user: br.BaseUser | None
relay_emote_to_text: bool = True
relay_formatted_body: bool = True
def __init__(self) -> None:
self._disappearing_lock = asyncio.Lock() if self.disappearing_msg_class else None
@abstractmethod
async def save(self) -> None:
pass
@abstractmethod
async def get_dm_puppet(self) -> br.BasePuppet | None:
"""
Get the ghost representing the other end of this direct chat.
Returns:
A puppet entity, or ``None`` if this is not a 1:1 chat.
"""
@abstractmethod
async def handle_matrix_message(
self, sender: br.BaseUser, message: MessageEventContent, event_id: EventID
) -> None:
pass
async def prepare_remote_dm(
self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet
) -> str:
"""
Do whatever is needed on the remote platform to set up a direct chat between the user
and the ghost. By default, this does nothing (and lets :meth:`setup_matrix_dm` handle
everything).
Args:
room_id: The room ID that will be used.
invited_by: The Matrix user who invited the ghost.
puppet: The ghost who was invited.
Returns:
A simple message indicating what was done (will be sent as a notice to the room).
If empty, the message won't be sent.
Raises:
DMCreateError: if the DM could not be created and the ghost should leave the room.
"""
return "Portal to private chat created."
async def postprocess_matrix_dm(self, user: br.BaseUser, puppet: br.BasePuppet) -> None:
await self.update_bridge_info()
async def reject_duplicate_dm(
self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet
) -> None:
try:
await puppet.default_mxid_intent.send_notice(
room_id,
text=f"You already have a private chat with me: {self.mxid}",
html=(
"You already have a private chat with me: "
f"Link to room"
),
)
except Exception as e:
self.log.debug(f"Failed to send notice to duplicate private chat room: {e}")
try:
await puppet.default_mxid_intent.send_state_event(
room_id,
event_type=EventType.ROOM_TOMBSTONE,
content=RoomTombstoneStateEventContent(
replacement_room=self.mxid,
body="You already have a private chat with me",
),
)
except Exception as e:
self.log.debug(f"Failed to send tombstone to duplicate private chat room: {e}")
await puppet.default_mxid_intent.leave_room(room_id)
async def accept_matrix_dm(
self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet
) -> None:
"""
Set up a room as a direct chat portal.
The ghost has already accepted the invite at this point, so this method needs to make it
leave if the DM can't be created for some reason.
By default, this checks if there's an existing portal and redirects the user there if it
does exist. If a portal doesn't exist, this will call :meth:`prepare_matrix_dm` and then
save the room ID, enable encryption and update bridge info. If the portal exists, but isn't
usable, the old room will be cleaned up and the function will continue.
Args:
room_id: The room ID that will be used.
invited_by: The Matrix user who invited the ghost.
puppet: The ghost who was invited.
"""
if self.mxid:
try:
portal_members = await self.main_intent.get_room_members(self.mxid)
except (MForbidden, MNotFound):
portal_members = []
if invited_by.mxid in portal_members:
await self.reject_duplicate_dm(room_id, invited_by, puppet)
return
self.log.debug(
f"{invited_by.mxid} isn't in old portal room {self.mxid},"
" cleaning up and accepting new room as the DM portal"
)
await self.cleanup_portal(
message="User seems to have left DM portal", puppets_only=True
)
try:
message = await self.prepare_remote_dm(room_id, invited_by, puppet)
except DMCreateError as e:
if e.message:
await puppet.default_mxid_intent.send_notice(room_id, text=e.message)
await puppet.default_mxid_intent.leave_room(room_id, reason="Failed to create DM")
return
self.mxid = room_id
e2be_ok = await self.check_dm_encryption()
await self.save()
if e2be_ok is False:
message += "\n\nWarning: Failed to enable end-to-bridge encryption."
if message:
await self._send_message(
puppet.default_mxid_intent,
TextMessageEventContent(
msgtype=MessageType.NOTICE,
body=message,
),
)
await self.postprocess_matrix_dm(invited_by, puppet)
async def handle_matrix_invite(self, invited_by: br.BaseUser, puppet: br.BasePuppet) -> None:
"""
Called when a Matrix user invites a bridge ghost to a room to process the invite (and check
if it should be accepted).
Args:
invited_by: The user who invited the ghost.
puppet: The ghost who was invited.
Raises:
RejectMatrixInvite: if the invite should be rejected.
IgnoreMatrixInvite: if the invite should be ignored (e.g. if it was already accepted).
"""
if self.is_direct:
raise RejectMatrixInvite("You can't invite additional users to private chats.")
raise RejectMatrixInvite("This bridge does not implement inviting users to portals.")
async def update_bridge_info(self) -> None:
"""Resend the ``m.bridge`` event into the room."""
@property
def _relay_is_implemented(self) -> bool:
return hasattr(self, "relay_user_id") and hasattr(self, "_relay_user")
@property
def has_relay(self) -> bool:
return (
self._relay_is_implemented
and self.bridge.config["bridge.relay.enabled"]
and bool(self.relay_user_id)
)
async def get_relay_user(self) -> br.BaseUser | None:
if not self.has_relay:
return None
if self._relay_user is None:
self._relay_user = await self.bridge.get_user(self.relay_user_id)
return self._relay_user if await self._relay_user.is_logged_in() else None
async def set_relay_user(self, user: br.BaseUser | None) -> None:
if not self._relay_is_implemented or not self.bridge.config["bridge.relay.enabled"]:
raise RuntimeError("Can't set_relay_user() when relay mode is not enabled")
self._relay_user = user
self.relay_user_id = user.mxid if user else None
await self.save()
async def get_relay_sender(self, sender: br.BaseUser, evt_identifier: str) -> RelaySender:
if not await sender.needs_relay(self):
return RelaySender(sender, False)
if not self.has_relay:
self.log.debug(
f"Ignoring {evt_identifier} from non-logged-in user {sender.mxid} "
f"in chat with no relay user"
)
return RelaySender(None, True)
relay_sender = await self.get_relay_user()
if not relay_sender:
self.log.debug(
f"Ignoring {evt_identifier} from non-logged-in user {sender.mxid} "
f"relay user {self.relay_user_id} is not set up correctly"
)
return RelaySender(None, True)
return RelaySender(relay_sender, True)
async def apply_relay_message_format(
self, sender: br.BaseUser, content: MessageEventContent
) -> None:
if self.relay_formatted_body and content.get("format", None) != Format.HTML:
content["format"] = Format.HTML
content["formatted_body"] = html.escape(content.body).replace("\n", "
")
tpl = self.bridge.config["bridge.relay.message_formats"].get(
content.msgtype.value, "$sender_displayname: $message"
)
displayname = await self.get_displayname(sender)
username, _ = self.az.intent.parse_user_id(sender.mxid)
tpl_args = {
"sender_mxid": sender.mxid,
"sender_username": username,
"sender_displayname": html.escape(displayname),
"formatted_body": content["formatted_body"],
"body": content.body,
"message": content.body,
}
content.body = Template(tpl).safe_substitute(tpl_args)
if self.relay_formatted_body and "formatted_body" in content:
tpl_args["message"] = content["formatted_body"]
content["formatted_body"] = Template(tpl).safe_substitute(tpl_args)
if self.relay_emote_to_text and content.msgtype == MessageType.EMOTE:
content.msgtype = MessageType.TEXT
async def get_displayname(self, user: br.BaseUser) -> str:
return await self.main_intent.get_room_displayname(self.mxid, user.mxid) or user.mxid
async def check_dm_encryption(self) -> bool | None:
try:
evt = await self.main_intent.get_state_event(self.mxid, EventType.ROOM_ENCRYPTION)
self.log.debug("Found existing encryption event in direct portal: %s", evt)
if evt and evt.algorithm == EncryptionAlgorithm.MEGOLM_V1:
self.encrypted = True
except MNotFound:
pass
if (
self.is_direct
and self.matrix.e2ee
and (self.bridge.config["bridge.encryption.default"] or self.encrypted)
):
return await self.enable_dm_encryption()
return None
def get_encryption_state_event_json(self) -> JSON:
evt = RoomEncryptionStateEventContent(EncryptionAlgorithm.MEGOLM_V1)
if self.bridge.config["bridge.encryption.rotation.enable_custom"]:
evt.rotation_period_ms = self.bridge.config["bridge.encryption.rotation.milliseconds"]
evt.rotation_period_msgs = self.bridge.config["bridge.encryption.rotation.messages"]
return evt.serialize()
async def enable_dm_encryption(self) -> bool:
self.log.debug("Inviting bridge bot to room for end-to-bridge encryption")
try:
await self.main_intent.invite_user(self.mxid, self.az.bot_mxid)
await self.az.intent.join_room_by_id(self.mxid)
if not self.encrypted:
await self.main_intent.send_state_event(
self.mxid,
EventType.ROOM_ENCRYPTION,
self.get_encryption_state_event_json(),
)
except Exception:
self.log.warning(f"Failed to enable end-to-bridge encryption", exc_info=True)
return False
self.encrypted = True
await self.update_info_from_puppet()
return True
async def update_info_from_puppet(self, puppet: br.BasePuppet | None = None) -> None:
"""
Update the room metadata to match the ghost's name/avatar.
This is called after enabling encryption, as the bridge bot needs to join for e2ee,
but that messes up the default name generation. If/when canonical DMs happen,
this might not be necessary anymore.
Args:
puppet: The ghost that is the other participant in the room.
If ``None``, the entity should be fetched as necessary.
"""
@property
def disappearing_enabled(self) -> bool:
return bool(self.disappearing_msg_class)
async def _disappear_event(self, msg: br.AbstractDisappearingMessage) -> None:
sleep_time = (msg.expiration_ts / 1000) - time.time()
self.log.trace(f"Sleeping {sleep_time:.3f} seconds before redacting {msg.event_id}")
await asyncio.sleep(sleep_time)
try:
await msg.delete()
except Exception:
self.log.exception(
f"Failed to delete disappearing message record for {msg.event_id} from database"
)
if self.mxid != msg.room_id:
self.log.debug(
f"Not redacting expired event {msg.event_id}, "
f"portal room seems to have changed ({self.mxid!r} != {msg.room_id!r})"
)
return
try:
await self._do_disappear(msg.event_id)
self.log.debug(f"Expired event {msg.event_id} disappeared successfully")
except Exception as e:
self.log.warning(f"Failed to make expired event {msg.event_id} disappear: {e}")
async def _do_disappear(self, event_id: EventID) -> None:
await self.main_intent.redact(self.mxid, event_id)
@classmethod
async def restart_scheduled_disappearing(cls) -> None:
"""
Restart disappearing message timers for all messages that were already scheduled to
disappear earlier. This should be called at bridge startup.
"""
if not cls.disappearing_msg_class:
return
msgs = await cls.disappearing_msg_class.get_all_scheduled()
for msg in msgs:
portal = await cls.bridge.get_portal(msg.room_id)
if portal and portal.mxid:
background_task.create(portal._disappear_event(msg))
else:
await msg.delete()
async def schedule_disappearing(self) -> None:
"""
Start the disappearing message timer for all unscheduled messages in this room.
This is automatically called from :meth:`MatrixHandler.handle_receipt`.
"""
if not self.disappearing_msg_class:
return
async with self._disappearing_lock:
msgs = await self.disappearing_msg_class.get_unscheduled_for_room(self.mxid)
for msg in msgs:
msg.start_timer()
await msg.update()
background_task.create(self._disappear_event(msg))
async def _send_message(
self,
intent: IntentAPI,
content: MessageEventContent,
event_type: EventType = EventType.ROOM_MESSAGE,
**kwargs,
) -> EventID:
if self.encrypted and self.matrix.e2ee:
event_type, content = await self.matrix.e2ee.encrypt(self.mxid, event_type, content)
event_id = await intent.send_message_event(self.mxid, event_type, content, **kwargs)
if intent.api.is_real_user:
background_task.create(intent.mark_read(self.mxid, event_id))
return event_id
@property
@abstractmethod
def bridge_info_state_key(self) -> str:
pass
@property
@abstractmethod
def bridge_info(self) -> dict[str, Any]:
pass
# region Matrix room cleanup
@abstractmethod
async def delete(self) -> None:
pass
@classmethod
async def cleanup_room(
cls,
intent: IntentAPI,
room_id: RoomID,
message: str = "Cleaning room",
puppets_only: bool = False,
) -> None:
if not puppets_only and cls.bridge.homeserver_software.is_hungry:
try:
await intent.beeper_delete_room(room_id)
return
except MNotFound as err:
cls.log.debug(f"Hungryserv yeet returned {err}, assuming the room is already gone")
return
except Exception:
cls.log.warning(
f"Failed to delete {room_id} using hungryserv yeet endpoint, "
f"falling back to normal method",
exc_info=True,
)
try:
members = await intent.get_room_members(room_id)
except MatrixError:
members = []
for user_id in members:
if user_id == intent.mxid:
continue
puppet = await cls.bridge.get_puppet(user_id, create=False)
if puppet:
await puppet.default_mxid_intent.leave_room(room_id)
continue
if not puppets_only:
custom_puppet = await cls.bridge.get_double_puppet(user_id)
left = False
if custom_puppet:
try:
await custom_puppet.intent.leave_room(room_id)
await custom_puppet.intent.forget_room(room_id)
except MatrixError:
pass
else:
left = True
if not left:
try:
await intent.kick_user(room_id, user_id, message)
except MatrixError:
pass
try:
await intent.leave_room(room_id)
except MatrixError:
cls.log.warning(f"Failed to leave room {room_id} when cleaning up room", exc_info=True)
async def cleanup_portal(self, message: str, puppets_only: bool = False) -> None:
await self.cleanup_room(self.main_intent, self.mxid, message, puppets_only)
await self.delete()
async def unbridge(self) -> None:
await self.cleanup_portal("Room unbridged", puppets_only=True)
async def cleanup_and_delete(self) -> None:
await self.cleanup_portal("Portal deleted")
async def get_authenticated_matrix_users(self) -> list[UserID]:
"""
Get the list of Matrix user IDs who can be bridged. This is used to determine if the portal
is empty (and should be cleaned up) or not. Bridges should override this to check that the
users are either logged in or the portal has a relaybot.
"""
try:
members = await self.main_intent.get_room_members(self.mxid)
except MatrixRequestError:
return []
return [
member
for member in members
if (not self.bridge.is_bridge_ghost(member) and member != self.az.bot_mxid)
]
# endregion
python-0.20.4/mautrix/bridge/puppet.py 0000664 0000000 0000000 00000002207 14547234302 0017737 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any
from abc import ABC, abstractmethod
from collections import defaultdict
import asyncio
import logging
from mautrix.appservice import AppService, IntentAPI
from mautrix.types import UserID
from mautrix.util.logging import TraceLogger
from .. import bridge as br
from .custom_puppet import CustomPuppetMixin
class BasePuppet(CustomPuppetMixin, ABC):
log: TraceLogger = logging.getLogger("mau.puppet")
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
az: AppService
loop: asyncio.AbstractEventLoop
mx: br.BaseMatrixHandler
is_registered: bool
mxid: str
intent: IntentAPI
@classmethod
@abstractmethod
async def get_by_mxid(cls, mxid: UserID) -> BasePuppet:
pass
@classmethod
@abstractmethod
async def get_by_custom_mxid(cls, mxid: UserID) -> BasePuppet:
pass
python-0.20.4/mautrix/bridge/state_store/ 0000775 0000000 0000000 00000000000 14547234302 0020403 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/bridge/state_store/__init__.py 0000664 0000000 0000000 00000000026 14547234302 0022512 0 ustar 00root root 0000000 0000000 __all__ = ["asyncpg"]
python-0.20.4/mautrix/bridge/state_store/asyncpg.py 0000664 0000000 0000000 00000002723 14547234302 0022425 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Awaitable, Callable, Union
from mautrix.appservice.state_store.asyncpg import PgASStateStore
from mautrix.types import UserID
from mautrix.util.async_db import Database
from ..puppet import BasePuppet
GetPuppetFunc = Union[
Callable[[UserID], Awaitable[BasePuppet]], Callable[[UserID, bool], Awaitable[BasePuppet]]
]
class PgBridgeStateStore(PgASStateStore):
def __init__(
self, db: Database, get_puppet: GetPuppetFunc, get_double_puppet: GetPuppetFunc
) -> None:
super().__init__(db)
self.get_puppet = get_puppet
self.get_double_puppet = get_double_puppet
async def is_registered(self, user_id: UserID) -> bool:
puppet = await self.get_puppet(user_id)
if puppet:
return puppet.is_registered
custom_puppet = await self.get_double_puppet(user_id)
if custom_puppet:
return True
return await super().is_registered(user_id)
async def registered(self, user_id: UserID) -> None:
puppet = await self.get_puppet(user_id, True)
if puppet:
puppet.is_registered = True
await puppet.save()
else:
await super().registered(user_id)
python-0.20.4/mautrix/bridge/user.py 0000664 0000000 0000000 00000023132 14547234302 0017400 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, NamedTuple
from abc import ABC, abstractmethod
from collections import defaultdict, deque
import asyncio
import logging
import time
from mautrix.api import Method, Path
from mautrix.appservice import AppService
from mautrix.errors import MNotFound
from mautrix.types import EventID, EventType, Membership, MessageType, RoomID, UserID
from mautrix.util import background_task
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
from mautrix.util.logging import TraceLogger
from mautrix.util.message_send_checkpoint import (
MessageSendCheckpoint,
MessageSendCheckpointReportedBy,
MessageSendCheckpointStatus,
MessageSendCheckpointStep,
)
from mautrix.util.opt_prometheus import Gauge
from .. import bridge as br
AsmuxPath = Path.unstable["com.beeper.asmux"]
class WrappedTask(NamedTuple):
task: asyncio.Task | None
class BaseUser(ABC):
log: TraceLogger = logging.getLogger("mau.user")
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
az: AppService
bridge: br.Bridge
loop: asyncio.AbstractEventLoop
is_whitelisted: bool
is_admin: bool
relay_whitelisted: bool
mxid: UserID
dm_update_lock: asyncio.Lock
command_status: dict[str, Any] | None
_metric_value: dict[Gauge, bool]
_prev_bridge_status: BridgeState | None
_bridge_state_queue: deque[BridgeState]
_bridge_state_loop: asyncio.Task | None
def __init__(self) -> None:
self.dm_update_lock = asyncio.Lock()
self.command_status = None
self._metric_value = defaultdict(lambda: False)
self._prev_bridge_status = None
self.log = self.log.getChild(self.mxid)
self.relay_whitelisted = False
self._bridge_state_queue = deque()
self._bridge_state_loop = None
@abstractmethod
async def is_logged_in(self) -> bool:
raise NotImplementedError()
@abstractmethod
async def get_puppet(self) -> br.BasePuppet | None:
"""
Get the ghost that represents this Matrix user on the remote network.
Returns:
The puppet entity, or ``None`` if the user is not logged in,
or it's otherwise not possible to find the remote ghost.
"""
raise NotImplementedError()
@abstractmethod
async def get_portal_with(
self, puppet: br.BasePuppet, create: bool = True
) -> br.BasePortal | None:
"""
Get a private chat portal between this user and the given ghost.
Args:
puppet: The ghost who the portal should be with.
create: ``True`` if the portal entity should be created if it doesn't exist.
Returns:
The portal entity, or ``None`` if it can't be found,
or doesn't exist and ``create`` is ``False``.
"""
async def needs_relay(self, portal: br.BasePortal) -> bool:
return not await self.is_logged_in()
async def is_in_portal(self, portal: br.BasePortal) -> bool:
try:
member_event = await portal.main_intent.get_state_event(
portal.mxid, EventType.ROOM_MEMBER, self.mxid
)
except MNotFound:
return False
return member_event and member_event.membership in (Membership.JOIN, Membership.INVITE)
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
raise NotImplementedError()
async def update_direct_chats(self, dms: dict[UserID, list[RoomID]] | None = None) -> None:
"""
Update the m.direct account data of the user.
Args:
dms: DMs to _add_ to the list. If not provided, the list is _replaced_ with the result
of :meth:`get_direct_chats`.
"""
if not self.bridge.config["bridge.sync_direct_chat_list"]:
return
puppet = await self.bridge.get_double_puppet(self.mxid)
if not puppet or not puppet.is_real_user:
return
self.log.debug("Updating m.direct list on homeserver")
replace = dms is None
dms = dms or await self.get_direct_chats()
if self.bridge.homeserver_software.is_asmux:
# This uses a secret endpoint for atomically updating the DM list
await puppet.intent.api.request(
Method.PUT if replace else Method.PATCH,
AsmuxPath.dms,
content=dms,
headers={"X-Asmux-Auth": self.az.as_token},
)
else:
async with self.dm_update_lock:
try:
current_dms = await puppet.intent.get_account_data(EventType.DIRECT)
except MNotFound:
current_dms = {}
if replace:
# Filter away all existing DM statuses with bridge users
filtered_dms = {
user: rooms
for user, rooms in current_dms.items()
if not self.bridge.is_bridge_ghost(user)
}
else:
filtered_dms = current_dms
# Add DM statuses for all rooms in our database
new_dms = {**filtered_dms, **dms}
if current_dms != new_dms:
await puppet.intent.set_account_data(EventType.DIRECT, new_dms)
def _track_metric(self, metric: Gauge, value: bool) -> None:
if self._metric_value[metric] != value:
if value:
metric.inc(1)
else:
metric.dec(1)
self._metric_value[metric] = value
async def fill_bridge_state(self, state: BridgeState) -> None:
state.user_id = self.mxid
state.fill()
async def get_bridge_states(self) -> list[BridgeState]:
raise NotImplementedError()
async def push_bridge_state(
self,
state_event: BridgeStateEvent,
error: str | None = None,
message: str | None = None,
ttl: int | None = None,
remote_id: str | None = None,
info: dict[str, Any] | None = None,
reason: str | None = None,
) -> None:
if not self.bridge.config["homeserver.status_endpoint"]:
return
state = BridgeState(
state_event=state_event,
error=error,
message=message,
ttl=ttl,
remote_id=remote_id,
info=info,
reason=reason,
)
await self.fill_bridge_state(state)
if state.should_deduplicate(self._prev_bridge_status):
return
self._prev_bridge_status = state
self._bridge_state_queue.append(state)
if not self._bridge_state_loop or self._bridge_state_loop.done():
self.log.trace(f"Starting bridge state loop")
self._bridge_state_loop = asyncio.create_task(self._start_bridge_state_send_loop())
else:
self.log.debug(f"Queued bridge state to send later: {state.state_event}")
async def _start_bridge_state_send_loop(self):
url = self.bridge.config["homeserver.status_endpoint"]
while self._bridge_state_queue:
state = self._bridge_state_queue.popleft()
success = await state.send(url, self.az.as_token, self.log)
if not success:
if state.send_attempts_ <= 10:
retry_seconds = state.send_attempts_**2
self.log.warning(
f"Attempt #{state.send_attempts_} of sending bridge state "
f"{state.state_event} failed, retrying in {retry_seconds} seconds"
)
await asyncio.sleep(retry_seconds)
self._bridge_state_queue.appendleft(state)
else:
self.log.error(
f"Failed to send bridge state {state.state_event} "
f"after {state.send_attempts_} attempts, giving up"
)
self._bridge_state_loop = None
def send_remote_checkpoint(
self,
status: MessageSendCheckpointStatus,
event_id: EventID,
room_id: RoomID,
event_type: EventType,
message_type: MessageType | None = None,
error: str | Exception | None = None,
retry_num: int = 0,
) -> WrappedTask:
"""
Send a remote checkpoint for the given ``event_id``. This function spaws an
:class:`asyncio.Task`` to send the checkpoint.
:returns: the checkpoint send task. This can be awaited if you want to block on the
checkpoint send.
"""
if not self.bridge.config["homeserver.message_send_checkpoint_endpoint"]:
return WrappedTask(task=None)
task = background_task.create(
MessageSendCheckpoint(
event_id=event_id,
room_id=room_id,
step=MessageSendCheckpointStep.REMOTE,
timestamp=int(time.time() * 1000),
status=status,
reported_by=MessageSendCheckpointReportedBy.BRIDGE,
event_type=event_type,
message_type=message_type,
info=str(error) if error else None,
retry_num=retry_num,
).send(
self.bridge.config["homeserver.message_send_checkpoint_endpoint"],
self.az.as_token,
self.log,
)
)
return WrappedTask(task=task)
python-0.20.4/mautrix/client/ 0000775 0000000 0000000 00000000000 14547234302 0016071 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/client/__init__.py 0000664 0000000 0000000 00000001452 14547234302 0020204 0 ustar 00root root 0000000 0000000 from .api import ClientAPI
from .client import Client
from .dispatcher import Dispatcher, MembershipEventDispatcher, SimpleDispatcher
from .encryption_manager import DecryptionDispatcher, EncryptingAPI
from .state_store import FileStateStore, MemoryStateStore, MemorySyncStore, StateStore, SyncStore
from .store_updater import StoreUpdatingAPI
from .syncer import EventHandler, InternalEventType, Syncer, SyncStream
__all__ = [
"ClientAPI",
"Client",
"Dispatcher",
"MembershipEventDispatcher",
"SimpleDispatcher",
"DecryptionDispatcher",
"EncryptingAPI",
"FileStateStore",
"MemoryStateStore",
"MemorySyncStore",
"StateStore",
"SyncStore",
"StoreUpdatingAPI",
"EventHandler",
"InternalEventType",
"Syncer",
"SyncStream",
"state_store",
]
python-0.20.4/mautrix/client/api/ 0000775 0000000 0000000 00000000000 14547234302 0016642 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/client/api/__init__.py 0000664 0000000 0000000 00000000411 14547234302 0020747 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from .client import ClientAPI
python-0.20.4/mautrix/client/api/authentication.py 0000664 0000000 0000000 00000016157 14547234302 0022245 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import (
DeviceID,
LoginFlowList,
LoginResponse,
LoginType,
MatrixUserIdentifier,
UserID,
UserIdentifier,
WhoamiResponse,
)
from .base import BaseClientAPI
class ClientAuthenticationMethods(BaseClientAPI):
"""
Methods in section 5 Authentication of the spec. These methods are used for setting and getting user
metadata and searching for users.
See also: `API reference `__
"""
# region 5.5 Login
# API reference: https://matrix.org/docs/spec/client_server/r0.6.1.html#login
async def get_login_flows(self) -> LoginFlowList:
"""
Get login flows supported by the homeserver.
See also: `API reference `__
Returns:
The list of login flows that the homeserver supports.
"""
resp = await self.api.request(Method.GET, Path.v3.login)
try:
return LoginFlowList.deserialize(resp)
except KeyError:
raise MatrixResponseError("`flows` not in response.")
async def login(
self,
identifier: UserIdentifier | UserID | None = None,
login_type: LoginType = LoginType.PASSWORD,
device_name: str | None = None,
device_id: str | None = None,
password: str | None = None,
store_access_token: bool = True,
update_hs_url: bool = False,
**kwargs: str,
) -> LoginResponse:
"""
Authenticates the user, and issues an access token they can use to authorize themself in
subsequent requests.
See also: `API reference `__
Args:
login_type: The login type being used.
identifier: Identification information for the user.
device_name: A display name to assign to the newly-created device.
Ignored if ``device_id`` correspnods to a known device.
device_id: ID of the client device. If this does not correspond to a known client
device, a new device will be created. The server will auto-generate a device_id
if this is not specified.
password: The user's password. Required when `type` is `m.login.password`.
store_access_token: Whether or not mautrix-python should store the returned access token
in this ClientAPI instance for future requests.
update_hs_url: Whether or not mautrix-python should use the returned homeserver URL
in this ClientAPI instance for future requests.
**kwargs: Additional arguments for other login types.
Returns:
The login response.
"""
if identifier is None or isinstance(identifier, str):
identifier = MatrixUserIdentifier(identifier or self.mxid)
if password is not None:
kwargs["password"] = password
if device_name is not None:
kwargs["initial_device_display_name"] = device_name
if device_id:
kwargs["device_id"] = device_id
elif self.device_id:
kwargs["device_id"] = self.device_id
resp = await self.api.request(
Method.POST,
Path.v3.login,
{
"type": str(login_type),
"identifier": identifier.serialize(),
**kwargs,
},
sensitive="password" in kwargs or "token" in kwargs,
)
resp_data = LoginResponse.deserialize(resp)
if store_access_token:
self.mxid = resp_data.user_id
self.device_id = resp_data.device_id
self.api.token = resp_data.access_token
if update_hs_url:
base_url = resp_data.well_known.homeserver.base_url
if base_url and base_url != self.api.base_url:
self.log.debug(
"Login response contained new base URL, switching from "
f"{self.api.base_url} to {base_url}"
)
self.api.base_url = base_url.rstrip("/")
return resp_data
async def logout(self, clear_access_token: bool = True) -> None:
"""
Invalidates an existing access token, so that it can no longer be used for authorization.
The device associated with the access token is also deleted.
`Device keys `__ for the
device are deleted alongside the device.
See also: `API reference `__
Args:
clear_access_token: Whether or not mautrix-python should forget the stored access token.
"""
await self.api.request(Method.POST, Path.v3.logout)
if clear_access_token:
self.api.token = ""
self.device_id = DeviceID("")
async def logout_all(self, clear_access_token: bool = True) -> None:
"""
Invalidates all access tokens for a user, so that they can no longer be used for
authorization. This includes the access token that made this request. All devices for the
user are also deleted.
`Device keys `__ for the
device are deleted alongside the device.
This endpoint does not require UI (user-interactive) authorization because UI authorization
is designed to protect against attacks where the someone gets hold of a single access token
then takes over the account. This endpoint invalidates all access tokens for the user,
including the token used in the request, and therefore the attacker is unable to take over
the account in this way.
See also: `API reference `__
Args:
clear_access_token: Whether or not mautrix-python should forget the stored access token.
"""
await self.api.request(Method.POST, Path.v3.logout.all)
if clear_access_token:
self.api.token = ""
self.device_id = DeviceID("")
# endregion
# TODO other sections
# region 5.7 Current account information
# API reference: https://matrix.org/docs/spec/client_server/r0.6.1.html#current-account-information
async def whoami(self) -> WhoamiResponse:
"""
Get information about the current user.
Returns:
The user ID and device ID of the current user.
"""
resp = await self.api.request(Method.GET, Path.v3.account.whoami)
return WhoamiResponse.deserialize(resp)
# endregion
python-0.20.4/mautrix/client/api/base.py 0000664 0000000 0000000 00000014764 14547234302 0020142 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import json
from aiohttp import ClientError, ClientSession, ContentTypeError
from yarl import URL
from mautrix.api import HTTPAPI, Method, Path
from mautrix.errors import (
WellKnownInvalidVersionsResponse,
WellKnownMissingHomeserver,
WellKnownNotJSON,
WellKnownNotURL,
WellKnownUnexpectedStatus,
WellKnownUnsupportedScheme,
)
from mautrix.types import DeviceID, SerializerError, UserID, VersionsResponse
from mautrix.util.logging import TraceLogger
class BaseClientAPI:
"""
BaseClientAPI is the base class for :class:`ClientAPI`. This is separate from the main
ClientAPI class so that the ClientAPI methods can be split into multiple classes (that
inherit this class).All those section-specific method classes are inherited by the main
ClientAPI class to create the full class.
"""
localpart: str
domain: str
_mxid: UserID
device_id: DeviceID
api: HTTPAPI
log: TraceLogger
versions_cache: VersionsResponse | None
def __init__(
self, mxid: UserID = "", device_id: DeviceID = "", api: HTTPAPI | None = None, **kwargs
) -> None:
"""
Initialize a ClientAPI. You must either provide the ``api`` parameter with an existing
:class:`mautrix.api.HTTPAPI` instance, or provide the ``base_url`` and other arguments for
creating it as kwargs.
Args:
mxid: The Matrix ID of the user. This is used for things like setting profile metadata.
Additionally, the homeserver domain is extracted from this string and used for
setting aliases and such. This can be changed later using `set_mxid`.
device_id: The device ID corresponding to the access token used.
api: The :class:`mautrix.api.HTTPAPI` instance to use. You can also pass the ``kwargs``
to create a HTTPAPI instance rather than creating the instance yourself.
kwargs: If ``api`` is not specified, then the arguments to pass when creating a HTTPAPI.
"""
if mxid:
self.mxid = mxid
else:
self._mxid = None
self.localpart = None
self.domain = None
self.fill_member_event_callback = None
self.versions_cache = None
self.device_id = device_id
self.api = api or HTTPAPI(**kwargs)
self.log = self.api.log
@classmethod
def parse_user_id(cls, mxid: UserID) -> tuple[str, str]:
"""
Parse the localpart and server name from a Matrix user ID.
Args:
mxid: The Matrix user ID.
Returns:
A tuple of (localpart, server_name).
Raises:
ValueError: if the given user ID is invalid.
"""
if len(mxid) == 0:
raise ValueError("User ID is empty")
elif mxid[0] != "@":
raise ValueError("User IDs start with @")
try:
sep = mxid.index(":")
except ValueError as e:
raise ValueError("User ID must contain domain separator") from e
if sep == len(mxid) - 1:
raise ValueError("User ID must contain domain")
return mxid[1:sep], mxid[sep + 1 :]
@property
def mxid(self) -> UserID:
return self._mxid
@mxid.setter
def mxid(self, mxid: UserID) -> None:
self.localpart, self.domain = self.parse_user_id(mxid)
self._mxid = mxid
async def versions(self, no_cache: bool = False) -> VersionsResponse:
"""
Get client-server spec versions supported by the server.
Args:
no_cache: If true, the versions will always be fetched from the server
rather than using cached results when availab.e.
Returns:
The supported Matrix spec versions and unstable features.
"""
if no_cache or not self.versions_cache:
resp = await self.api.request(Method.GET, Path.versions)
self.versions_cache = VersionsResponse.deserialize(resp)
return self.versions_cache
@classmethod
async def discover(cls, domain: str, session: ClientSession | None = None) -> URL | None:
"""
Follow the server discovery spec to find the actual URL when given a Matrix server name.
Args:
domain: The server name (end of user ID) to discover.
session: Optionally, the aiohttp ClientSession object to use.
Returns:
The parsed URL if the discovery succeeded.
``None`` if the request returned a 404 status.
Raises:
WellKnownError: for other errors
"""
if session is None:
async with ClientSession(headers={"User-Agent": HTTPAPI.default_ua}) as sess:
return await cls._discover(domain, sess)
else:
return await cls._discover(domain, session)
@classmethod
async def _discover(cls, domain: str, session: ClientSession) -> URL | None:
well_known = URL.build(scheme="https", host=domain, path="/.well-known/matrix/client")
async with session.get(well_known) as resp:
if resp.status == 404:
return None
elif resp.status != 200:
raise WellKnownUnexpectedStatus(resp.status)
try:
data = await resp.json(content_type=None)
except (json.JSONDecodeError, ContentTypeError) as e:
raise WellKnownNotJSON() from e
try:
homeserver_url = data["m.homeserver"]["base_url"]
except KeyError as e:
raise WellKnownMissingHomeserver() from e
parsed_url = URL(homeserver_url)
if not parsed_url.is_absolute():
raise WellKnownNotURL()
elif parsed_url.scheme not in ("http", "https"):
raise WellKnownUnsupportedScheme(parsed_url.scheme)
try:
async with session.get(parsed_url / "_matrix/client/versions") as resp:
data = VersionsResponse.deserialize(await resp.json())
if len(data.versions) == 0:
raise ValueError("no versions defined in /_matrix/client/versions response")
except (ClientError, json.JSONDecodeError, SerializerError, ValueError) as e:
raise WellKnownInvalidVersionsResponse() from e
return parsed_url
python-0.20.4/mautrix/client/api/client.py 0000664 0000000 0000000 00000002546 14547234302 0020501 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from .authentication import ClientAuthenticationMethods
from .events import EventMethods
from .filtering import FilteringMethods
from .modules import ModuleMethods
from .rooms import RoomMethods
from .user_data import UserDataMethods
class ClientAPI(
ClientAuthenticationMethods,
FilteringMethods,
RoomMethods,
EventMethods,
UserDataMethods,
ModuleMethods,
):
"""
ClientAPI is a medium-level wrapper around the HTTPAPI that provides many easy-to-use
functions for accessing the client-server API.
This class can be used directly, but generally you should use the higher-level wrappers that
inherit from this class, such as :class:`mautrix.client.Client`
or :class:`mautrix.appservice.IntentAPI`.
Examples:
>>> from mautrix.client import ClientAPI
>>> client = ClientAPI("@user:matrix.org", base_url="https://matrix-client.matrix-org",
token="syt_123_456")
>>> await client.whoami()
WhoamiResponse(user_id="@user:matrix.org", device_id="DEV123")
>>> await client.get_joined_rooms()
["!roomid:matrix.org"]
"""
python-0.20.4/mautrix/client/api/events.py 0000664 0000000 0000000 00000067557 14547234302 0020544 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Awaitable
import json
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import (
JSON,
BaseFileInfo,
ContentURI,
Event,
EventContent,
EventContext,
EventID,
EventType,
FilterID,
Format,
ImageInfo,
MediaMessageEventContent,
Member,
Membership,
MessageEventContent,
MessageType,
Obj,
PaginatedMessages,
PaginationDirection,
PresenceState,
ReactionEventContent,
RelatesTo,
RelationType,
RoomEventFilter,
RoomID,
Serializable,
SerializerError,
StateEvent,
StateEventContent,
SyncToken,
TextMessageEventContent,
UserID,
)
from mautrix.util.formatter import parse_html
from .base import BaseClientAPI
class EventMethods(BaseClientAPI):
"""
Methods in section 8 Events of the spec. Includes ``/sync``'ing, getting messages and state,
setting state, sending messages and redacting messages. See also: `Events API reference`_
.. _Events API reference:
https://spec.matrix.org/v1.1/client-server-api/#events
"""
# region 8.4 Syncing
# API reference: https://spec.matrix.org/v1.1/client-server-api/#syncing
def sync(
self,
since: SyncToken | None = None,
timeout: int = 30000,
filter_id: FilterID | None = None,
full_state: bool = False,
set_presence: PresenceState | None = None,
) -> Awaitable[JSON]:
"""
Perform a sync request. See also: `/sync API reference`_
This method doesn't parse the response at all.
You should use :class:`mautrix.client.Syncer` to parse sync responses and dispatch the data
into event handlers. :class:`mautrix.client.Client` includes ``Syncer``.
Args:
since (str): Optional. A token which specifies where to continue a sync from.
timeout (int): Optional. The time in milliseconds to wait.
filter_id (int): A filter ID.
full_state (bool): Return the full state for every room the user has joined
Defaults to false.
set_presence (str): Should the client be marked as "online" or" offline"
.. _/sync API reference:
https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
"""
request = {"timeout": timeout}
if since:
request["since"] = str(since)
if filter_id:
request["filter"] = str(filter_id)
if full_state:
request["full_state"] = "true" if full_state else "false"
if set_presence:
request["set_presence"] = str(set_presence)
return self.api.request(
Method.GET, Path.v3.sync, query_params=request, retry_count=0, metrics_method="sync"
)
# endregion
# region 7.5 Getting events for a room
# API reference: https://spec.matrix.org/v1.1/client-server-api/#getting-events-for-a-room
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
"""
Get a single event based on ``room_id``/``event_id``. You must have permission to retrieve
this event e.g. by being a member in the room for this event.
See also: `API reference `__
Args:
room_id: The ID of the room the event is in.
event_id: The event ID to get.
Returns:
The event.
"""
content = await self.api.request(
Method.GET, Path.v3.rooms[room_id].event[event_id], metrics_method="getEvent"
)
try:
return Event.deserialize(content)
except SerializerError as e:
raise MatrixResponseError("Invalid event in response") from e
async def get_event_context(
self,
room_id: RoomID,
event_id: EventID,
limit: int | None = 10,
filter: RoomEventFilter | None = None,
) -> EventContext:
"""
Get a number of events that happened just before and after the specified event.
This allows clients to get the context surrounding an event, as well as get the state at
an event and paginate in either direction.
Args:
room_id: The room to get events from.
event_id: The event to get context around.
limit: The maximum number of events to return. The limit applies to the total number of
events before and after the requested event. A limit of 0 means no other events
are returned, while 2 means one event before and one after are returned.
filter: A JSON RoomEventFilter_ to filter returned events with.
Returns:
The event itself, up to ``limit/2`` events before and after the event, the room state
at the event, and pagination tokens to scroll up and down.
.. _RoomEventFilter:
https://spec.matrix.org/v1.1/client-server-api/#filtering
"""
query_params = {}
if limit is not None:
query_params["limit"] = str(limit)
if filter is not None:
query_params["filter"] = (
filter.serialize() if isinstance(filter, Serializable) else filter
)
resp = await self.api.request(
Method.GET,
Path.v3.rooms[room_id].context[event_id],
query_params=query_params,
metrics_method="get_event_context",
)
return EventContext.deserialize(resp)
async def get_state_event(
self,
room_id: RoomID,
event_type: EventType,
state_key: str = "",
) -> StateEventContent:
"""
Looks up the contents of a state event in a room. If the user is joined to the room then the
state is taken from the current state of the room. If the user has left the room then the
state is taken from the state of the room when they left.
See also: `API reference `__
Args:
room_id: The ID of the room to look up the state in.
event_type: The type of state to look up.
state_key: The key of the state to look up. Defaults to empty string.
Returns:
The state event.
"""
content = await self.api.request(
Method.GET,
Path.v3.rooms[room_id].state[event_type][state_key],
metrics_method="getStateEvent",
)
content["__mautrix_event_type"] = event_type
try:
return StateEvent.deserialize_content(content)
except SerializerError as e:
raise MatrixResponseError("Invalid state event in response") from e
async def get_state(self, room_id: RoomID) -> list[StateEvent]:
"""
Get the state events for the current state of a room.
See also: `API reference `__
Args:
room_id: The ID of the room to look up the state for.
Returns:
A list of state events with the most recent of each event_type/state_key pair.
"""
content = await self.api.request(
Method.GET, Path.v3.rooms[room_id].state, metrics_method="getState"
)
try:
return [StateEvent.deserialize(event) for event in content]
except SerializerError as e:
raise MatrixResponseError("Invalid state events in response") from e
async def get_members(
self,
room_id: RoomID,
at: SyncToken | None = None,
membership: Membership | None = None,
not_membership: Membership | None = None,
) -> list[StateEvent]:
"""
Get the list of members for a room.
See also: `API reference `__
Args:
room_id: The ID of the room to get the member events for.
at: The point in time (pagination token) to return members for in the room. This token
can be obtained from a ``prev_batch`` token returned for each room by the sync API.
Defaults to the current state of the room, as determined by the server.
membership: The kind of membership to filter for. Defaults to no filtering if
unspecified. When specified alongside ``not_membership``, the two parameters create
an 'or' condition: either the ``membership`` is the same as membership or is not the
same as ``not_membership``.
not_membership: The kind of membership to exclude from the results. Defaults to no
filtering if unspecified.
Returns:
A list of most recent member events for each user.
"""
query = {}
if at:
query["at"] = at
if membership:
query["membership"] = membership.value
if not_membership:
query["not_membership"] = not_membership.value
content = await self.api.request(
Method.GET,
Path.v3.rooms[room_id].members,
query_params=query,
metrics_method="getMembers",
)
try:
return [StateEvent.deserialize(event) for event in content["chunk"]]
except KeyError:
raise MatrixResponseError("`chunk` not in response.")
except SerializerError as e:
raise MatrixResponseError("Invalid state events in response") from e
async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]:
"""
Get a user ID -> member info map for a room. The current user must be in the room for it to
work, unless it is an Application Service in which case any of the AS's users must be in the
room. This API is primarily for Application Services and should be faster to respond than
`/members`_ as it can be implemented more efficiently on the server.
See also: `API reference `__
Args:
room_id: The ID of the room to get the members of.
Returns:
A dictionary from user IDs to Member info objects.
.. _/members:
https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3roomsroomidmembers
"""
content = await self.api.request(
Method.GET, Path.v3.rooms[room_id].joined_members, metrics_method="getJoinedMembers"
)
try:
return {
user_id: Member(
membership=Membership.JOIN,
displayname=member.get("display_name", ""),
avatar_url=member.get("avatar_url", ""),
)
for user_id, member in content["joined"].items()
}
except KeyError:
raise MatrixResponseError("`joined` not in response.")
except SerializerError as e:
raise MatrixResponseError("Invalid member objects in response") from e
async def get_messages(
self,
room_id: RoomID,
direction: PaginationDirection,
from_token: SyncToken | None = None,
to_token: SyncToken | None = None,
limit: int | None = None,
filter_json: str | dict | RoomEventFilter | None = None,
) -> PaginatedMessages:
"""
Get a list of message and state events for a room. Pagination parameters are used to
paginate history in the room.
See also: `API reference `__
Args:
room_id: The ID of the room to get events from.
direction: The direction to return events from.
from_token: The token to start returning events from. This token can be obtained from a
``prev_batch`` token returned for each room by the `sync endpoint`_, or from a
``start`` or ``end`` token returned by a previous request to this endpoint.
Starting from Matrix v1.3, this field can be omitted to fetch events from the
beginning or end of the room.
to_token: The token to stop returning events at.
limit: The maximum number of events to return. Defaults to 10.
filter_json: A JSON RoomEventFilter_ to filter returned events with.
Returns:
.. _RoomEventFilter:
https://spec.matrix.org/v1.3/client-server-api/#filtering
.. _sync endpoint:
https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3sync
"""
if isinstance(filter_json, Serializable):
filter_json = filter_json.json()
elif isinstance(filter_json, dict):
filter_json = json.dumps(filter_json)
query_params = {
"from": from_token,
"dir": direction.value,
"to": to_token,
"limit": str(limit) if limit else None,
"filter": filter_json,
}
content = await self.api.request(
Method.GET,
Path.v3.rooms[room_id].messages,
query_params=query_params,
metrics_method="getMessages",
)
try:
return PaginatedMessages(
content["start"],
content["end"],
[Event.deserialize(event) for event in content["chunk"]],
)
except KeyError:
if "start" not in content:
raise MatrixResponseError("`start` not in response.")
elif "end" not in content:
raise MatrixResponseError("`start` not in response.")
raise MatrixResponseError("`content` not in response.")
except SerializerError as e:
raise MatrixResponseError("Invalid events in response") from e
# endregion
# region 7.6 Sending events to a room
# API reference: https://spec.matrix.org/v1.1/client-server-api/#sending-events-to-a-room
async def send_state_event(
self,
room_id: RoomID,
event_type: EventType,
content: StateEventContent,
state_key: str = "",
ensure_joined: bool = True,
**kwargs,
) -> EventID:
"""
Send a state event to a room. State events with the same ``room_id``, ``event_type`` and
``state_key`` will be overridden.
See also: `API reference `__
Args:
room_id: The ID of the room to set the state in.
event_type: The type of state to send.
content: The content to send.
state_key: The key for the state to send. Defaults to empty string.
ensure_joined: Used by IntentAPI to determine if it should ensure the user is joined
before sending the event.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Used by
:class:`IntentAPI` to pass the timestamp massaging field to
:meth:`AppServiceAPI.request`.
Returns:
The ID of the event that was sent.
"""
content = content.serialize() if isinstance(content, Serializable) else content
resp = await self.api.request(
Method.PUT,
Path.v3.rooms[room_id].state[event_type][state_key],
content,
**kwargs,
metrics_method="sendStateEvent",
)
try:
return resp["event_id"]
except KeyError:
raise MatrixResponseError("`event_id` not in response.")
async def send_message_event(
self,
room_id: RoomID,
event_type: EventType,
content: EventContent,
txn_id: str | None = None,
**kwargs,
) -> EventID:
"""
Send a message event to a room. Message events allow access to historical events and
pagination, making them suited for "once-off" activity in a room.
See also: `API reference `__
Args:
room_id: The ID of the room to send the message to.
event_type: The type of message to send.
content: The content to send.
txn_id: The transaction ID to use. If not provided, a random ID will be generated.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Used by
:class:`IntentAPI` to pass the timestamp massaging field to
:meth:`AppServiceAPI.request`.
Returns:
The ID of the event that was sent.
"""
if not room_id:
raise ValueError("Room ID not given")
elif not event_type:
raise ValueError("Event type not given")
url = Path.v3.rooms[room_id].send[event_type][txn_id or self.api.get_txn_id()]
content = content.serialize() if isinstance(content, Serializable) else content
resp = await self.api.request(
Method.PUT, url, content, **kwargs, metrics_method="sendMessageEvent"
)
try:
return resp["event_id"]
except KeyError:
raise MatrixResponseError("`event_id` not in response.")
# region Message send helper functions
def send_message(
self,
room_id: RoomID,
content: MessageEventContent,
**kwargs,
) -> Awaitable[EventID]:
"""
Send a message to a room.
Args:
room_id: The ID of the room to send the message to.
content: The content to send.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method.
Returns:
The ID of the event that was sent.
"""
return self.send_message_event(room_id, EventType.ROOM_MESSAGE, content, **kwargs)
def react(self, room_id: RoomID, event_id: EventID, key: str, **kwargs) -> Awaitable[EventID]:
content = ReactionEventContent(
relates_to=RelatesTo(rel_type=RelationType.ANNOTATION, event_id=event_id, key=key)
)
return self.send_message_event(room_id, EventType.REACTION, content, **kwargs)
async def send_text(
self,
room_id: RoomID,
text: str | None = None,
html: str | None = None,
msgtype: MessageType = MessageType.TEXT,
relates_to: RelatesTo | None = None,
**kwargs,
) -> EventID:
"""
Send a text message to a room.
Args:
room_id: The ID of the room to send the message to.
text: The text to send. If set to ``None``, the given HTML will be parsed to generate
a plaintext representation.
html: The HTML to send.
msgtype: The message type to send.
Defaults to :attr:`MessageType.TEXT` (normal text message).
relates_to: Message relation metadata used for things like replies.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method.
Returns:
The ID of the event that was sent.
Raises:
ValueError: if both ``text`` and ``html`` are ``None``.
"""
if html is not None:
if text is None:
text = await parse_html(html)
content = TextMessageEventContent(
msgtype=msgtype, body=text, format=Format.HTML, formatted_body=html
)
elif text is not None:
content = TextMessageEventContent(msgtype=msgtype, body=text)
else:
raise TypeError("send_text() requires either text or html to be set")
if relates_to:
content.relates_to = relates_to
return await self.send_message(room_id, content, **kwargs)
def send_notice(
self,
room_id: RoomID,
text: str | None = None,
html: str | None = None,
relates_to: RelatesTo | None = None,
**kwargs,
) -> Awaitable[EventID]:
"""
Send a notice text message to a room. Notices are like normal text messages, but usually
sent by bots to tell other bots not to react to them. If you're a bot, please send notices
instead of normal text, unless there is a reason to do something else.
Args:
room_id: The ID of the room to send the message to.
text: The text to send. If set to ``None``, the given HTML will be parsed to generate
a plaintext representation.
html: The HTML to send.
relates_to: Message relation metadata used for things like replies.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method.
Returns:
The ID of the event that was sent.
Raises:
ValueError: if both ``text`` and ``html`` are ``None``.
"""
return self.send_text(room_id, text, html, MessageType.NOTICE, relates_to, **kwargs)
def send_emote(
self,
room_id: RoomID,
text: str | None = None,
html: str | None = None,
relates_to: RelatesTo | None = None,
**kwargs,
) -> Awaitable[EventID]:
"""
Send an emote to a room. Emotes are usually displayed by prepending a star and the user's
display name to the message, which means they're usually written in the third person.
Args:
room_id: The ID of the room to send the message to.
text: The text to send. If set to ``None``, the given HTML will be parsed to generate
a plaintext representation.
html: The HTML to send.
relates_to: Message relation metadata used for things like replies.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method.
Returns:
The ID of the event that was sent.
Raises:
ValueError: if both ``text`` and ``html`` are ``None``.
"""
return self.send_text(room_id, text, html, MessageType.EMOTE, relates_to, **kwargs)
def send_file(
self,
room_id: RoomID,
url: ContentURI,
info: BaseFileInfo | None = None,
file_name: str | None = None,
file_type: MessageType = MessageType.FILE,
relates_to: RelatesTo | None = None,
**kwargs,
) -> Awaitable[EventID]:
"""
Send a file to a room.
Args:
room_id: The ID of the room to send the message to.
url: The Matrix content repository URI of the file. You can upload files using
:meth:`~MediaRepositoryMethods.upload_media`.
info: Additional metadata about the file, e.g. mimetype, image size, video duration, etc
file_name: The name for the file to send.
file_type: The general file type to send. The file type can be further specified by
setting the ``mimetype`` field of the ``info`` parameter. Defaults to
:attr:`MessageType.FILE` (unspecified file type, e.g. document)
relates_to: Message relation metadata used for things like replies.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method.
Returns:
The ID of the event that was sent.
"""
return self.send_message(
room_id,
MediaMessageEventContent(
url=url, info=info, body=file_name, relates_to=relates_to, msgtype=file_type
),
**kwargs,
)
def send_sticker(
self,
room_id: RoomID,
url: ContentURI,
info: ImageInfo | None = None,
text: str = "",
relates_to: RelatesTo | None = None,
**kwargs,
) -> Awaitable[EventID]:
"""
Send a sticker to a room. Stickers are basically images, but they're usually rendered
slightly differently.
Args:
room_id: The ID of the room to send the message to.
url: The Matrix content repository URI of the sticker. You can upload files using
:meth:`~MediaRepositoryMethods.upload_media`.
info: Additional metadata about the sticker, e.g. mimetype and image size
text: A textual description of the sticker.
relates_to: Message relation metadata used for things like replies.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method.
Returns:
The ID of the event that was sent.
"""
return self.send_message_event(
room_id,
EventType.STICKER,
MediaMessageEventContent(url=url, info=info, body=text, relates_to=relates_to),
**kwargs,
)
def send_image(
self,
room_id: RoomID,
url: ContentURI,
info: ImageInfo | None = None,
file_name: str = None,
relates_to: RelatesTo | None = None,
**kwargs,
) -> Awaitable[EventID]:
"""
Send an image to a room.
Args:
room_id: The ID of the room to send the message to.
url: The Matrix content repository URI of the image. You can upload files using
:meth:`~MediaRepositoryMethods.upload_media`.
info: Additional metadata about the image, e.g. mimetype and image size
file_name: The file name for the image to send.
relates_to: Message relation metadata used for things like replies.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method.
Returns:
The ID of the event that was sent.
"""
return self.send_file(
room_id, url, info, file_name, MessageType.IMAGE, relates_to, **kwargs
)
# endregion
# endregion
# region 7.7 Redactions
# API reference: https://spec.matrix.org/v1.1/client-server-api/#redactions
async def redact(
self,
room_id: RoomID,
event_id: EventID,
reason: str | None = None,
extra_content: dict[str, JSON] | None = None,
**kwargs,
) -> EventID:
"""
Send an event to redact a previous event.
Redacting an event strips all information out of an event which isn't critical to the
integrity of the server-side representation of the room.
This cannot be undone.
Users may redact their own events, and any user with a power level greater than or equal to
the redact power level of the room may redact events there.
See also: `API reference `__
Args:
room_id: The ID of the room the event is in.
event_id: The ID of the event to redact.
reason: The reason for the event being redacted.
extra_content: Extra content for the event.
**kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Used by
:class:`IntentAPI` to pass the timestamp massaging field to
:meth:`AppServiceAPI.request`.
Returns:
The ID of the event that was sent to redact the other event.
"""
url = Path.v3.rooms[room_id].redact[event_id][self.api.get_txn_id()]
content = extra_content or {}
if reason:
content["reason"] = reason
resp = await self.api.request(
Method.PUT, url, content=content, **kwargs, metrics_method="redact"
)
try:
return resp["event_id"]
except KeyError:
raise MatrixResponseError("`event_id` not in response.")
# endregion
python-0.20.4/mautrix/client/api/filtering.py 0000664 0000000 0000000 00000004205 14547234302 0021200 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import Filter, FilterID, Serializable
from .base import BaseClientAPI
class FilteringMethods(BaseClientAPI):
"""
Methods in section 7 Filtering of the spec.
Filters can be created on the server and can be passed as as a parameter to APIs which return
events. These filters alter the data returned from those APIs. Not all APIs accept filters.
See also: `API reference `__
"""
async def get_filter(self, filter_id: FilterID) -> Filter:
"""
Download a filter.
See also: `API reference `__
Args:
filter_id: The filter ID to download.
Returns:
The filter data.
"""
content = await self.api.request(Method.GET, Path.v3.user[self.mxid].filter[filter_id])
return Filter.deserialize(content)
async def create_filter(self, filter_params: Filter) -> FilterID:
"""
Upload a new filter definition to the homeserver.
See also: `API reference `__
Args:
filter_params: The filter data.
Returns:
A filter ID that can be used in future requests to refer to the uploaded filter.
"""
resp = await self.api.request(
Method.POST,
Path.v3.user[self.mxid].filter,
filter_params.serialize()
if isinstance(filter_params, Serializable)
else filter_params,
)
try:
return resp["filter_id"]
except KeyError:
raise MatrixResponseError("`filter_id` not in response.")
# endregion
python-0.20.4/mautrix/client/api/modules/ 0000775 0000000 0000000 00000000000 14547234302 0020312 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/client/api/modules/__init__.py 0000664 0000000 0000000 00000001532 14547234302 0022424 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from .account_data import AccountDataMethods
from .crypto import CryptoMethods
from .media_repository import MediaRepositoryMethods
from .misc import MiscModuleMethods
from .push_rules import PushRuleMethods
from .room_tag import RoomTaggingMethods
class ModuleMethods(
MediaRepositoryMethods,
CryptoMethods,
AccountDataMethods,
MiscModuleMethods,
RoomTaggingMethods,
PushRuleMethods,
):
"""
Methods in section 13 Modules of the spec.
See also: `API reference `__
"""
# TODO: subregions 15, 21, 26, 27, others?
python-0.20.4/mautrix/client/api/modules/account_data.py 0000664 0000000 0000000 00000005176 14547234302 0023322 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.api import Method, Path
from mautrix.types import JSON, AccountDataEventContent, EventType, RoomID, Serializable
from ..base import BaseClientAPI
class AccountDataMethods(BaseClientAPI):
"""
Methods in section 13.9 Client Config of the spec. These methods are used for storing user-local
data on the homeserver to synchronize client configuration across sessions.
See also: `API reference `__"""
async def get_account_data(self, type: EventType | str, room_id: RoomID | None = None) -> JSON:
"""
Get a specific account data event from the homeserver.
See also: `API reference `__
Args:
type: The type of the account data event to get.
room_id: Optionally, the room ID to get per-room account data.
Returns:
The data in the event.
"""
if isinstance(type, EventType) and not type.is_account_data:
raise ValueError("Event type is not an account data event type")
base_path = Path.v3.user[self.mxid]
if room_id:
base_path = base_path.rooms[room_id]
return await self.api.request(Method.GET, base_path.account_data[type])
async def set_account_data(
self,
type: EventType | str,
data: AccountDataEventContent | dict[str, JSON],
room_id: RoomID | None = None,
) -> None:
"""
Store account data on the homeserver.
See also: `API reference `__
Args:
type: The type of the account data event to set.
data: The content to store in that account data event.
room_id: Optionally, the room ID to set per-room account data.
"""
if isinstance(type, EventType) and not type.is_account_data:
raise ValueError("Event type is not an account data event type")
base_path = Path.v3.user[self.mxid]
if room_id:
base_path = base_path.rooms[room_id]
await self.api.request(
Method.PUT,
base_path.account_data[type],
data.serialize() if isinstance(data, Serializable) else data,
)
python-0.20.4/mautrix/client/api/modules/crypto.py 0000664 0000000 0000000 00000015774 14547234302 0022222 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import (
ClaimKeysResponse,
DeviceID,
EncryptionKeyAlgorithm,
EventType,
QueryKeysResponse,
Serializable,
SyncToken,
ToDeviceEventContent,
UserID,
)
from ..base import BaseClientAPI
class CryptoMethods(BaseClientAPI):
"""
Methods in section `13.9 Send-to-Device messaging `__
and `13.11 End-to-End Encryption of the spec `__.
"""
async def send_to_device(
self, event_type: EventType, messages: dict[UserID, dict[DeviceID, ToDeviceEventContent]]
) -> None:
"""
Send to-device events to a set of client devices.
See also: `API reference `__
Args:
event_type: The type of event to send.
messages: The messages to send. A map from user ID, to a map from device ID to message
body. The device ID may also be ``*``, meaning all known devices for the user.
"""
if not event_type.is_to_device:
raise ValueError("Event type must be a to-device event type")
await self.api.request(
Method.PUT,
Path.v3.sendToDevice[event_type][self.api.get_txn_id()],
{
"messages": {
user_id: {
device_id: (
content.serialize() if isinstance(content, Serializable) else content
)
for device_id, content in devices.items()
}
for user_id, devices in messages.items()
},
},
)
async def send_to_one_device(
self,
event_type: EventType,
user_id: UserID,
device_id: DeviceID,
message: ToDeviceEventContent,
) -> None:
"""
Send a to-device event to a single device.
Args:
event_type: The type of event to send.
user_id: The user whose device to send the event to.
device_id: The device ID to send the event to.
message: The event content to send.
"""
return await self.send_to_device(event_type, {user_id: {device_id: message}})
async def upload_keys(
self,
one_time_keys: dict[str, Any] | None = None,
device_keys: dict[str, Any] | None = None,
) -> dict[EncryptionKeyAlgorithm, int]:
"""
Publishes end-to-end encryption keys for the device.
See also: `API reference `__
Args:
one_time_keys: One-time public keys for "pre-key" messages. The names of the properties
should be in the format ``:``. The format of the key is
determined by the key algorithm.
device_keys: Identity keys for the device. May be absent if no new identity keys are
required.
Returns:
For each key algorithm, the number of unclaimed one-time keys of that type currently
held on the server for this device.
"""
data = {}
if device_keys:
data["device_keys"] = device_keys
if one_time_keys:
data["one_time_keys"] = one_time_keys
resp = await self.api.request(Method.POST, Path.v3.keys.upload, data)
try:
return {
EncryptionKeyAlgorithm.deserialize(alg): count
for alg, count in resp["one_time_key_counts"].items()
}
except KeyError as e:
raise MatrixResponseError("`one_time_key_counts` not in response.") from e
except AttributeError as e:
raise MatrixResponseError("Invalid `one_time_key_counts` field in response.") from e
async def query_keys(
self,
device_keys: list[UserID] | set[UserID] | dict[UserID, list[DeviceID]],
token: SyncToken = "",
timeout: int = 10000,
) -> QueryKeysResponse:
"""
Fetch devices and their identity keys for the given users.
See also: `API reference `__
Args:
device_keys: The keys to be downloaded. A map from user ID, to a list of device IDs, or
to an empty list to indicate all devices for the corresponding user.
token: If the client is fetching keys as a result of a device update received in a sync
request, this should be the 'since' token of that sync request, or any later sync
token. This allows the server to ensure its response contains the keys advertised by
the notification in that sync.
timeout: The time (in milliseconds) to wait when downloading keys from remote servers.
Returns:
Information on the queried devices and errors for homeservers that could not be reached.
"""
if isinstance(device_keys, (list, set)):
device_keys = {user_id: [] for user_id in device_keys}
data = {
"timeout": timeout,
"device_keys": device_keys,
}
if token:
data["token"] = token
resp = await self.api.request(Method.POST, Path.v3.keys.query, data)
return QueryKeysResponse.deserialize(resp)
async def claim_keys(
self,
one_time_keys: dict[UserID, dict[DeviceID, EncryptionKeyAlgorithm]],
timeout: int = 10000,
) -> ClaimKeysResponse:
"""
Claim one-time keys for use in pre-key messages.
See also: `API reference `__
Args:
one_time_keys: The keys to be claimed. A map from user ID, to a map from device ID to
algorithm name.
timeout: The time (in milliseconds) to wait when downloading keys from remote servers.
Returns:
One-time keys for the queried devices and errors for homeservers that could not be
reached.
"""
resp = await self.api.request(
Method.POST,
Path.v3.keys.claim,
{
"timeout": timeout,
"one_time_keys": {
user_id: {device_id: alg.serialize() for device_id, alg in devices.items()}
for user_id, devices in one_time_keys.items()
},
},
)
return ClaimKeysResponse.deserialize(resp)
python-0.20.4/mautrix/client/api/modules/media_repository.py 0000664 0000000 0000000 00000032407 14547234302 0024250 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, AsyncIterable, Literal
from contextlib import contextmanager
import asyncio
import time
from mautrix import __optional_imports__
from mautrix.api import MediaPath, Method
from mautrix.errors import MatrixResponseError, make_request_error
from mautrix.types import (
ContentURI,
MediaCreateResponse,
MediaRepoConfig,
MXOpenGraph,
SerializerError,
)
from mautrix.util import background_task
from mautrix.util.async_body import async_iter_bytes
from mautrix.util.opt_prometheus import Histogram
from ..base import BaseClientAPI
try:
from mautrix.util import magic
except ImportError:
if __optional_imports__:
raise
magic = None # type: ignore
UPLOAD_TIME = Histogram(
"bridge_media_upload_time",
"Time spent uploading media (milliseconds per megabyte)",
buckets=[10, 25, 50, 100, 250, 500, 750, 1000, 2500, 5000, 10000],
)
class MediaRepositoryMethods(BaseClientAPI):
"""
Methods in section 13.8 Content Repository of the spec. These methods are used for uploading and
downloading content from the media repository and for getting URL previews without leaking
client IPs.
See also: `API reference `__
"""
async def create_mxc(self) -> MediaCreateResponse:
"""
Create a media ID for uploading media to the homeserver.
See also: `API reference `__
Returns:
MediaCreateResponse Containing the MXC URI that can be used to upload a file to later
"""
resp = await self.api.request(Method.POST, MediaPath.v1.create)
return MediaCreateResponse.deserialize(resp)
@contextmanager
def _observe_upload_time(self, size: int | None, mxc: ContentURI | None = None) -> None:
start = time.monotonic_ns()
yield
duration = time.monotonic_ns() - start
if mxc:
duration_sec = duration / 1000**3
self.log.debug(f"Completed asynchronous upload of {mxc} in {duration_sec:.3f} seconds")
if size:
UPLOAD_TIME.observe(duration / size)
async def upload_media(
self,
data: bytes | bytearray | AsyncIterable[bytes],
mime_type: str | None = None,
filename: str | None = None,
size: int | None = None,
mxc: ContentURI | None = None,
async_upload: bool = False,
) -> ContentURI:
"""
Upload a file to the content repository.
See also: `API reference `__
Args:
data: The data to upload.
mime_type: The MIME type to send with the upload request.
filename: The filename to send with the upload request.
size: The file size to send with the upload request.
mxc: An existing MXC URI which doesn't have content yet to upload into.
async_upload: Should the media be uploaded in the background?
If ``True``, this will create a MXC URI using :meth:`create_mxc`, start uploading
in the background, and then immediately return the created URI. This is mutually
exclusive with manually passing the ``mxc`` parameter.
Returns:
The MXC URI to the uploaded file.
Raises:
MatrixResponseError: If the response does not contain a ``content_uri`` field.
ValueError: if both ``async_upload`` and ``mxc`` are provided at the same time.
"""
if magic and isinstance(data, bytes):
mime_type = mime_type or magic.mimetype(data)
headers = {}
if mime_type:
headers["Content-Type"] = mime_type
if size:
headers["Content-Length"] = str(size)
elif isinstance(data, (bytes, bytearray)):
size = len(data)
query = {}
if filename:
query["filename"] = filename
upload_url = None
if async_upload:
if mxc:
raise ValueError("async_upload and mxc can't be provided simultaneously")
create_response = await self.create_mxc()
mxc = create_response.content_uri
upload_url = create_response.unstable_upload_url
path = MediaPath.v3.upload
method = Method.POST
if mxc:
server_name, media_id = self.api.parse_mxc_uri(mxc)
if upload_url is None:
path = MediaPath.v3.upload[server_name][media_id]
method = Method.PUT
else:
path = (
MediaPath.unstable["com.beeper.msc3870"].upload[server_name][media_id].complete
)
if upload_url is not None:
task = self._upload_to_url(upload_url, path, headers, data, post_upload_query=query)
else:
task = self.api.request(
method, path, content=data, headers=headers, query_params=query
)
if async_upload:
async def _try_upload():
try:
with self._observe_upload_time(size, mxc):
await task
except Exception as e:
self.log.error(f"Failed to upload {mxc}: {type(e).__name__}: {e}")
background_task.create(_try_upload())
return mxc
else:
with self._observe_upload_time(size):
resp = await task
try:
return resp["content_uri"]
except KeyError:
raise MatrixResponseError("`content_uri` not in response.")
async def download_media(self, url: ContentURI, timeout_ms: int | None = None) -> bytes:
"""
Download a file from the content repository.
See also: `API reference `__
Args:
url: The MXC URI to download.
timeout_ms: The maximum number of milliseconds that the client is willing to wait to
start receiving data. Used for asynchronous uploads.
Returns:
The raw downloaded data.
"""
url = self.api.get_download_url(url)
query_params: dict[str, Any] = {"allow_redirect": "true"}
if timeout_ms is not None:
query_params["timeout_ms"] = timeout_ms
req_id = self.api.log_download_request(url, query_params)
start = time.monotonic()
async with self.api.session.get(url, params=query_params) as response:
try:
response.raise_for_status()
return await response.read()
finally:
self.api.log_download_request_done(
url, req_id, time.monotonic() - start, response.status
)
async def download_thumbnail(
self,
url: ContentURI,
width: int | None = None,
height: int | None = None,
resize_method: Literal["crop", "scale"] = None,
allow_remote: bool = True,
timeout_ms: int | None = None,
):
"""
Download a thumbnail for a file in the content repository.
See also: `API reference `__
Args:
url: The MXC URI to download.
width: The _desired_ width of the thumbnail. The actual thumbnail may not match the size
specified.
height: The _desired_ height of the thumbnail. The actual thumbnail may not match the
size specified.
resize_method: The desired resizing method. Either ``crop`` or ``scale``.
allow_remote: Indicates to the server that it should not attempt to fetch the media if
it is deemed remote. This is to prevent routing loops where the server contacts
itself.
timeout_ms: The maximum number of milliseconds that the client is willing to wait to
start receiving data. Used for asynchronous Uploads.
Returns:
The raw downloaded data.
"""
url = self.api.get_download_url(url, download_type="thumbnail")
query_params: dict[str, Any] = {"allow_redirect": "true"}
if width is not None:
query_params["width"] = width
if height is not None:
query_params["height"] = height
if resize_method is not None:
query_params["method"] = resize_method
if allow_remote is not None:
query_params["allow_remote"] = allow_remote
if timeout_ms is not None:
query_params["timeout_ms"] = timeout_ms
req_id = self.api.log_download_request(url, query_params)
start = time.monotonic()
async with self.api.session.get(url, params=query_params) as response:
try:
response.raise_for_status()
return await response.read()
finally:
self.api.log_download_request_done(
url, req_id, time.monotonic() - start, response.status
)
async def get_url_preview(self, url: str, timestamp: int | None = None) -> MXOpenGraph:
"""
Get information about a URL for a client.
See also: `API reference `__
Args:
url: The URL to get a preview of.
timestamp: The preferred point in time to return a preview for. The server may return a
newer version if it does not have the requested version available.
"""
query_params = {"url": url}
if timestamp is not None:
query_params["ts"] = timestamp
content = await self.api.request(
Method.GET, MediaPath.v3.preview_url, query_params=query_params
)
try:
return MXOpenGraph.deserialize(content)
except SerializerError as e:
raise MatrixResponseError("Invalid MXOpenGraph in response.") from e
async def get_media_repo_config(self) -> MediaRepoConfig:
"""
This endpoint allows clients to retrieve the configuration of the content repository, such
as upload limitations. Clients SHOULD use this as a guide when using content repository
endpoints. All values are intentionally left optional. Clients SHOULD follow the advice
given in the field description when the field is not available.
**NOTE:** Both clients and server administrators should be aware that proxies between the
client and the server may affect the apparent behaviour of content repository APIs, for
example, proxies may enforce a lower upload size limit than is advertised by the server on
this endpoint.
See also: `API reference `__
Returns:
The media repository config.
"""
content = await self.api.request(Method.GET, MediaPath.v3.config)
try:
return MediaRepoConfig.deserialize(content)
except SerializerError as e:
raise MatrixResponseError("Invalid MediaRepoConfig in response") from e
async def _upload_to_url(
self,
upload_url: str,
post_upload_path: str,
headers: dict[str, str],
data: bytes | bytearray | AsyncIterable[bytes],
post_upload_query: dict[str, str],
min_iter_size: int = 25 * 1024 * 1024,
) -> None:
retry_count = self.api.default_retry_count
backoff = 2
do_fake_iter = data and hasattr(data, "__len__") and len(data) > min_iter_size
if do_fake_iter:
headers["Content-Length"] = str(len(data))
while True:
self.log.debug("Uploading media to external URL %s", upload_url)
upload_response = None
try:
req_data = async_iter_bytes(data) if do_fake_iter else data
upload_response = await self.api.session.put(
upload_url, data=req_data, headers=headers
)
upload_response.raise_for_status()
except Exception as e:
if retry_count <= 0:
raise make_request_error(
http_status=upload_response.status if upload_response else -1,
text=(await upload_response.text()) if upload_response else "",
errcode="COM.BEEPER.EXTERNAL_UPLOAD_ERROR",
message=None,
)
self.log.warning(
f"Uploading media to external URL {upload_url} failed: {e}, "
f"retrying in {backoff} seconds",
)
await asyncio.sleep(backoff)
backoff *= 2
retry_count -= 1
else:
break
await self.api.request(Method.POST, post_upload_path, query_params=post_upload_query)
python-0.20.4/mautrix/client/api/modules/misc.py 0000664 0000000 0000000 00000013403 14547234302 0021620 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import (
JSON,
EventID,
PresenceEventContent,
PresenceState,
RoomID,
SerializerError,
UserID,
)
from ..base import BaseClientAPI
class MiscModuleMethods(BaseClientAPI):
"""
Miscellaneous subsections in the `Modules section`_ of the API spec.
Currently included subsections:
* 13.4 `Typing Notifications`_
* 13.5 `Receipts`_
* 13.6 `Fully Read Markers`_
* 13.7 `Presence`_
.. _Modules section: https://matrix.org/docs/spec/client_server/r0.4.0.html#modules
.. _Typing Notifications: https://matrix.org/docs/spec/client_server/r0.4.0.html#id95
.. _Receipts: https://matrix.org/docs/spec/client_server/r0.4.0.html#id99
.. _Fully Read Markers: https://matrix.org/docs/spec/client_server/r0.4.0.html#fully-read-markers
.. _Presence: https://matrix.org/docs/spec/client_server/r0.4.0.html#id107
"""
# region 13.4 Typing Notifications
async def set_typing(self, room_id: RoomID, timeout: int = 0) -> None:
"""
This tells the server that the user is typing for the next N milliseconds where N is the
value specified in the timeout key. If the timeout is equal to or less than zero, it tells
the server that the user has stopped typing.
See also: `API reference `__
Args:
room_id: The ID of the room in which the user is typing.
timeout: The length of time in milliseconds to mark this user as typing.
"""
if timeout > 0:
content = {"typing": True, "timeout": timeout}
else:
content = {"typing": False}
await self.api.request(Method.PUT, Path.v3.rooms[room_id].typing[self.mxid], content)
# endregion
# region 13.5 Receipts
async def send_receipt(
self,
room_id: RoomID,
event_id: EventID,
receipt_type: str = "m.read",
) -> None:
"""
Update the marker for the given receipt type to the event ID specified.
See also: `API reference `__
Args:
room_id: The ID of the room which to send the receipt to.
event_id: The last event ID to acknowledge.
receipt_type: The type of receipt to send. Currently only ``m.read`` is supported.
"""
await self.api.request(Method.POST, Path.v3.rooms[room_id].receipt[receipt_type][event_id])
# endregion
# region 13.6 Fully read markers
async def set_fully_read_marker(
self,
room_id: RoomID,
fully_read: EventID,
read_receipt: EventID | None = None,
extra_content: dict[str, JSON] | None = None,
) -> None:
"""
Set the position of the read marker for the given room, and optionally send a new read
receipt.
See also: `API reference `__
Args:
room_id: The ID of the room which to set the read marker in.
fully_read: The last event up to which the user has either read all events or is not
interested in reading the events.
read_receipt: The new position for the user's normal read receipt, i.e. the last event
the user has seen.
extra_content: Additional fields to include in the ``/read_markers`` request.
"""
content = {
"m.fully_read": fully_read,
}
if read_receipt:
content["m.read"] = read_receipt
if extra_content:
content.update(extra_content)
await self.api.request(Method.POST, Path.v3.rooms[room_id].read_markers, content)
# endregion
# region 13.7 Presence
async def set_presence(
self, presence: PresenceState = PresenceState.ONLINE, status: str | None = None
) -> None:
"""
Set the current user's presence state. When setting the status, the activity time is updated
to reflect that activity; the client does not need to specify
:attr:`Presence.last_active_ago`.
See also: `API reference `__
Args:
presence: The new presence state to set.
status: The status message to attach to this state.
"""
content = {
"presence": presence.value,
}
if status:
content["status_msg"] = status
await self.api.request(Method.PUT, Path.v3.presence[self.mxid].status, content)
async def get_presence(self, user_id: UserID) -> PresenceEventContent:
"""
Get the presence info of a user.
See also: `API reference `__
Args:
user_id: The ID of the user whose presence info to get.
Returns:
The presence info of the given user.
"""
content = await self.api.request(Method.GET, Path.v3.presence[user_id].status)
try:
return PresenceEventContent.deserialize(content)
except SerializerError:
raise MatrixResponseError("Invalid presence in response")
# endregion
python-0.20.4/mautrix/client/api/modules/push_rules.py 0000664 0000000 0000000 00000006666 14547234302 0023073 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.api import Method, Path
from mautrix.types import (
PushAction,
PushCondition,
PushRule,
PushRuleID,
PushRuleKind,
PushRuleScope,
)
from ..base import BaseClientAPI
class PushRuleMethods(BaseClientAPI):
"""
Methods in section 13.13 Push Notifications of the spec. These methods are used for modifying
what triggers push notifications.
See also: `API reference `__"""
async def get_push_rule(
self, scope: PushRuleScope, kind: PushRuleKind, rule_id: PushRuleID
) -> PushRule:
"""
Retrieve a single specified push rule.
See also: `API reference `__
Args:
scope: The scope of the push rule.
kind: The kind of rule.
rule_id: The identifier of the rule.
Returns:
The push rule information.
"""
resp = await self.api.request(Method.GET, Path.v3.pushrules[scope][kind][rule_id])
return PushRule.deserialize(resp)
async def set_push_rule(
self,
scope: PushRuleScope,
kind: PushRuleKind,
rule_id: PushRuleID,
actions: list[PushAction],
pattern: str | None = None,
before: PushRuleID | None = None,
after: PushRuleID | None = None,
conditions: list[PushCondition] = None,
) -> None:
"""
Create or modify a push rule.
See also: `API reference `__
Args:
scope: The scope of the push rule.
kind: The kind of rule.
rule_id: The identifier for the rule.
before:
after:
actions: The actions to perform when the conditions for the rule are met.
pattern: The glob-style pattern to match against for ``content`` rules.
conditions: The conditions for the rule for ``underride`` and ``override`` rules.
"""
query = {}
if after:
query["after"] = after
if before:
query["before"] = before
content = {"actions": [act.serialize() for act in actions]}
if conditions:
content["conditions"] = [cond.serialize() for cond in conditions]
if pattern:
content["pattern"] = pattern
await self.api.request(
Method.PUT,
Path.v3.pushrules[scope][kind][rule_id],
query_params=query,
content=content,
)
async def remove_push_rule(
self, scope: PushRuleScope, kind: PushRuleKind, rule_id: PushRuleID
) -> None:
"""
Remove a push rule.
See also: `API reference `__
Args:
scope: The scope of the push rule.
kind: The kind of rule.
rule_id: The identifier of the rule.
"""
await self.api.request(Method.DELETE, Path.v3.pushrules[scope][kind][rule_id])
python-0.20.4/mautrix/client/api/modules/room_tag.py 0000664 0000000 0000000 00000006051 14547234302 0022475 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.api import Method, Path
from mautrix.types import RoomID, RoomTagAccountDataEventContent, RoomTagInfo, Serializable
from ..base import BaseClientAPI
class RoomTaggingMethods(BaseClientAPI):
"""
Methods in section 13.18 Room Tagging of the spec. These methods are used for organizing rooms
into tags for the local user.
See also: `API reference `__"""
async def get_room_tags(self, room_id: RoomID) -> RoomTagAccountDataEventContent:
"""
Get all tags for a specific room. This is equivalent to getting the ``m.tag`` account data
event for the room.
See also: `API reference `__
Args:
room_id: The room ID to get tags from.
Returns:
The m.tag account data event.
"""
resp = await self.api.request(Method.GET, Path.v3.user[self.mxid].rooms[room_id].tags)
return RoomTagAccountDataEventContent.deserialize(resp)
async def get_room_tag(self, room_id: RoomID, tag: str) -> RoomTagInfo | None:
"""
Get the info of a specific tag for a room.
Args:
room_id: The room to get the tag from.
tag: The tag to get.
Returns:
The info about the tag, or ``None`` if the room does not have the specified tag.
"""
resp = await self.get_room_tags(room_id)
try:
return resp.tags[tag]
except KeyError:
return None
async def set_room_tag(
self, room_id: RoomID, tag: str, info: RoomTagInfo | None = None
) -> None:
"""
Add or update a tag for a specific room.
See also: `API reference `__
Args:
room_id: The room ID to add the tag to.
tag: The tag to add.
info: Optionally, information like ordering within the tag.
"""
await self.api.request(
Method.PUT,
Path.v3.user[self.mxid].rooms[room_id].tags[tag],
content=(info.serialize() if isinstance(info, Serializable) else (info or {})),
)
async def remove_room_tag(self, room_id: RoomID, tag: str) -> None:
"""
Remove a tag from a specific room.
See also: `API reference `__
Args:
room_id: The room ID to remove the tag from.
tag: The tag to remove.
"""
await self.api.request(Method.DELETE, Path.v3.user[self.mxid].rooms[room_id].tags[tag])
python-0.20.4/mautrix/client/api/rooms.py 0000664 0000000 0000000 00000102251 14547234302 0020354 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Awaitable, Callable
import asyncio
from multidict import CIMultiDict
from mautrix.api import Method, Path
from mautrix.errors import (
MatrixRequestError,
MatrixResponseError,
MNotFound,
MNotJoined,
MRoomInUse,
)
from mautrix.types import (
JSON,
DirectoryPaginationToken,
EventID,
EventType,
Membership,
MemberStateEventContent,
PowerLevelStateEventContent,
RoomAlias,
RoomAliasInfo,
RoomCreatePreset,
RoomCreateStateEventContent,
RoomDirectoryResponse,
RoomDirectoryVisibility,
RoomID,
Serializable,
StateEvent,
StrippedStateEvent,
UserID,
)
from .base import BaseClientAPI
from .events import EventMethods
class RoomMethods(EventMethods, BaseClientAPI):
"""
Methods in section 8 Rooms of the spec. These methods are used for creating rooms, interacting
with the room directory and using the easy room metadata editing endpoints. Generic state
setting and sending events are in the :class:`EventMethods` (section 7) module.
See also: `API reference `__
"""
# region 8.1 Creation
# API reference: https://spec.matrix.org/v1.1/client-server-api/#creation
async def create_room(
self,
alias_localpart: str | None = None,
visibility: RoomDirectoryVisibility = RoomDirectoryVisibility.PRIVATE,
preset: RoomCreatePreset = RoomCreatePreset.PRIVATE,
name: str | None = None,
topic: str | None = None,
is_direct: bool = False,
invitees: list[UserID] | None = None,
initial_state: list[StateEvent | StrippedStateEvent | dict[str, JSON]] | None = None,
room_version: str = None,
creation_content: RoomCreateStateEventContent | dict[str, JSON] | None = None,
power_level_override: PowerLevelStateEventContent | dict[str, JSON] | None = None,
beeper_auto_join_invites: bool = False,
custom_request_fields: dict[str, Any] | None = None,
) -> RoomID:
"""
Create a new room with various configuration options.
See also: `API reference `__
Args:
alias_localpart: The desired room alias **local part**. If this is included, a room
alias will be created and mapped to the newly created room. The alias will belong
on the same homeserver which created the room. For example, if this was set to
"foo" and sent to the homeserver "example.com" the complete room alias would be
``#foo:example.com``.
visibility: A ``public`` visibility indicates that the room will be shown in the
published room list. A ``private`` visibility will hide the room from the published
room list. Defaults to ``private``. **NB:** This should not be confused with
``join_rules`` which also uses the word ``public``.
preset: Convenience parameter for setting various default state events based on a
preset. Defaults to private (invite-only).
name: If this is included, an ``m.room.name`` event will be sent into the room to
indicate the name of the room. See `Room Events`_ for more information on
``m.room.name``.
topic: If this is included, an ``m.room.topic`` event will be sent into the room to
indicate the topic for the room. See `Room Events`_ for more information on
``m.room.topic``.
is_direct: This flag makes the server set the ``is_direct`` flag on the
`m.room.member`_ events sent to the users in ``invite`` and ``invite_3pid``. See
`Direct Messaging`_ for more information.
invitees: A list of user IDs to invite to the room. This will tell the server to invite
everyone in the list to the newly created room.
initial_state: A list of state events to set in the new room. This allows the user to
override the default state events set in the new room. The expected format of the
state events are an object with type, state_key and content keys set.
Takes precedence over events set by ``is_public``, but gets overriden by ``name``
and ``topic keys``.
room_version: The room version to set for the room. If not provided, the homeserver
will use its configured default.
creation_content: Extra keys, such as ``m.federate``, to be added to the
``m.room.create`` event. The server will ignore ``creator`` and ``room_version``.
Future versions of the specification may allow the server to ignore other keys.
power_level_override: The power level content to override in the default power level
event. This object is applied on top of the generated ``m.room.power_levels`` event
content prior to it being sent to the room. Defaults to overriding nothing.
beeper_auto_join_invites: A Beeper-specific extension which auto-joins all members in
the invite array instead of sending invites.
custom_request_fields: Additional fields to put in the top-level /createRoom content.
Non-custom fields take precedence over fields here.
Returns:
The ID of the newly created room.
Raises:
MatrixResponseError: If the response does not contain a ``room_id`` field.
.. _Room Events: https://spec.matrix.org/v1.1/client-server-api/#room-events
.. _Direct Messaging: https://spec.matrix.org/v1.1/client-server-api/#direct-messaging
.. _m.room.create: https://spec.matrix.org/v1.1/client-server-api/#mroomcreate
.. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember
"""
content = {
**(custom_request_fields or {}),
"visibility": visibility.value,
"is_direct": is_direct,
"preset": preset.value,
}
if alias_localpart:
content["room_alias_name"] = alias_localpart
if invitees:
content["invite"] = invitees
if beeper_auto_join_invites:
content["com.beeper.auto_join_invites"] = True
if name:
content["name"] = name
if topic:
content["topic"] = topic
if initial_state:
content["initial_state"] = [
event.serialize() if isinstance(event, Serializable) else event
for event in initial_state
]
if room_version:
content["room_version"] = room_version
if creation_content:
content["creation_content"] = (
creation_content.serialize()
if isinstance(creation_content, Serializable)
else creation_content
)
# Remove keys that the server will ignore anyway
content["creation_content"].pop("room_version", None)
if power_level_override:
content["power_level_content_override"] = (
power_level_override.serialize()
if isinstance(power_level_override, Serializable)
else power_level_override
)
resp = await self.api.request(Method.POST, Path.v3.createRoom, content)
try:
return resp["room_id"]
except KeyError:
raise MatrixResponseError("`room_id` not in response.")
# endregion
# region 8.2 Room aliases
# API reference: https://spec.matrix.org/v1.1/client-server-api/#room-aliases
async def add_room_alias(
self, room_id: RoomID, alias_localpart: str, override: bool = False
) -> None:
"""
Create a new mapping from room alias to room ID.
See also: `API reference `__
Args:
room_id: The room ID to set.
alias_localpart: The localpart of the room alias to set.
override: Whether or not the alias should be removed and the request retried if the
server responds with HTTP 409 Conflict
"""
room_alias = f"#{alias_localpart}:{self.domain}"
content = {"room_id": room_id}
try:
await self.api.request(Method.PUT, Path.v3.directory.room[room_alias], content)
except MatrixRequestError as e:
if e.http_status == 409:
if override:
await self.remove_room_alias(alias_localpart)
await self.api.request(Method.PUT, Path.v3.directory.room[room_alias], content)
else:
raise MRoomInUse(e.http_status, e.message) from e
else:
raise
async def remove_room_alias(self, alias_localpart: str, raise_404: bool = False) -> None:
"""
Remove a mapping of room alias to room ID.
Servers may choose to implement additional access control checks here, for instance that
room aliases can only be deleted by their creator or server administrator.
See also: `API reference `__
Args:
alias_localpart: The room alias to remove.
raise_404: Whether 404 errors should be raised as exceptions instead of ignored.
"""
room_alias = f"#{alias_localpart}:{self.domain}"
try:
await self.api.request(Method.DELETE, Path.v3.directory.room[room_alias])
except MNotFound:
if raise_404:
raise
# else: ignore
async def resolve_room_alias(self, room_alias: RoomAlias) -> RoomAliasInfo:
"""
Request the server to resolve a room alias to a room ID.
The server will use the federation API to resolve the alias if the domain part of the alias
does not correspond to the server's own domain.
See also: `API reference `__
Args:
room_alias: The room alias.
Returns:
The room ID and a list of servers that are aware of the room.
"""
content = await self.api.request(Method.GET, Path.v3.directory.room[room_alias])
return RoomAliasInfo.deserialize(content)
# endregion
# region 8.4 Room membership
# API reference: https://spec.matrix.org/v1.1/client-server-api/#room-membership
async def get_joined_rooms(self) -> list[RoomID]:
"""Get the list of rooms the user is in."""
content = await self.api.request(Method.GET, Path.v3.joined_rooms)
try:
return content["joined_rooms"]
except KeyError:
raise MatrixResponseError("`joined_rooms` not in response.")
# region 8.4.1 Joining rooms
# API reference: https://spec.matrix.org/v1.1/client-server-api/#joining-rooms
async def join_room_by_id(
self,
room_id: RoomID,
third_party_signed: JSON = None,
extra_content: dict[str, Any] | None = None,
) -> RoomID:
"""
Start participating in a room, i.e. join it by its ID.
See also: `API reference `__
Args:
room_id: The ID of the room to join.
third_party_signed: A signature of an ``m.third_party_invite`` token to prove that this
user owns a third party identity which has been invited to the room.
extra_content: Additional properties for the join event content.
If a non-empty dict is passed, the join event will be created using
the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /join``.
Returns:
The ID of the room the user joined.
"""
if extra_content:
await self.send_member_event(
room_id, self.mxid, Membership.JOIN, extra_content=extra_content
)
return room_id
content = await self.api.request(
Method.POST,
Path.v3.rooms[room_id].join,
{"third_party_signed": third_party_signed} if third_party_signed is not None else None,
)
try:
return content["room_id"]
except KeyError:
raise MatrixResponseError("`room_id` not in response.")
async def join_room(
self,
room_id_or_alias: RoomID | RoomAlias,
servers: list[str] | None = None,
third_party_signed: JSON = None,
max_retries: int = 4,
) -> RoomID:
"""
Start participating in a room, i.e. join it by its ID or alias, with an optional list of
servers to ask about the ID from.
See also: `API reference `__
Args:
room_id_or_alias: The ID of the room to join, or an alias pointing to the room.
servers: A list of servers to ask about the room ID to join. Not applicable for aliases,
as aliases already contain the necessary server information.
third_party_signed: A signature of an ``m.third_party_invite`` token to prove that this
user owns a third party identity which has been invited to the room.
max_retries: The maximum number of retries. Used to circumvent a Synapse bug with
accepting invites over federation. 0 means only one join call will be attempted.
See: `matrix-org/synapse#2807 `__
Returns:
The ID of the room the user joined.
"""
max_retries = max(0, max_retries)
tries = 0
content = (
{"third_party_signed": third_party_signed} if third_party_signed is not None else None
)
query_params = CIMultiDict()
for server_name in servers or []:
query_params.add("server_name", server_name)
while tries <= max_retries:
try:
content = await self.api.request(
Method.POST,
Path.v3.join[room_id_or_alias],
content=content,
query_params=query_params,
)
break
except MatrixRequestError:
tries += 1
if tries <= max_retries:
wait = tries * 10
self.log.exception(
f"Failed to join room {room_id_or_alias}, retrying in {wait} seconds..."
)
await asyncio.sleep(wait)
else:
raise
try:
return content["room_id"]
except KeyError:
raise MatrixResponseError("`room_id` not in response.")
fill_member_event_callback: Callable[
[RoomID, UserID, MemberStateEventContent], Awaitable[MemberStateEventContent | None]
] | None
async def fill_member_event(
self, room_id: RoomID, user_id: UserID, content: MemberStateEventContent
) -> MemberStateEventContent | None:
"""
Fill a membership event content that is going to be sent in :meth:`send_member_event`.
This is used to set default fields like the displayname and avatar, which are usually set
by the server in the sugar membership endpoints like /join and /invite, but are not set
automatically when sending member events manually.
This default implementation only calls :attr:`fill_member_event_callback`.
Args:
room_id: The room where the member event is going to be sent.
user_id: The user whose membership is changing.
content: The new member event content.
Returns:
The filled member event content.
"""
if self.fill_member_event_callback is not None:
return await self.fill_member_event_callback(room_id, user_id, content)
return None
async def send_member_event(
self,
room_id: RoomID,
user_id: UserID,
membership: Membership,
extra_content: dict[str, JSON] | None = None,
) -> EventID:
"""
Send a membership event manually.
Args:
room_id: The room to send the event to.
user_id: The user whose membership to change.
membership: The membership status.
extra_content: Additional content to put in the member event.
Returns:
The event ID of the new member event.
"""
content = MemberStateEventContent(membership=membership)
for key, value in extra_content.items():
content[key] = value
content = await self.fill_member_event(room_id, user_id, content) or content
return await self.send_state_event(
room_id, EventType.ROOM_MEMBER, content=content, state_key=user_id, ensure_joined=False
)
async def invite_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str | None = None,
extra_content: dict[str, JSON] | None = None,
) -> None:
"""
Invite a user to participate in a particular room. They do not start participating in the
room until they actually join the room.
Only users currently in the room can invite other users to join that room.
If the user was invited to the room, the homeserver will add a `m.room.member`_ event to
the room.
See also: `API reference `__
.. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember
Args:
room_id: The ID of the room to which to invite the user.
user_id: The fully qualified user ID of the invitee.
reason: The reason the user was invited. This will be supplied as the ``reason`` on
the `m.room.member`_ event.
extra_content: Additional properties for the invite event content.
If a non-empty dict is passed, the invite event will be created using
the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /invite``.
"""
if extra_content:
await self.send_member_event(
room_id, user_id, Membership.INVITE, extra_content=extra_content
)
else:
data = {"user_id": user_id}
if reason:
data["reason"] = reason
await self.api.request(Method.POST, Path.v3.rooms[room_id].invite, content=data)
# endregion
# region 8.4.2 Leaving rooms
# API reference: https://spec.matrix.org/v1.1/client-server-api/#leaving-rooms
async def leave_room(
self,
room_id: RoomID,
reason: str | None = None,
extra_content: dict[str, JSON] | None = None,
raise_not_in_room: bool = False,
) -> None:
"""
Stop participating in a particular room, i.e. leave the room.
If the user was already in the room, they will no longer be able to see new events in the
room. If the room requires an invite to join, they will need to be re-invited before they
can re-join.
If the user was invited to the room, but had not joined, this call serves to reject the
invite.
The user will still be allowed to retrieve history from the room which they were previously
allowed to see.
See also: `API reference `__
Args:
room_id: The ID of the room to leave.
reason: The reason for leaving the room. This will be supplied as the ``reason`` on
the updated `m.room.member`_ event.
extra_content: Additional properties for the leave event content.
If a non-empty dict is passed, the leave event will be created using
the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /leave``.
raise_not_in_room: Should errors about the user not being in the room be raised?
"""
try:
if extra_content:
await self.send_member_event(
room_id, self.mxid, Membership.LEAVE, extra_content=extra_content
)
else:
data = {}
if reason:
data["reason"] = reason
await self.api.request(Method.POST, Path.v3.rooms[room_id].leave, content=data)
except MNotJoined:
if raise_not_in_room:
raise
except MatrixRequestError as e:
# TODO remove this once MSC3848 is released and minimum spec version is bumped
if "not in room" not in e.message or raise_not_in_room:
raise
async def knock_room(
self,
room_id_or_alias: RoomID | RoomAlias,
reason: str | None = None,
servers: list[str] | None = None,
) -> RoomID:
"""
Knock on a room, i.e. request to join it by its ID or alias, with an optional list of
servers to ask about the ID from.
See also: `API reference `__
Args:
room_id_or_alias: The ID of the room to knock on, or an alias pointing to the room.
reason: The reason for knocking on the room. This will be supplied as the ``reason`` on
the updated `m.room.member`_ event.
servers: A list of servers to ask about the room ID to knock. Not applicable for aliases,
as aliases already contain the necessary server information.
Returns:
The ID of the room the user knocked on.
"""
data = {}
if reason:
data["reason"] = reason
query_params = CIMultiDict()
for server_name in servers or []:
query_params.add("server_name", server_name)
content = await self.api.request(
Method.POST,
Path.v3.knock[room_id_or_alias],
content=data,
query_params=query_params,
)
try:
return content["room_id"]
except KeyError:
raise MatrixResponseError("`room_id` not in response.")
async def forget_room(self, room_id: RoomID) -> None:
"""
Stop remembering a particular room, i.e. forget it.
In general, history is a first class citizen in Matrix. After this API is called, however,
a user will no longer be able to retrieve history for this room. If all users on a
homeserver forget a room, the room is eligible for deletion from that homeserver.
If the user is currently joined to the room, they must leave the room before calling this
API.
See also: `API reference `__
Args:
room_id: The ID of the room to forget.
"""
await self.api.request(Method.POST, Path.v3.rooms[room_id].forget)
async def kick_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
"""
Kick a user from the room.
The caller must have the required power level in order to perform this operation.
Kicking a user adjusts the target member's membership state to be ``leave`` with an optional
``reason``. Like with other membership changes, a user can directly adjust the target
member's state by calling :meth:`EventMethods.send_state_event` with
:attr:`EventType.ROOM_MEMBER` as the event type and the ``user_id`` as the state key.
See also: `API reference `__
Args:
room_id: The ID of the room from which the user should be kicked.
user_id: The fully qualified user ID of the user being kicked.
reason: The reason the user has been kicked. This will be supplied as the ``reason`` on
the target's updated `m.room.member`_ event.
extra_content: Additional properties for the kick event content.
If a non-empty dict is passed, the kick event will be created using
the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /kick``.
.. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember
"""
if extra_content:
if reason and "reason" not in extra_content:
extra_content["reason"] = reason
await self.send_member_event(
room_id, user_id, Membership.LEAVE, extra_content=extra_content
)
return
await self.api.request(
Method.POST, Path.v3.rooms[room_id].kick, {"user_id": user_id, "reason": reason}
)
# endregion
# region 8.4.2.1 Banning users in a room
# API reference: https://spec.matrix.org/v1.1/client-server-api/#banning-users-in-a-room
async def ban_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
"""
Ban a user in the room. If the user is currently in the room, also kick them. When a user is
banned from a room, they may not join it or be invited to it until they are unbanned. The
caller must have the required power level in order to perform this operation.
See also: `API reference `__
Args:
room_id: The ID of the room from which the user should be banned.
user_id: The fully qualified user ID of the user being banned.
reason: The reason the user has been kicked. This will be supplied as the ``reason`` on
the target's updated `m.room.member`_ event.
extra_content: Additional properties for the ban event content.
If a non-empty dict is passed, the ban will be created using
the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /ban``.
.. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember
"""
if extra_content:
if reason and "reason" not in extra_content:
extra_content["reason"] = reason
await self.send_member_event(
room_id, user_id, Membership.BAN, extra_content=extra_content
)
return
await self.api.request(
Method.POST, Path.v3.rooms[room_id].ban, {"user_id": user_id, "reason": reason}
)
async def unban_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
"""
Unban a user from the room. This allows them to be invited to the room, and join if they
would otherwise be allowed to join according to its join rules. The caller must have the
required power level in order to perform this operation.
See also: `API reference `__
Args:
room_id: The ID of the room from which the user should be unbanned.
user_id: The fully qualified user ID of the user being banned.
reason: The reason the user has been unbanned. This will be supplied as the ``reason`` on
the target's updated `m.room.member`_ event.
extra_content: Additional properties for the unban (leave) event content.
If a non-empty dict is passed, the unban will be created using
the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /unban``.
"""
if extra_content:
if reason and "reason" not in extra_content:
extra_content["reason"] = reason
await self.send_member_event(
room_id, user_id, Membership.LEAVE, extra_content=extra_content
)
return
await self.api.request(
Method.POST, Path.v3.rooms[room_id].unban, {"user_id": user_id, "reason": reason}
)
# endregion
# endregion
# region 8.5 Listing rooms
# API reference: https://spec.matrix.org/v1.1/client-server-api/#listing-rooms
async def get_room_directory_visibility(self, room_id: RoomID) -> RoomDirectoryVisibility:
"""
Get the visibility of the room on the server's public room directory.
See also: `API reference `__
Args:
room_id: The ID of the room.
Returns:
The visibility of the room in the directory.
"""
resp = await self.api.request(Method.GET, Path.v3.directory.list.room[room_id])
try:
return RoomDirectoryVisibility(resp["visibility"])
except KeyError:
raise MatrixResponseError("`visibility` not in response.")
except ValueError:
raise MatrixResponseError(
f"Invalid value for `visibility` in response: {resp['visibility']}"
)
async def set_room_directory_visibility(
self, room_id: RoomID, visibility: RoomDirectoryVisibility
) -> None:
"""
Set the visibility of the room in the server's public room directory.
Servers may choose to implement additional access control checks here, for instance that
room visibility can only be changed by the room creator or a server administrator.
Args:
room_id: The ID of the room.
visibility: The new visibility setting for the room.
.. _API reference: https://spec.matrix.org/v1.1/client-server-api/#put_matrixclientv3directorylistroomroomid
"""
await self.api.request(
Method.PUT,
Path.v3.directory.list.room[room_id],
{
"visibility": visibility.value,
},
)
async def get_room_directory(
self,
limit: int | None = None,
server: str | None = None,
since: DirectoryPaginationToken | None = None,
search_query: str | None = None,
include_all_networks: bool | None = None,
third_party_instance_id: str | None = None,
) -> RoomDirectoryResponse:
"""
Get a list of public rooms from the server's room directory.
See also: `API reference `__
Args:
limit: The maximum number of results to return.
server: The server to fetch the room directory from. Defaults to the user's server.
since: A pagination token from a previous request, allowing clients to get the next (or
previous) batch of rooms. The direction of pagination is specified solely by which
token is supplied, rather than via an explicit flag.
search_query: A string to search for in the room metadata, e.g. name, topic, canonical
alias etc.
include_all_networks: Whether or not to include rooms from all known networks/protocols
from application services on the homeserver. Defaults to false.
third_party_instance_id: The specific third party network/protocol to request from the
homeserver. Can only be used if ``include_all_networks`` is false.
Returns:
The relevant pagination tokens, an estimate of the total number of public rooms and the
paginated chunk of public rooms.
"""
method = (
Method.GET
if (
search_query is None
and include_all_networks is None
and third_party_instance_id is None
)
else Method.POST
)
content = {}
if limit is not None:
content["limit"] = limit
if since is not None:
content["since"] = since
if search_query is not None:
content["filter"] = {"generic_search_term": search_query}
if include_all_networks is not None:
content["include_all_networks"] = include_all_networks
if third_party_instance_id is not None:
content["third_party_instance_id"] = third_party_instance_id
query_params = {"server": server} if server is not None else None
content = await self.api.request(
method, Path.v3.publicRooms, content, query_params=query_params
)
return RoomDirectoryResponse.deserialize(content)
# endregion
python-0.20.4/mautrix/client/api/user_data.py 0000664 0000000 0000000 00000015472 14547234302 0021174 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError, MNotFound
from mautrix.types import ContentURI, Member, SerializerError, User, UserID, UserSearchResults
from .base import BaseClientAPI
class UserDataMethods(BaseClientAPI):
"""
Methods in section 10 User Data of the spec. These methods are used for setting and getting user
metadata and searching for users.
See also: `API reference `__
"""
# region 10.1 User Directory
# API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#user-directory
async def search_users(self, search_query: str, limit: int | None = 10) -> UserSearchResults:
"""
Performs a search for users on the homeserver. The homeserver may determine which subset of
users are searched, however the homeserver MUST at a minimum consider the users the
requesting user shares a room with and those who reside in public rooms (known to the
homeserver). The search MUST consider local users to the homeserver, and SHOULD query remote
users as part of the search.
The search is performed case-insensitively on user IDs and display names preferably using a
collation determined based upon the Accept-Language header provided in the request, if
present.
See also: `API reference `__
Args:
search_query: The query to search for.
limit: The maximum number of results to return.
Returns:
The results of the search and whether or not the results were limited.
"""
content = await self.api.request(
Method.POST,
Path.v3.user_directory.search,
{
"search_term": search_query,
"limit": limit,
},
)
try:
return UserSearchResults(
[User.deserialize(user) for user in content["results"]], content["limited"]
)
except SerializerError as e:
raise MatrixResponseError("Invalid user in search results") from e
except KeyError:
if "results" not in content:
raise MatrixResponseError("`results` not in content.")
elif "limited" not in content:
raise MatrixResponseError("`limited` not in content.")
raise
# endregion
# region 10.2 Profiles
# API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#profiles
async def set_displayname(self, displayname: str, check_current: bool = True) -> None:
"""
Set the display name of the current user.
See also: `API reference `__
Args:
displayname: The new display name for the user.
check_current: Whether or not to check if the displayname is already set.
"""
if check_current and await self.get_displayname(self.mxid) == displayname:
return
await self.api.request(
Method.PUT,
Path.v3.profile[self.mxid].displayname,
{
"displayname": displayname,
},
)
async def get_displayname(self, user_id: UserID) -> str | None:
"""
Get the display name of a user.
See also: `API reference `__
Args:
user_id: The ID of the user whose display name to get.
Returns:
The display name of the given user.
"""
try:
content = await self.api.request(Method.GET, Path.v3.profile[user_id].displayname)
except MNotFound:
return None
try:
return content["displayname"]
except KeyError:
return None
async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = True) -> None:
"""
Set the avatar of the current user.
See also: `API reference `__
Args:
avatar_url: The ``mxc://`` URI to the new avatar.
check_current: Whether or not to check if the avatar is already set.
"""
if check_current and await self.get_avatar_url(self.mxid) == avatar_url:
return
await self.api.request(
Method.PUT,
Path.v3.profile[self.mxid].avatar_url,
{
"avatar_url": avatar_url,
},
)
async def get_avatar_url(self, user_id: UserID) -> ContentURI | None:
"""
Get the avatar URL of a user.
See also: `API reference `__
Args:
user_id: The ID of the user whose avatar to get.
Returns:
The ``mxc://`` URI to the user's avatar.
"""
try:
content = await self.api.request(Method.GET, Path.v3.profile[user_id].avatar_url)
except MNotFound:
return None
try:
return content["avatar_url"]
except KeyError:
return None
async def get_profile(self, user_id: UserID) -> Member:
"""
Get the combined profile information for a user.
See also: `API reference `__
Args:
user_id: The ID of the user whose profile to get.
Returns:
The profile information of the given user.
"""
content = await self.api.request(Method.GET, Path.v3.profile[user_id])
try:
return Member.deserialize(content)
except SerializerError as e:
raise MatrixResponseError("Invalid member in response") from e
# endregion
# region Beeper Custom Fields API
async def beeper_update_profile(self, custom_fields: dict[str, Any]) -> None:
"""
Set custom fields on the user's profile. Only works on Hungryserv.
Args:
custom_fields: A dictionary of fields to set in the custom content of the profile.
"""
await self.api.request(Method.PATCH, Path.v3.profile[self.mxid], custom_fields)
# endregion
python-0.20.4/mautrix/client/client.py 0000664 0000000 0000000 00000003500 14547234302 0017717 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix import __optional_imports__
from mautrix.types import Event, EventType, StateEvent
from .encryption_manager import DecryptionDispatcher, EncryptingAPI
from .state_store import StateStore, SyncStore
from .syncer import Syncer
if __optional_imports__:
from .. import crypto as crypt
class Client(EncryptingAPI, Syncer):
"""Client is a high-level wrapper around the client API."""
def __init__(
self,
*args,
sync_store: SyncStore | None = None,
state_store: StateStore | None = None,
**kwargs,
) -> None:
EncryptingAPI.__init__(self, *args, state_store=state_store, **kwargs)
Syncer.__init__(self, sync_store)
self.add_event_handler(EventType.ALL, self._update_state)
async def _update_state(self, evt: Event) -> None:
if not isinstance(evt, StateEvent) or not self.state_store:
return
await self.state_store.update_state(evt)
@EncryptingAPI.crypto.setter
def crypto(self, crypto: crypt.OlmMachine | None) -> None:
"""
Set the olm machine and enable the automatic event decryptor.
Args:
crypto: The olm machine to use for crypto
Raises:
ValueError: if :attr:`state_store` is not set.
"""
if not self.state_store:
raise ValueError("State store must be set to use encryption")
self._crypto = crypto
if self.crypto_enabled:
self.add_dispatcher(DecryptionDispatcher)
else:
self.remove_dispatcher(DecryptionDispatcher)
python-0.20.4/mautrix/client/dispatcher.py 0000664 0000000 0000000 00000004710 14547234302 0020573 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import ClassVar
from abc import ABC, abstractmethod
from mautrix.types import Event, EventType, Membership, StateEvent
from . import syncer
class Dispatcher(ABC):
client: syncer.Syncer
def __init__(self, client: syncer.Syncer) -> None:
self.client = client
@abstractmethod
def register(self) -> None:
pass
@abstractmethod
def unregister(self) -> None:
pass
class SimpleDispatcher(Dispatcher, ABC):
event_type: ClassVar[EventType]
def register(self) -> None:
self.client.add_event_handler(self.event_type, self.handle)
def unregister(self) -> None:
self.client.remove_event_handler(self.event_type, self.handle)
@abstractmethod
async def handle(self, evt: Event) -> None:
pass
class MembershipEventDispatcher(SimpleDispatcher):
event_type = EventType.ROOM_MEMBER
async def handle(self, evt: StateEvent) -> None:
if evt.type != EventType.ROOM_MEMBER:
return
if evt.content.membership == Membership.JOIN:
if evt.prev_content.membership != Membership.JOIN:
change_type = syncer.InternalEventType.JOIN
else:
change_type = syncer.InternalEventType.PROFILE_CHANGE
elif evt.content.membership == Membership.INVITE:
change_type = syncer.InternalEventType.INVITE
elif evt.content.membership == Membership.LEAVE:
if evt.prev_content.membership == Membership.BAN:
change_type = syncer.InternalEventType.UNBAN
elif evt.prev_content.membership == Membership.INVITE:
if evt.state_key == evt.sender:
change_type = syncer.InternalEventType.REJECT_INVITE
else:
change_type = syncer.InternalEventType.DISINVITE
elif evt.state_key == evt.sender:
change_type = syncer.InternalEventType.LEAVE
else:
change_type = syncer.InternalEventType.KICK
elif evt.content.membership == Membership.BAN:
change_type = syncer.InternalEventType.BAN
else:
return
self.client.dispatch_manual_event(change_type, evt)
python-0.20.4/mautrix/client/encryption_manager.py 0000664 0000000 0000000 00000015723 14547234302 0022337 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import asyncio
import logging
from mautrix import __optional_imports__
from mautrix.errors import DecryptionError, EncryptionError, MNotFound
from mautrix.types import (
EncryptedEvent,
EncryptedMegolmEventContent,
EventContent,
EventID,
EventType,
RoomID,
)
from mautrix.util.logging import TraceLogger
from . import client, dispatcher, store_updater
if __optional_imports__:
from .. import crypto as crypt
class EncryptingAPI(store_updater.StoreUpdatingAPI):
"""
EncryptingAPI is a wrapper around StoreUpdatingAPI that automatically encrypts messages.
For automatic decryption, see :class:`DecryptionDispatcher`.
"""
_crypto: crypt.OlmMachine | None
encryption_blacklist: set[EventType] = {EventType.REACTION}
"""A set of event types which shouldn't be encrypted even in encrypted rooms."""
crypto_log: TraceLogger = logging.getLogger("mau.client.crypto")
"""The logger to use for crypto-related things."""
_share_session_events: dict[RoomID, asyncio.Event]
def __init__(self, *args, crypto_log: TraceLogger | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
if crypto_log:
self.crypto_log = crypto_log
self._crypto = None
self._share_session_events = {}
@property
def crypto(self) -> crypt.OlmMachine | None:
"""The :class:`crypto.OlmMachine` to use for e2ee stuff."""
return self._crypto
@crypto.setter
def crypto(self, crypto: crypt.OlmMachine) -> None:
"""
Args:
crypto: The olm machine to use for crypto
Raises:
ValueError: if :attr:`state_store` is not set.
"""
if not self.state_store:
raise ValueError("State store must be set to use encryption")
self._crypto = crypto
@property
def crypto_enabled(self) -> bool:
"""``True`` if both the olm machine and state store are set properly."""
return bool(self.crypto) and bool(self.state_store)
async def encrypt(
self, room_id: RoomID, event_type: EventType, content: EventContent
) -> EncryptedMegolmEventContent:
"""
Encrypt a message for the given room. Automatically creates and shares a group session
if necessary.
Args:
room_id: The room to encrypt the event to.
event_type: The type of event.
content: The content of the event.
Returns:
The content of the encrypted event.
"""
try:
return await self.crypto.encrypt_megolm_event(room_id, event_type, content)
except EncryptionError:
self.crypto_log.debug("Got EncryptionError, sharing group session and trying again")
await self.share_group_session(room_id)
self.crypto_log.trace(
f"Shared group session, now trying to encrypt in {room_id} again"
)
return await self.crypto.encrypt_megolm_event(room_id, event_type, content)
async def _share_session_lock(self, room_id: RoomID) -> bool:
try:
event = self._share_session_events[room_id]
except KeyError:
self._share_session_events[room_id] = asyncio.Event()
return True
else:
await event.wait()
return False
async def share_group_session(self, room_id: RoomID) -> None:
"""
Create and share a Megolm session for the given room.
Args:
room_id: The room to share the session for.
"""
if not await self._share_session_lock(room_id):
self.log.silly("Group session was already being shared, so didn't share new one")
return
try:
if not await self.state_store.has_full_member_list(room_id):
self.crypto_log.trace(
f"Don't have full member list for {room_id}, fetching from server"
)
members = list((await self.get_joined_members(room_id)).keys())
else:
self.crypto_log.trace(f"Fetching member list for {room_id} from state store")
members = await self.state_store.get_members(room_id)
await self.crypto.share_group_session(room_id, members)
finally:
self._share_session_events.pop(room_id).set()
async def send_message_event(
self,
room_id: RoomID,
event_type: EventType,
content: EventContent,
disable_encryption: bool = False,
**kwargs,
) -> EventID:
"""
A wrapper around :meth:`ClientAPI.send_message_event` that encrypts messages if the target
room is encrypted.
Args:
room_id: The room to send the message to.
event_type: The unencrypted event type.
content: The unencrypted event content.
disable_encryption: Set to ``True`` if you want to force-send an unencrypted message.
**kwargs: Additional parameters to pass to :meth:`ClientAPI.send_message_event`.
Returns:
The ID of the event that was sent.
"""
if self.crypto and event_type not in self.encryption_blacklist and not disable_encryption:
is_encrypted = await self.state_store.is_encrypted(room_id)
if is_encrypted is None:
try:
await self.get_state_event(room_id, EventType.ROOM_ENCRYPTION)
is_encrypted = True
except MNotFound:
is_encrypted = False
if is_encrypted:
content = await self.encrypt(room_id, event_type, content)
event_type = EventType.ROOM_ENCRYPTED
return await super().send_message_event(room_id, event_type, content, **kwargs)
class DecryptionDispatcher(dispatcher.SimpleDispatcher):
"""
DecryptionDispatcher is a dispatcher that can be used with a :class:`client.Syncer`
to automatically decrypt events and dispatch the unencrypted versions for event handlers.
The easiest way to use this is with :class:`client.Client`, which automatically registers
this dispatcher when :attr:`EncryptingAPI.crypto` is set.
"""
event_type = EventType.ROOM_ENCRYPTED
client: client.Client
async def handle(self, evt: EncryptedEvent) -> None:
try:
self.client.crypto_log.trace(f"Decrypting {evt.event_id} in {evt.room_id}...")
decrypted = await self.client.crypto.decrypt_megolm_event(evt)
except DecryptionError as e:
self.client.crypto_log.warning(f"Failed to decrypt {evt.event_id}: {e}")
return
self.client.crypto_log.trace(f"Decrypted {evt.event_id}: {decrypted}")
self.client.dispatch_event(decrypted, evt.source)
python-0.20.4/mautrix/client/state_store/ 0000775 0000000 0000000 00000000000 14547234302 0020425 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/client/state_store/__init__.py 0000664 0000000 0000000 00000000432 14547234302 0022535 0 ustar 00root root 0000000 0000000 from .abstract import StateStore
from .file import FileStateStore
from .memory import MemoryStateStore
from .sync import MemorySyncStore, SyncStore
__all__ = [
"StateStore",
"FileStateStore",
"MemoryStateStore",
"MemorySyncStore",
"SyncStore",
"asyncpg",
]
python-0.20.4/mautrix/client/state_store/abstract.py 0000664 0000000 0000000 00000013313 14547234302 0022603 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Awaitable
from abc import ABC, abstractmethod
from mautrix.types import (
EventType,
Member,
Membership,
MemberStateEventContent,
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
StateEvent,
UserID,
)
class StateStore(ABC):
async def open(self) -> None:
pass
async def close(self) -> None:
await self.flush()
async def flush(self) -> None:
pass
@abstractmethod
async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
pass
@abstractmethod
async def set_member(
self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent
) -> None:
pass
@abstractmethod
async def set_membership(
self, room_id: RoomID, user_id: UserID, membership: Membership
) -> None:
pass
@abstractmethod
async def get_member_profiles(
self,
room_id: RoomID,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> dict[UserID, Member]:
pass
async def get_members(
self,
room_id: RoomID,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> list[UserID]:
profiles = await self.get_member_profiles(room_id, memberships)
return list(profiles.keys())
async def get_members_filtered(
self,
room_id: RoomID,
not_prefix: str,
not_suffix: str,
not_id: str,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> list[UserID]:
"""
A filtered version of get_members that only returns user IDs that aren't operated by a
bridge. This should return the same as :meth:`get_members`, except users where the user ID
is equal to not_id OR it starts with not_prefix AND ends with not_suffix.
The default implementation simply calls :meth:`get_members`, but databases can implement
this more efficiently.
Args:
room_id: The room ID to find.
not_prefix: The user ID prefix to disallow.
not_suffix: The user ID suffix to disallow.
not_id: The user ID to disallow.
memberships: The membership states to include.
"""
members = await self.get_members(room_id, memberships=memberships)
return [
user_id
for user_id in members
if user_id != not_id
and not (user_id.startswith(not_prefix) and user_id.endswith(not_suffix))
]
@abstractmethod
async def set_members(
self,
room_id: RoomID,
members: dict[UserID, Member | MemberStateEventContent],
only_membership: Membership | None = None,
) -> None:
pass
@abstractmethod
async def has_full_member_list(self, room_id: RoomID) -> bool:
pass
@abstractmethod
async def has_power_levels_cached(self, room_id: RoomID) -> bool:
pass
@abstractmethod
async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None:
pass
@abstractmethod
async def set_power_levels(
self, room_id: RoomID, content: PowerLevelStateEventContent
) -> None:
pass
@abstractmethod
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
pass
@abstractmethod
async def is_encrypted(self, room_id: RoomID) -> bool | None:
pass
@abstractmethod
async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
pass
@abstractmethod
async def set_encryption_info(
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, any]
) -> None:
pass
async def update_state(self, evt: StateEvent) -> None:
if evt.type == EventType.ROOM_POWER_LEVELS:
await self.set_power_levels(evt.room_id, evt.content)
elif evt.type == EventType.ROOM_MEMBER:
evt.unsigned["mautrix_prev_membership"] = await self.get_member(
evt.room_id, UserID(evt.state_key)
)
await self.set_member(evt.room_id, UserID(evt.state_key), evt.content)
elif evt.type == EventType.ROOM_ENCRYPTION:
await self.set_encryption_info(evt.room_id, evt.content)
async def get_membership(self, room_id: RoomID, user_id: UserID) -> Membership:
member = await self.get_member(room_id, user_id)
return member.membership if member else Membership.LEAVE
async def is_joined(self, room_id: RoomID, user_id: UserID) -> bool:
return (await self.get_membership(room_id, user_id)) == Membership.JOIN
def joined(self, room_id: RoomID, user_id: UserID) -> Awaitable[None]:
return self.set_membership(room_id, user_id, Membership.JOIN)
def invited(self, room_id: RoomID, user_id: UserID) -> Awaitable[None]:
return self.set_membership(room_id, user_id, Membership.INVITE)
def left(self, room_id: RoomID, user_id: UserID) -> Awaitable[None]:
return self.set_membership(room_id, user_id, Membership.LEAVE)
async def has_power_level(
self, room_id: RoomID, user_id: UserID, event_type: EventType
) -> bool | None:
room_levels = await self.get_power_levels(room_id)
if not room_levels:
return None
return room_levels.get_user_level(user_id) >= room_levels.get_event_level(event_type)
python-0.20.4/mautrix/client/state_store/asyncpg/ 0000775 0000000 0000000 00000000000 14547234302 0022071 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/client/state_store/asyncpg/__init__.py 0000664 0000000 0000000 00000000074 14547234302 0024203 0 ustar 00root root 0000000 0000000 from .store import PgStateStore
__all__ = ["PgStateStore"]
python-0.20.4/mautrix/client/state_store/asyncpg/store.py 0000664 0000000 0000000 00000024444 14547234302 0023607 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, NamedTuple
import json
from mautrix.types import (
Member,
Membership,
MemberStateEventContent,
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
Serializable,
UserID,
)
from mautrix.util.async_db import Database, Scheme
from ..abstract import StateStore
from .upgrade import upgrade_table
class RoomState(NamedTuple):
is_encrypted: bool
has_full_member_list: bool
encryption: RoomEncryptionStateEventContent
power_levels: PowerLevelStateEventContent
class PgStateStore(StateStore):
upgrade_table = upgrade_table
db: Database
def __init__(self, db: Database) -> None:
self.db = db
async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
res = await self.db.fetchrow(
"SELECT membership, displayname, avatar_url "
"FROM mx_user_profile WHERE room_id=$1 AND user_id=$2",
room_id,
user_id,
)
if res is None:
return None
return Member(
membership=Membership.deserialize(res["membership"]),
displayname=res["displayname"],
avatar_url=res["avatar_url"],
)
async def set_member(
self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent
) -> None:
q = (
"INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) "
"VALUES ($1, $2, $3, $4, $5)"
"ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3, displayname=$4,"
" avatar_url=$5"
)
await self.db.execute(
q, room_id, user_id, member.membership.value, member.displayname, member.avatar_url
)
async def set_membership(
self, room_id: RoomID, user_id: UserID, membership: Membership
) -> None:
q = (
"INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3) "
"ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3"
)
await self.db.execute(q, room_id, user_id, membership.value)
async def get_members(
self,
room_id: RoomID,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> list[UserID]:
membership_values = [membership.value for membership in memberships]
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND membership=ANY($2)"
res = await self.db.fetch(q, room_id, membership_values)
else:
membership_placeholders = ("?," * len(memberships)).rstrip(",")
q = (
"SELECT user_id FROM mx_user_profile "
f"WHERE room_id=? AND membership IN ({membership_placeholders})"
)
res = await self.db.fetch(q, room_id, *membership_values)
return [profile["user_id"] for profile in res]
async def get_member_profiles(
self,
room_id: RoomID,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> dict[UserID, Member]:
membership_values = [membership.value for membership in memberships]
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = (
"SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile "
"WHERE room_id=$1 AND membership=ANY($2)"
)
res = await self.db.fetch(q, room_id, membership_values)
else:
membership_placeholders = ("?," * len(memberships)).rstrip(",")
q = (
"SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile "
f"WHERE room_id=? AND membership IN ({membership_placeholders})"
)
res = await self.db.fetch(q, room_id, *membership_values)
return {profile["user_id"]: Member.deserialize(profile) for profile in res}
async def get_members_filtered(
self,
room_id: RoomID,
not_prefix: str,
not_suffix: str,
not_id: str,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> list[UserID]:
not_like = f"{not_prefix}%{not_suffix}"
membership_values = [membership.value for membership in memberships]
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = (
"SELECT user_id FROM mx_user_profile "
"WHERE room_id=$1 AND membership=ANY($2)"
"AND user_id != $3 AND user_id NOT LIKE $4"
)
res = await self.db.fetch(q, room_id, membership_values, not_id, not_like)
else:
membership_placeholders = ("?," * len(memberships)).rstrip(",")
q = (
"SELECT user_id FROM mx_user_profile "
f"WHERE room_id=? AND membership IN ({membership_placeholders})"
"AND user_id != ? AND user_id NOT LIKE ?"
)
res = await self.db.fetch(q, room_id, *membership_values, not_id, not_like)
return [profile["user_id"] for profile in res]
async def set_members(
self,
room_id: RoomID,
members: dict[UserID, Member | MemberStateEventContent],
only_membership: Membership | None = None,
) -> None:
columns = ["room_id", "user_id", "membership", "displayname", "avatar_url"]
records = [
(room_id, user_id, str(member.membership), member.displayname, member.avatar_url)
for user_id, member in members.items()
]
async with self.db.acquire() as conn, conn.transaction():
del_q = "DELETE FROM mx_user_profile WHERE room_id=$1"
if only_membership is None:
await conn.execute(del_q, room_id)
elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
del_q = f"{del_q} AND (membership=$2 OR user_id = ANY($3))"
await conn.execute(del_q, room_id, only_membership.value, list(members.keys()))
else:
member_placeholders = ("?," * len(members)).rstrip(",")
del_q = f"{del_q} AND (membership=? OR user_id IN ({member_placeholders}))"
await conn.execute(del_q, room_id, only_membership.value, *members.keys())
if self.db.scheme == Scheme.POSTGRES:
await conn.copy_records_to_table(
"mx_user_profile", records=records, columns=columns
)
else:
q = (
"INSERT INTO mx_user_profile (room_id, user_id, membership, "
"displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)"
)
await conn.executemany(q, records)
if not only_membership or only_membership == Membership.JOIN:
await conn.execute(
"UPDATE mx_room_state SET has_full_member_list=true WHERE room_id=$1",
room_id,
)
async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]:
q = (
"SELECT mx_user_profile.room_id FROM mx_user_profile "
"LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id "
"WHERE user_id=$1 AND mx_room_state.is_encrypted=true"
)
rows = await self.db.fetch(q, user_id)
return [row["room_id"] for row in rows]
async def has_full_member_list(self, room_id: RoomID) -> bool:
return bool(
await self.db.fetchval(
"SELECT has_full_member_list FROM mx_room_state WHERE room_id=$1", room_id
)
)
async def has_power_levels_cached(self, room_id: RoomID) -> bool:
return bool(
await self.db.fetchval(
"SELECT power_levels IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id
)
)
async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None:
power_levels_json = await self.db.fetchval(
"SELECT power_levels FROM mx_room_state WHERE room_id=$1", room_id
)
if power_levels_json is None:
return None
return PowerLevelStateEventContent.parse_json(power_levels_json)
async def set_power_levels(
self, room_id: RoomID, content: PowerLevelStateEventContent | dict[str, Any]
) -> None:
await self.db.execute(
"INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) "
"ON CONFLICT (room_id) DO UPDATE SET power_levels=$2",
room_id,
json.dumps(content.serialize() if isinstance(content, Serializable) else content),
)
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
return bool(
await self.db.fetchval(
"SELECT encryption IS NULL FROM mx_room_state WHERE room_id=$1", room_id
)
)
async def is_encrypted(self, room_id: RoomID) -> bool | None:
return await self.db.fetchval(
"SELECT is_encrypted FROM mx_room_state WHERE room_id=$1", room_id
)
async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
row = await self.db.fetchrow(
"SELECT is_encrypted, encryption FROM mx_room_state WHERE room_id=$1", room_id
)
if row is None or not row["is_encrypted"]:
return None
return RoomEncryptionStateEventContent.parse_json(row["encryption"])
async def set_encryption_info(
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
) -> None:
q = (
"INSERT INTO mx_room_state (room_id, is_encrypted, encryption) VALUES ($1, true, $2) "
"ON CONFLICT (room_id) DO UPDATE SET is_encrypted=true, encryption=$2"
)
await self.db.execute(
q,
room_id,
json.dumps(content.serialize() if isinstance(content, Serializable) else content),
)
python-0.20.4/mautrix/client/state_store/asyncpg/upgrade.py 0000664 0000000 0000000 00000005333 14547234302 0024076 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import logging
from mautrix.util.async_db import Connection, Scheme, UpgradeTable
upgrade_table = UpgradeTable(
version_table_name="mx_version",
database_name="matrix state cache",
log=logging.getLogger("mau.client.db.upgrade"),
)
@upgrade_table.register(description="Latest revision", upgrades_to=3)
async def upgrade_blank_to_v3(conn: Connection, scheme: Scheme) -> None:
await conn.execute(
"""CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
is_encrypted BOOLEAN,
has_full_member_list BOOLEAN,
encryption TEXT,
power_levels TEXT
)"""
)
membership_check = ""
if scheme != Scheme.SQLITE:
await conn.execute(
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
)
else:
membership_check = "CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock'))"
await conn.execute(
f"""CREATE TABLE mx_user_profile (
room_id TEXT,
user_id TEXT,
membership membership NOT NULL {membership_check},
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
)"""
)
@upgrade_table.register(description="Stop using size-limited string fields")
async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
if scheme == Scheme.SQLITE:
# SQLite doesn't care about types
return
await conn.execute("ALTER TABLE mx_room_state ALTER COLUMN room_id TYPE TEXT")
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN room_id TYPE TEXT")
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT")
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN displayname TYPE TEXT")
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT")
@upgrade_table.register(description="Mark rooms that need crypto state event resynced")
async def upgrade_v3(conn: Connection) -> None:
if await conn.table_exists("portal"):
await conn.execute(
"""
INSERT INTO mx_room_state (room_id, encryption)
SELECT portal.mxid, '{"resync":true}' FROM portal
WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
ON CONFLICT (room_id) DO UPDATE
SET encryption=excluded.encryption
WHERE mx_room_state.encryption IS NULL
"""
)
python-0.20.4/mautrix/client/state_store/file.py 0000664 0000000 0000000 00000004117 14547234302 0021721 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import IO, Any
from pathlib import Path
from mautrix.types import (
Member,
Membership,
MemberStateEventContent,
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
UserID,
)
from mautrix.util.file_store import Filer, FileStore
from .memory import MemoryStateStore
class FileStateStore(MemoryStateStore, FileStore):
def __init__(
self,
path: str | Path | IO,
filer: Filer | None = None,
binary: bool = True,
save_interval: float = 60.0,
) -> None:
FileStore.__init__(self, path, filer, binary, save_interval)
MemoryStateStore.__init__(self)
async def set_membership(
self, room_id: RoomID, user_id: UserID, membership: Membership
) -> None:
await super().set_membership(room_id, user_id, membership)
self._time_limited_flush()
async def set_member(
self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent
) -> None:
await super().set_member(room_id, user_id, member)
self._time_limited_flush()
async def set_members(
self,
room_id: RoomID,
members: dict[UserID, Member | MemberStateEventContent],
only_membership: Membership | None = None,
) -> None:
await super().set_members(room_id, members, only_membership)
self._time_limited_flush()
async def set_encryption_info(
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
) -> None:
await super().set_encryption_info(room_id, content)
self._time_limited_flush()
async def set_power_levels(
self, room_id: RoomID, content: PowerLevelStateEventContent
) -> None:
await super().set_power_levels(room_id, content)
self._time_limited_flush()
python-0.20.4/mautrix/client/state_store/memory.py 0000664 0000000 0000000 00000015414 14547234302 0022314 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, TypedDict
from mautrix.types import (
Member,
Membership,
MemberStateEventContent,
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
UserID,
)
from .abstract import StateStore
class SerializedStateStore(TypedDict):
members: dict[RoomID, dict[UserID, Any]]
full_member_list: dict[RoomID, bool]
power_levels: dict[RoomID, Any]
encryption: dict[RoomID, Any]
class MemoryStateStore(StateStore):
members: dict[RoomID, dict[UserID, Member]]
full_member_list: dict[RoomID, bool]
power_levels: dict[RoomID, PowerLevelStateEventContent]
encryption: dict[RoomID, RoomEncryptionStateEventContent | None]
def __init__(self) -> None:
self.members = {}
self.full_member_list = {}
self.power_levels = {}
self.encryption = {}
def serialize(self) -> SerializedStateStore:
"""
Convert the data in the store into a JSON-friendly dict.
Returns: A dict that can be safely serialized with most object serialization methods.
"""
return {
"members": {
room_id: {user_id: member.serialize() for user_id, member in members.items()}
for room_id, members in self.members.items()
},
"full_member_list": self.full_member_list,
"power_levels": {
room_id: content.serialize() for room_id, content in self.power_levels.items()
},
"encryption": {
room_id: (content.serialize() if content is not None else None)
for room_id, content in self.encryption.items()
},
}
def deserialize(self, data: SerializedStateStore) -> None:
"""
Parse a previously serialized dict into this state store.
Args:
data: A dict returned by :meth:`serialize`.
"""
self.members = {
room_id: {user_id: Member.deserialize(member) for user_id, member in members.items()}
for room_id, members in data["members"].items()
}
self.full_member_list = data["full_member_list"]
self.power_levels = {
room_id: PowerLevelStateEventContent.deserialize(content)
for room_id, content in data["power_levels"].items()
}
self.encryption = {
room_id: (
RoomEncryptionStateEventContent.deserialize(content)
if content is not None
else None
)
for room_id, content in data["encryption"].items()
}
async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
try:
return self.members[room_id][user_id]
except KeyError:
return None
async def set_member(
self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent
) -> None:
if not isinstance(member, Member):
member = Member(
membership=member.membership,
avatar_url=member.avatar_url,
displayname=member.displayname,
)
try:
self.members[room_id][user_id] = member
except KeyError:
self.members[room_id] = {user_id: member}
async def set_membership(
self, room_id: RoomID, user_id: UserID, membership: Membership
) -> None:
try:
room_members = self.members[room_id]
except KeyError:
self.members[room_id] = {user_id: Member(membership=membership)}
return
try:
room_members[user_id].membership = membership
except (KeyError, TypeError):
room_members[user_id] = Member(membership=membership)
async def get_member_profiles(
self,
room_id: RoomID,
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
) -> dict[UserID, Member]:
try:
return {
user_id: member
for user_id, member in self.members[room_id].items()
if member.membership in memberships
}
except KeyError:
return {}
async def set_members(
self,
room_id: RoomID,
members: dict[UserID, Member | MemberStateEventContent],
only_membership: Membership | None = None,
) -> None:
old_members = {}
if only_membership is not None:
old_members = {
user_id: member
for user_id, member in self.members.get(room_id, {}).items()
if member.membership != only_membership
}
self.members[room_id] = {
user_id: (
member
if isinstance(member, Member)
else Member(
membership=member.membership,
avatar_url=member.avatar_url,
displayname=member.displayname,
)
)
for user_id, member in members.items()
}
self.members[room_id].update(old_members)
self.full_member_list[room_id] = True
async def has_full_member_list(self, room_id: RoomID) -> bool:
return self.full_member_list.get(room_id, False)
async def has_power_levels_cached(self, room_id: RoomID) -> bool:
return room_id in self.power_levels
async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None:
return self.power_levels.get(room_id)
async def set_power_levels(
self, room_id: RoomID, content: PowerLevelStateEventContent | dict[str, Any]
) -> None:
if not isinstance(content, PowerLevelStateEventContent):
content = PowerLevelStateEventContent.deserialize(content)
self.power_levels[room_id] = content
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
return room_id in self.encryption
async def is_encrypted(self, room_id: RoomID) -> bool | None:
try:
return self.encryption[room_id] is not None
except KeyError:
return None
async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent:
return self.encryption.get(room_id)
async def set_encryption_info(
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
) -> None:
if not isinstance(content, RoomEncryptionStateEventContent):
content = RoomEncryptionStateEventContent.deserialize(content)
self.encryption[room_id] = content
python-0.20.4/mautrix/client/state_store/sync.py 0000664 0000000 0000000 00000002037 14547234302 0021755 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from abc import ABC, abstractmethod
from mautrix.types import SyncToken
class SyncStore(ABC):
"""SyncStore persists information used by /sync."""
@abstractmethod
async def put_next_batch(self, next_batch: SyncToken) -> None:
pass
@abstractmethod
async def get_next_batch(self) -> SyncToken:
pass
class MemorySyncStore(SyncStore):
"""MemorySyncStore is a :class:`SyncStore` implementation that stores the data in memory."""
def __init__(self, next_batch: SyncToken | None = None) -> None:
self._next_batch: SyncToken | None = next_batch
async def put_next_batch(self, next_batch: SyncToken) -> None:
self._next_batch = next_batch
async def get_next_batch(self) -> SyncToken:
return self._next_batch
python-0.20.4/mautrix/client/state_store/tests/ 0000775 0000000 0000000 00000000000 14547234302 0021567 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/client/state_store/tests/__init__.py 0000664 0000000 0000000 00000000000 14547234302 0023666 0 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/client/state_store/tests/joined_members.json 0000664 0000000 0000000 00000001713 14547234302 0025446 0 ustar 00root root 0000000 0000000 {
"!telegram-group:example.com": {
"@telegrambot:example.com": {
"avatar_url": "mxc://maunium.net/tJCRmUyJDsgRNgqhOgoiHWbX",
"display_name": "Telegram bridge bot"
},
"@telegram_84359547:example.com": {
"avatar_url": "mxc://example.com/321cba",
"display_name": "tulir (Telegram)"
},
"@telegram_5647382910:example.com": {
"avatar_url": "mxc://example.com/o3sSOTEE6F7aSZELL8PfP3N7",
"display_name": "Tulir #4 (Telegram)"
},
"@tulir:example.com": {
"avatar_url": "mxc://example.com/123abc",
"display_name": "tulir"
},
"@telegram_374880943:example.com": {
"avatar_url": "mxc://example.com/9yk1G8ZmSHbvP4JI2IFbfSAn",
"display_name": "Mautrix (Telegram)"
},
"@telegram_987654321:example.com": {
"avatar_url": "mxc://example.com/TTi6q0SABMNNNlC38014KMBV",
"display_name": "Mautrix / Testing (Telegram)"
},
"@telegram_123456789:example.com": {
"avatar_url": null,
"display_name": "Tulir #3 (Telegram)"
}
}
}
python-0.20.4/mautrix/client/state_store/tests/members.json 0000664 0000000 0000000 00000011775 14547234302 0024127 0 ustar 00root root 0000000 0000000 {
"!telegram-group:example.com": [
{
"content": {
"avatar_url": "mxc://maunium.net/tJCRmUyJDsgRNgqhOgoiHWbX",
"displayname": "Telegram bridge bot",
"membership": "join"
},
"origin_server_ts": 1637017135583,
"sender": "@telegrambot:example.com",
"state_key": "@telegrambot:example.com",
"type": "m.room.member",
"event_id": "$ToMVBrkfMXcaMt6lrNn_AACxGtWGk9acDQplsYMjdu8"
},
{
"content": {
"avatar_url": "mxc://example.com/321cba",
"displayname": "tulir (Telegram)",
"membership": "join"
},
"origin_server_ts": 1637017138997,
"sender": "@telegram_84359547:example.com",
"state_key": "@telegram_84359547:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$tRU6IBdkgFoG9PD-PRjtKOFMPAQVMMFdU-ZfB09hm48",
"prev_content": {
"avatar_url": "mxc://example.com/321cba",
"displayname": "tulir (Telegram)",
"membership": "invite"
},
"prev_sender": "@telegrambot:example.com"
},
"event_id": "$CnpXeISvK8MkbzgryzlrMKKk134KLwuvdIDTLp8v1ME"
},
{
"content": {
"avatar_url": "mxc://example.com/o3sSOTEE6F7aSZELL8PfP3N7",
"displayname": "Tulir #4 (Telegram)",
"membership": "join"
},
"origin_server_ts": 1637017139441,
"sender": "@telegram_5647382910:example.com",
"state_key": "@telegram_5647382910:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$_yUbIfhy46AVZ8Aj3bEoGUvzJ6xC38mbFIZ79JOrIhw",
"prev_content": {
"avatar_url": "mxc://example.com/o3sSOTEE6F7aSZELL8PfP3N7",
"displayname": "Tulir #4 (Telegram)",
"membership": "invite"
},
"prev_sender": "@telegrambot:example.com"
},
"event_id": "$n_dhyHuytT4lwJ1rzdkGeRIAaCQNtdM-6KhI5rQcVSw"
},
{
"content": {
"avatar_url": "mxc://example.com/123abc",
"displayname": "tulir",
"membership": "join"
},
"origin_server_ts": 1637017136989,
"sender": "@tulir:example.com",
"state_key": "@tulir:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$jMwIGQPtxaRLfCfUm0XhYRgWK1WsVRnI-2shdpLpJqg",
"prev_content": {
"avatar_url": "mxc://example.com/123abc",
"displayname": "tulir",
"fi.mau.will_auto_accept": true,
"is_direct": false,
"membership": "invite"
},
"prev_sender": "@telegrambot:example.com"
},
"event_id": "$xv9FUkljVHVY2aHvQgtq9yPAo8YSD6IX80bs2eeLKQo"
},
{
"content": {
"avatar_url": "mxc://example.com/9yk1G8ZmSHbvP4JI2IFbfSAn",
"displayname": "Mautrix (Telegram)",
"membership": "join"
},
"origin_server_ts": 1637017140015,
"sender": "@telegram_374880943:example.com",
"state_key": "@telegram_374880943:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$phLlRQ3dHd7rVOZmfMWVscQCnwAT3k-CAhh8aByHJOQ",
"prev_content": {
"avatar_url": "mxc://example.com/9yk1G8ZmSHbvP4JI2IFbfSAn",
"displayname": "Mautrix (Telegram)",
"membership": "invite"
},
"prev_sender": "@telegrambot:example.com"
},
"event_id": "$agGmH9w9I-KgBnOoUBXp6RmXxemGaYwN3PnwOgIvmEk"
},
{
"content": {
"avatar_url": "mxc://example.com/TTi6q0SABMNNNlC38014KMBV",
"displayname": "Mautrix / Testing (Telegram)",
"membership": "join"
},
"origin_server_ts": 1637017141173,
"sender": "@telegram_987654321:example.com",
"state_key": "@telegram_987654321:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$VYToMy5ap8j2no4bL5neeYrDqAOPnt2JpY0yfdbnbWk",
"prev_content": {
"avatar_url": "mxc://example.com/TTi6q0SABMNNNlC38014KMBV",
"displayname": "Mautrix / Testing (Telegram)",
"membership": "invite"
},
"prev_sender": "@telegrambot:example.com"
},
"event_id": "$DZxbFRtQ8oKgw9hGiQRCpFLV-PF20uhS6DMDfFjOx38"
},
{
"content": {
"displayname": "Tulir #3 (Telegram)",
"membership": "join"
},
"origin_server_ts": 1637017138285,
"sender": "@telegram_123456789:example.com",
"state_key": "@telegram_123456789:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$WLvGliXjUCuXxJV1wqbZd7suumC3HiOf-pVP94_JUAE",
"prev_content": {
"displayname": "Tulir #3 (Telegram)",
"membership": "invite"
},
"prev_sender": "@telegrambot:example.com"
},
"event_id": "$DW23zfhHKNBTYJNq40R7leaIZAPzf54eyWgxmeQ7WYc"
},
{
"content": {
"membership": "leave"
},
"origin_server_ts": 1637017153623,
"sender": "@telegram_476034259:example.com",
"state_key": "@telegram_476034259:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$2yMikO4Th2GYcegL-AWLVMbH-5gBw7ru4pJIZG_XsHc"
},
"event_id": "$r8T_b9upDSscRZqulUOJSQeK80RnN7LKDRSvzXflxlQ"
},
{
"content": {
"membership": "ban",
"avatar_url": "mxc://maunium.net/456def",
"displayname": "WhatsApp bridge bot"
},
"origin_server_ts": 1637406936936,
"sender": "@tulir:example.com",
"state_key": "@whatsappbot:example.com",
"type": "m.room.member",
"event_id": "$9zhRDMhfvo37qFlRYZ45H0afcx9KSyGkn-gr72S6TsU"
}
]
}
python-0.20.4/mautrix/client/state_store/tests/new_state.json 0000664 0000000 0000000 00000013302 14547234302 0024452 0 ustar 00root root 0000000 0000000 {
"!telegram-group:example.com": [
{
"content": {
"algorithm": "m.megolm.v1.aes-sha2"
},
"origin_server_ts": 1637017136383,
"sender": "@telegrambot:example.com",
"state_key": "",
"type": "m.room.encryption",
"event_id": "$bFRKXissP1baFP9XmqnPqhNF5qqQPNFaxzToadnVGzA"
},
{
"content": {
"ban": 50,
"events": {
"m.room.avatar": 0,
"m.room.encryption": 50,
"m.room.history_visibility": 75,
"m.room.name": 0,
"m.room.pinned_events": 0,
"m.room.power_levels": 75,
"m.room.tombstone": 99,
"m.room.topic": 0,
"m.sticker": 0
},
"events_default": 0,
"invite": 0,
"kick": 50,
"redact": 50,
"state_default": 50,
"users": {
"@telegrambot:example.com": 100,
"@tulir:example.com": 95
},
"users_default": 0
},
"origin_server_ts": 1637017135659,
"sender": "@telegrambot:example.com",
"state_key": "",
"type": "m.room.power_levels",
"event_id": "$98TzHn07MQ_fogX-hzYQfsb-1KlTL2XMVIOe3vvWY3g"
},
{
"content": {
"avatar_url": "mxc://example.com/123abc",
"displayname": "tulir",
"fi.mau.will_auto_accept": true,
"is_direct": false,
"membership": "invite"
},
"origin_server_ts": 1637017136716,
"sender": "@telegrambot:example.com",
"state_key": "@tulir:example.com",
"type": "m.room.member",
"event_id": "$jMwIGQPtxaRLfCfUm0XhYRgWK1WsVRnI-2shdpLpJqg"
},
{
"content": {
"avatar_url": "mxc://example.com/123abc",
"displayname": "tulir",
"membership": "join"
},
"origin_server_ts": 1637017136989,
"sender": "@tulir:example.com",
"state_key": "@tulir:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$jMwIGQPtxaRLfCfUm0XhYRgWK1WsVRnI-2shdpLpJqg",
"prev_content": {
"avatar_url": "mxc://example.com/123abc",
"displayname": "tulir",
"fi.mau.will_auto_accept": true,
"is_direct": false,
"membership": "invite"
},
"prev_sender": "@telegrambot:example.com"
},
"event_id": "$xv9FUkljVHVY2aHvQgtq9yPAo8YSD6IX80bs2eeLKQo"
},
{
"content": {
"avatar_url": "mxc://example.com/321cba",
"displayname": "tulir (Telegram)",
"membership": "invite"
},
"origin_server_ts": 1637017138734,
"sender": "@telegrambot:example.com",
"state_key": "@telegram_84359547:example.com",
"type": "m.room.member",
"event_id": "$tRU6IBdkgFoG9PD-PRjtKOFMPAQVMMFdU-ZfB09hm48"
},
{
"content": {
"avatar_url": "mxc://example.com/321cba",
"displayname": "tulir (Telegram)",
"membership": "join"
},
"origin_server_ts": 1637017138997,
"sender": "@telegram_84359547:example.com",
"state_key": "@telegram_84359547:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$tRU6IBdkgFoG9PD-PRjtKOFMPAQVMMFdU-ZfB09hm48",
"prev_content": {
"avatar_url": "mxc://example.com/321cba",
"displayname": "tulir (Telegram)",
"membership": "invite"
},
"prev_sender": "@telegrambot:example.com",
"age": 1234
},
"event_id": "$CnpXeISvK8MkbzgryzlrMKKk134KLwuvdIDTLp8v1ME"
}
],
"!unencrypted-room:example.com": [
{
"content": {
"ban": 50,
"events": {
"m.room.avatar": 50,
"m.room.canonical_alias": 50,
"m.room.encryption": 100,
"m.room.history_visibility": 100,
"m.room.name": 50,
"m.room.power_levels": 100,
"m.room.server_acl": 100,
"m.room.tombstone": 100
},
"events_default": 0,
"historical": 100,
"invite": 0,
"kick": 50,
"redact": 50,
"state_default": 50,
"users": {
"@tulir:example.com": 9001
},
"users_default": 0
},
"origin_server_ts": 1637343767626,
"sender": "@tulir:example.com",
"state_key": "",
"type": "m.room.power_levels",
"event_id": "$TMUQddqg0CzcNviUBPT38_3i51-SC86p2w3oADNhyZA"
},
{
"content": {
"avatar_url": "mxc://example.com/987zyx",
"displayname": "Maubot",
"membership": "invite"
},
"origin_server_ts": 1637343773909,
"sender": "@tulir:example.com",
"state_key": "@maubot:example.com",
"type": "m.room.member",
"event_id": "$9XvlRmVmzCdNna0uvgm4NeMiVKJLmd-8c_tX10gmg4k"
},
{
"content": {
"avatar_url": "mxc://example.com/987zyx",
"displayname": "Maubot",
"membership": "join"
},
"origin_server_ts": 1637343773991,
"sender": "@maubot:example.com",
"state_key": "@maubot:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$9XvlRmVmzCdNna0uvgm4NeMiVKJLmd-8c_tX10gmg4k",
"prev_content": {
"avatar_url": "mxc://example.com/987zyx",
"displayname": "Maubot",
"membership": "invite"
},
"prev_sender": "@tulir:example.com",
"age": 1234
},
"event_id": "$XZMALYE9N30jP5_x8S1IWFlzt5F6tZB--W2kkoKGJDM"
},
{
"content": {
"avatar_url": "mxc://maunium.net/456def",
"displayname": "WhatsApp bridge bot",
"membership": "invite"
},
"origin_server_ts": 1637406933363,
"sender": "@tulir:example.com",
"state_key": "@whatsappbot:example.com",
"type": "m.room.member",
"event_id": "$jODwwttaZd-flc2eyh0JirR0p6EtDkX5BaJmPfaXcrc"
},
{
"content": {
"membership": "ban",
"avatar_url": "mxc://maunium.net/456def",
"displayname": "WhatsApp bridge bot"
},
"origin_server_ts": 1637406936936,
"sender": "@tulir:example.com",
"state_key": "@whatsappbot:example.com",
"type": "m.room.member",
"unsigned": {
"replaces_state": "$jODwwttaZd-flc2eyh0JirR0p6EtDkX5BaJmPfaXcrc",
"prev_content": {
"avatar_url": "mxc://maunium.net/456def",
"displayname": "WhatsApp bridge bot",
"membership": "invite"
},
"prev_sender": "@tulir:example.com",
"age": 65
},
"event_id": "$9zhRDMhfvo37qFlRYZ45H0afcx9KSyGkn-gr72S6TsU"
}
]
}
python-0.20.4/mautrix/client/state_store/tests/store_test.py 0000664 0000000 0000000 00000014317 14547234302 0024342 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import AsyncContextManager, AsyncIterator, Callable
from contextlib import asynccontextmanager
import json
import os
import pathlib
import random
import string
import time
import asyncpg
import pytest
from mautrix.types import EncryptionAlgorithm, Member, Membership, RoomID, StateEvent, UserID
from mautrix.util.async_db import Database
from .. import MemoryStateStore, StateStore
from ..asyncpg import PgStateStore
@asynccontextmanager
async def async_postgres_store() -> AsyncIterator[PgStateStore]:
try:
pg_url = os.environ["MEOW_TEST_PG_URL"]
except KeyError:
pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)")
return
conn: asyncpg.Connection = await asyncpg.connect(pg_url)
schema_name = "".join(random.choices(string.ascii_lowercase, k=8))
schema_name = f"test_schema_{schema_name}_{int(time.time())}"
await conn.execute(f"CREATE SCHEMA {schema_name}")
db = Database.create(
pg_url,
upgrade_table=PgStateStore.upgrade_table,
db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}},
)
store = PgStateStore(db)
await db.start()
yield store
await db.stop()
await conn.execute(f"DROP SCHEMA {schema_name} CASCADE")
await conn.close()
@asynccontextmanager
async def async_sqlite_store() -> AsyncIterator[PgStateStore]:
db = Database.create(
"sqlite::memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1}
)
store = PgStateStore(db)
await db.start()
yield store
await db.stop()
@asynccontextmanager
async def memory_store() -> AsyncIterator[MemoryStateStore]:
yield MemoryStateStore()
@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store])
async def store(request) -> AsyncIterator[StateStore]:
param: Callable[[], AsyncContextManager[StateStore]] = request.param
async with param() as state_store:
yield state_store
def read_state_file(request, file) -> dict[RoomID, list[StateEvent]]:
path = pathlib.Path(request.node.fspath).with_name(file)
with path.open() as fp:
content = json.load(fp)
return {
room_id: [StateEvent.deserialize({**evt, "room_id": room_id}) for evt in events]
for room_id, events in content.items()
}
async def store_room_state(request, store: StateStore) -> None:
room_state_changes = read_state_file(request, "new_state.json")
for events in room_state_changes.values():
for evt in events:
await store.update_state(evt)
async def get_all_members(request, store: StateStore) -> None:
room_state = read_state_file(request, "members.json")
for room_id, member_events in room_state.items():
await store.set_members(room_id, {evt.state_key: evt.content for evt in member_events})
async def get_joined_members(request, store: StateStore) -> None:
path = pathlib.Path(request.node.fspath).with_name("joined_members.json")
with path.open() as fp:
content = json.load(fp)
for room_id, members in content.items():
parsed_members = {
user_id: Member(
membership=Membership.JOIN,
displayname=member.get("display_name", ""),
avatar_url=member.get("avatar_url", ""),
)
for user_id, member in members.items()
}
await store.set_members(room_id, parsed_members, only_membership=Membership.JOIN)
async def test_basic(store: StateStore) -> None:
room_id = RoomID("!foo:example.com")
user_id = UserID("@tulir:example.com")
assert not await store.is_encrypted(room_id)
assert not await store.is_joined(room_id, user_id)
await store.joined(room_id, user_id)
assert await store.is_joined(room_id, user_id)
assert not await store.has_encryption_info_cached(RoomID("!unknown-room:example.com"))
assert await store.is_encrypted(RoomID("!unknown-room:example.com")) is None
async def test_basic_updated(request, store: StateStore) -> None:
await store_room_state(request, store)
test_group = RoomID("!telegram-group:example.com")
assert await store.is_encrypted(test_group)
assert (await store.get_encryption_info(test_group)).algorithm == EncryptionAlgorithm.MEGOLM_V1
assert not await store.is_encrypted(RoomID("!unencrypted-room:example.com"))
async def test_updates(request, store: StateStore) -> None:
await store_room_state(request, store)
room_id = RoomID("!telegram-group:example.com")
initial_members = {"@tulir:example.com", "@telegram_84359547:example.com"}
joined_members = initial_members | {
"@telegrambot:example.com",
"@telegram_5647382910:example.com",
"@telegram_374880943:example.com",
"@telegram_987654321:example.com",
"@telegram_123456789:example.com",
}
left_members = {"@telegram_476034259:example.com", "@whatsappbot:example.com"}
full_members = joined_members | left_members
any_membership = (
Membership.JOIN,
Membership.INVITE,
Membership.LEAVE,
Membership.BAN,
Membership.KNOCK,
)
leave_memberships = (Membership.BAN, Membership.LEAVE)
assert set(await store.get_members(room_id)) == initial_members
await get_all_members(request, store)
assert set(await store.get_members(room_id)) == joined_members
assert set(await store.get_members(room_id, memberships=any_membership)) == full_members
await get_joined_members(request, store)
assert set(await store.get_members(room_id)) == joined_members
assert set(await store.get_members(room_id, memberships=any_membership)) == full_members
assert set(await store.get_members(room_id, memberships=leave_memberships)) == left_members
assert set(
await store.get_members_filtered(
room_id,
memberships=leave_memberships,
not_id="",
not_prefix="@telegram_",
not_suffix=":example.com",
)
) == {"@whatsappbot:example.com"}
python-0.20.4/mautrix/client/store_updater.py 0000664 0000000 0000000 00000025711 14547234302 0021331 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import asyncio
from mautrix.errors import MForbidden, MNotFound
from mautrix.types import (
JSON,
EventID,
EventType,
Member,
Membership,
MemberStateEventContent,
RoomAlias,
RoomID,
StateEvent,
StateEventContent,
SyncToken,
UserID,
)
from .api import ClientAPI
from .state_store import StateStore
class StoreUpdatingAPI(ClientAPI):
"""
StoreUpdatingAPI is a wrapper around the medium-level ClientAPI that optionally updates
a client state store with outgoing state events (after they're successfully sent).
"""
state_store: StateStore | None
def __init__(self, *args, state_store: StateStore | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.state_store = state_store
async def join_room_by_id(
self,
room_id: RoomID,
third_party_signed: JSON = None,
extra_content: dict[str, JSON] | None = None,
) -> RoomID:
room_id = await super().join_room_by_id(
room_id, third_party_signed=third_party_signed, extra_content=extra_content
)
if room_id and not extra_content and self.state_store:
await self.state_store.set_membership(room_id, self.mxid, Membership.JOIN)
return room_id
async def join_room(
self,
room_id_or_alias: RoomID | RoomAlias,
servers: list[str] | None = None,
third_party_signed: JSON = None,
max_retries: int = 4,
) -> RoomID:
room_id = await super().join_room(
room_id_or_alias, servers, third_party_signed, max_retries
)
if room_id and self.state_store:
await self.state_store.set_membership(room_id, self.mxid, Membership.JOIN)
return room_id
async def leave_room(
self,
room_id: RoomID,
reason: str | None = None,
extra_content: dict[str, JSON] | None = None,
raise_not_in_room: bool = False,
) -> None:
await super().leave_room(room_id, reason, extra_content, raise_not_in_room)
if not extra_content and self.state_store:
await self.state_store.set_membership(room_id, self.mxid, Membership.LEAVE)
async def knock_room(
self,
room_id_or_alias: RoomID | RoomAlias,
reason: str | None = None,
servers: list[str] | None = None,
) -> RoomID:
room_id = await super().knock_room(room_id_or_alias, reason, servers)
if room_id and self.state_store:
await self.state_store.set_membership(room_id, self.mxid, Membership.KNOCK)
return room_id
async def invite_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str | None = None,
extra_content: dict[str, JSON] | None = None,
) -> None:
await super().invite_user(room_id, user_id, reason, extra_content=extra_content)
if not extra_content and self.state_store:
await self.state_store.set_membership(room_id, user_id, Membership.INVITE)
async def kick_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
await super().kick_user(room_id, user_id, reason=reason, extra_content=extra_content)
if not extra_content and self.state_store:
await self.state_store.set_membership(room_id, user_id, Membership.LEAVE)
async def ban_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
await super().ban_user(room_id, user_id, reason=reason, extra_content=extra_content)
if not extra_content and self.state_store:
await self.state_store.set_membership(room_id, user_id, Membership.BAN)
async def unban_user(
self,
room_id: RoomID,
user_id: UserID,
reason: str = "",
extra_content: dict[str, JSON] | None = None,
) -> None:
await super().unban_user(room_id, user_id, reason=reason, extra_content=extra_content)
if self.state_store:
await self.state_store.set_membership(room_id, user_id, Membership.LEAVE)
async def get_state(self, room_id: RoomID) -> list[StateEvent]:
state = await super().get_state(room_id)
if self.state_store:
update_members = self.state_store.set_members(
room_id,
{evt.state_key: evt.content for evt in state if evt.type == EventType.ROOM_MEMBER},
)
await asyncio.gather(
update_members,
*[
self.state_store.update_state(evt)
for evt in state
if evt.type != EventType.ROOM_MEMBER
],
)
return state
async def create_room(self, *args, **kwargs) -> RoomID:
room_id = await super().create_room(*args, **kwargs)
if self.state_store:
invitee_membership = Membership.INVITE
if kwargs.get("beeper_auto_join_invites"):
invitee_membership = Membership.JOIN
for user_id in kwargs.get("invitees", []):
await self.state_store.set_membership(room_id, user_id, invitee_membership)
for evt in kwargs.get("initial_state", []):
await self.state_store.update_state(
StateEvent(
type=EventType.find(evt["type"], t_class=EventType.Class.STATE),
room_id=room_id,
event_id=EventID("$fake-create-id"),
sender=self.mxid,
state_key=evt.get("state_key", ""),
timestamp=0,
content=evt["content"],
)
)
return room_id
async def send_state_event(
self,
room_id: RoomID,
event_type: EventType,
content: StateEventContent | dict[str, JSON],
state_key: str = "",
**kwargs,
) -> EventID:
event_id = await super().send_state_event(
room_id, event_type, content, state_key, **kwargs
)
if self.state_store:
fake_event = StateEvent(
type=event_type,
room_id=room_id,
event_id=event_id,
sender=self.mxid,
state_key=state_key,
timestamp=0,
content=content,
)
await self.state_store.update_state(fake_event)
return event_id
async def get_state_event(
self, room_id: RoomID, event_type: EventType, state_key: str = ""
) -> StateEventContent:
event = await super().get_state_event(room_id, event_type, state_key)
if self.state_store:
fake_event = StateEvent(
type=event_type,
room_id=room_id,
event_id=EventID(""),
sender=UserID(""),
state_key=state_key,
timestamp=0,
content=event,
)
await self.state_store.update_state(fake_event)
return event
async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]:
members = await super().get_joined_members(room_id)
if self.state_store:
await self.state_store.set_members(room_id, members, only_membership=Membership.JOIN)
return members
async def get_members(
self,
room_id: RoomID,
at: SyncToken | None = None,
membership: Membership | None = None,
not_membership: Membership | None = None,
) -> list[StateEvent]:
member_events = await super().get_members(room_id, at, membership, not_membership)
if self.state_store and not_membership != Membership.JOIN:
await self.state_store.set_members(
room_id,
{evt.state_key: evt.content for evt in member_events},
only_membership=membership,
)
return member_events
async def fill_member_event(
self,
room_id: RoomID,
user_id: UserID,
content: MemberStateEventContent,
) -> MemberStateEventContent | None:
"""
Fill a membership event content that is going to be sent in :meth:`send_member_event`.
This is used to set default fields like the displayname and avatar, which are usually set
by the server in the sugar membership endpoints like /join and /invite, but are not set
automatically when sending member events manually.
This implementation in StoreUpdatingAPI will first try to call the default implementation
(which calls :attr:`fill_member_event_callback`). If that doesn't return anything, this
will try to get the profile from the current member event, and then fall back to fetching
the global profile from the server.
Args:
room_id: The room where the member event is going to be sent.
user_id: The user whose membership is changing.
content: The new member event content.
Returns:
The filled member event content.
"""
callback_content = await super().fill_member_event(room_id, user_id, content)
if callback_content is not None:
self.log.trace("Filled new member event for %s using callback", user_id)
return callback_content
if content.displayname is None and content.avatar_url is None:
existing_member = await self.state_store.get_member(room_id, user_id)
if existing_member is not None:
self.log.trace(
"Found existing member event %s to fill new member event for %s",
existing_member,
user_id,
)
content.displayname = existing_member.displayname
content.avatar_url = existing_member.avatar_url
return content
try:
profile = await self.get_profile(user_id)
except (MNotFound, MForbidden):
profile = None
if profile:
self.log.trace(
"Fetched profile %s to fill new member event of %s", profile, user_id
)
content.displayname = profile.displayname
content.avatar_url = profile.avatar_url
return content
else:
self.log.trace("Didn't find profile info to fill new member event of %s", user_id)
else:
self.log.trace(
"Member event for %s already contains displayname or avatar, not re-filling",
user_id,
)
return None
python-0.20.4/mautrix/client/syncer.py 0000664 0000000 0000000 00000044411 14547234302 0017752 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Awaitable, Callable, Type, TypeVar
from abc import ABC, abstractmethod
from contextlib import suppress
from enum import Enum, Flag, auto
import asyncio
import time
from mautrix.errors import MUnknownToken
from mautrix.types import (
JSON,
AccountDataEvent,
BaseMessageEventContentFuncs,
DeviceLists,
DeviceOTKCount,
EphemeralEvent,
Event,
EventType,
Filter,
FilterID,
GenericEvent,
MessageEvent,
PresenceState,
SerializerError,
StateEvent,
StrippedStateEvent,
SyncToken,
ToDeviceEvent,
UserID,
)
from mautrix.util import background_task
from mautrix.util.logging import TraceLogger
from . import dispatcher
from .state_store import MemorySyncStore, SyncStore
EventHandler = Callable[[Event], Awaitable[None]]
T = TypeVar("T", bound=Event)
class SyncStream(Flag):
INTERNAL = auto()
JOINED_ROOM = auto()
INVITED_ROOM = auto()
LEFT_ROOM = auto()
TIMELINE = auto()
STATE = auto()
EPHEMERAL = auto()
ACCOUNT_DATA = auto()
TO_DEVICE = auto()
class InternalEventType(Enum):
SYNC_STARTED = auto()
SYNC_ERRORED = auto()
SYNC_SUCCESSFUL = auto()
SYNC_STOPPED = auto()
JOIN = auto()
PROFILE_CHANGE = auto()
INVITE = auto()
REJECT_INVITE = auto()
DISINVITE = auto()
LEAVE = auto()
KICK = auto()
BAN = auto()
UNBAN = auto()
DEVICE_LISTS = auto()
DEVICE_OTK_COUNT = auto()
class Syncer(ABC):
loop: asyncio.AbstractEventLoop
log: TraceLogger
mxid: UserID
global_event_handlers: list[tuple[EventHandler, bool]]
event_handlers: dict[EventType | InternalEventType, list[tuple[EventHandler, bool]]]
dispatchers: dict[Type[dispatcher.Dispatcher], dispatcher.Dispatcher]
syncing_task: asyncio.Task | None
ignore_initial_sync: bool
ignore_first_sync: bool
presence: PresenceState
sync_store: SyncStore
def __init__(self, sync_store: SyncStore) -> None:
self.global_event_handlers = []
self.event_handlers = {}
self.dispatchers = {}
self.syncing_task = None
self.ignore_initial_sync = False
self.ignore_first_sync = False
self.presence = PresenceState.ONLINE
self.sync_store = sync_store or MemorySyncStore()
def on(
self, var: EventHandler | EventType | InternalEventType
) -> EventHandler | Callable[[EventHandler], EventHandler]:
"""
Add a new event handler. This method is for decorator usage.
Use :meth:`add_event_handler` if you don't use a decorator.
Args:
var: Either the handler function or the event type to handle.
Returns:
If ``var`` was the handler function, the handler function is returned.
If ``var`` was an event type, a function that takes the handler function as an argument
is returned.
Examples:
>>> from mautrix.client import Client
>>> cli = Client(...)
>>> @cli.on(EventType.ROOM_MESSAGE)
>>> def handler(event: MessageEvent) -> None:
... pass
"""
if isinstance(var, (EventType, InternalEventType)):
def decorator(func: EventHandler) -> EventHandler:
self.add_event_handler(var, func)
return func
return decorator
else:
self.add_event_handler(EventType.ALL, var)
return var
def add_dispatcher(self, dispatcher_type: Type[dispatcher.Dispatcher]) -> None:
if dispatcher_type in self.dispatchers:
return
self.log.debug(f"Enabling {dispatcher_type.__name__}")
self.dispatchers[dispatcher_type] = dispatcher_type(self)
self.dispatchers[dispatcher_type].register()
def remove_dispatcher(self, dispatcher_type: Type[dispatcher.Dispatcher]) -> None:
if dispatcher_type not in self.dispatchers:
return
self.log.debug(f"Disabling {dispatcher_type.__name__}")
self.dispatchers[dispatcher_type].unregister()
del self.dispatchers[dispatcher_type]
def add_event_handler(
self,
event_type: InternalEventType | EventType,
handler: EventHandler,
wait_sync: bool = False,
) -> None:
"""
Add a new event handler.
Args:
event_type: The event type to add. If not specified, the handler will be called for all
event types.
handler: The handler function to add.
wait_sync: Whether or not the handler should be awaited before the next sync request.
"""
if not isinstance(event_type, (EventType, InternalEventType)):
raise ValueError("Invalid event type")
if event_type == EventType.ALL:
self.global_event_handlers.append((handler, wait_sync))
else:
self.event_handlers.setdefault(event_type, []).append((handler, wait_sync))
def remove_event_handler(
self, event_type: EventType | InternalEventType, handler: EventHandler
) -> None:
"""
Remove an event handler.
Args:
handler: The handler function to remove.
event_type: The event type to remove the handler function from.
"""
if not isinstance(event_type, (EventType, InternalEventType)):
raise ValueError("Invalid event type")
try:
handler_list = (
self.global_event_handlers
if event_type == EventType.ALL
else self.event_handlers[event_type]
)
except KeyError:
# No handlers for this event type registered
return
# FIXME this is a bit hacky
with suppress(ValueError):
handler_list.remove((handler, True))
with suppress(ValueError):
handler_list.remove((handler, False))
if len(handler_list) == 0 and event_type != EventType.ALL:
del self.event_handlers[event_type]
def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asyncio.Task]:
"""
Send the given event to all applicable event handlers.
Args:
event: The event to send.
source: The sync stream the event was received in.
"""
if event is None:
return []
if isinstance(event.content, BaseMessageEventContentFuncs):
event.content.trim_reply_fallback()
if getattr(event, "state_key", None) is not None:
event.type = event.type.with_class(EventType.Class.STATE)
elif source & SyncStream.EPHEMERAL:
event.type = event.type.with_class(EventType.Class.EPHEMERAL)
elif source & SyncStream.ACCOUNT_DATA:
event.type = event.type.with_class(EventType.Class.ACCOUNT_DATA)
elif source & SyncStream.TO_DEVICE:
event.type = event.type.with_class(EventType.Class.TO_DEVICE)
else:
event.type = event.type.with_class(EventType.Class.MESSAGE)
setattr(event, "source", source)
return self.dispatch_manual_event(event.type, event, include_global_handlers=True)
async def _catch_errors(self, handler: EventHandler, data: Any) -> None:
try:
await handler(data)
except Exception:
self.log.exception("Failed to run handler")
def dispatch_manual_event(
self,
event_type: EventType | InternalEventType,
data: Any,
include_global_handlers: bool = False,
force_synchronous: bool = False,
) -> list[asyncio.Task]:
handlers = self.event_handlers.get(event_type, [])
if include_global_handlers:
handlers = self.global_event_handlers + handlers
tasks = []
for handler, wait_sync in handlers:
if force_synchronous or wait_sync:
tasks.append(asyncio.create_task(self._catch_errors(handler, data)))
else:
background_task.create(self._catch_errors(handler, data))
return tasks
async def run_internal_event(
self, event_type: InternalEventType, custom_type: Any = None, **kwargs: Any
) -> None:
kwargs["source"] = SyncStream.INTERNAL
tasks = self.dispatch_manual_event(
event_type,
custom_type if custom_type is not None else kwargs,
include_global_handlers=False,
)
await asyncio.gather(*tasks)
def dispatch_internal_event(
self, event_type: InternalEventType, custom_type: Any = None, **kwargs: Any
) -> list[asyncio.Task]:
kwargs["source"] = SyncStream.INTERNAL
return self.dispatch_manual_event(
event_type,
custom_type if custom_type is not None else kwargs,
include_global_handlers=False,
)
def _try_deserialize(self, type: Type[T], data: JSON) -> T | GenericEvent:
try:
return type.deserialize(data)
except SerializerError as e:
self.log.trace("Deserialization error traceback", exc_info=True)
self.log.warning(f"Failed to deserialize {data} into {type.__name__}: {e}")
try:
return GenericEvent.deserialize(data)
except SerializerError:
return None
def handle_sync(self, data: JSON) -> list[asyncio.Task]:
"""
Handle a /sync object.
Args:
data: The data from a /sync request.
"""
tasks = []
otk_count = data.get("device_one_time_keys_count", {})
tasks += self.dispatch_internal_event(
InternalEventType.DEVICE_OTK_COUNT,
custom_type=DeviceOTKCount(
curve25519=otk_count.get("curve25519", 0),
signed_curve25519=otk_count.get("signed_curve25519", 0),
),
)
device_lists = data.get("device_lists", {})
tasks += self.dispatch_internal_event(
InternalEventType.DEVICE_LISTS,
custom_type=DeviceLists(
changed=device_lists.get("changed", []),
left=device_lists.get("left", []),
),
)
for raw_event in data.get("account_data", {}).get("events", []):
tasks += self.dispatch_event(
self._try_deserialize(AccountDataEvent, raw_event), source=SyncStream.ACCOUNT_DATA
)
for raw_event in data.get("ephemeral", {}).get("events", []):
tasks += self.dispatch_event(
self._try_deserialize(EphemeralEvent, raw_event), source=SyncStream.EPHEMERAL
)
for raw_event in data.get("to_device", {}).get("events", []):
tasks += self.dispatch_event(
self._try_deserialize(ToDeviceEvent, raw_event), source=SyncStream.TO_DEVICE
)
rooms = data.get("rooms", {})
for room_id, room_data in rooms.get("join", {}).items():
for raw_event in room_data.get("state", {}).get("events", []):
raw_event["room_id"] = room_id
tasks += self.dispatch_event(
self._try_deserialize(StateEvent, raw_event),
source=SyncStream.JOINED_ROOM | SyncStream.STATE,
)
for raw_event in room_data.get("timeline", {}).get("events", []):
raw_event["room_id"] = room_id
tasks += self.dispatch_event(
self._try_deserialize(Event, raw_event),
source=SyncStream.JOINED_ROOM | SyncStream.TIMELINE,
)
for raw_event in room_data.get("ephemeral", {}).get("events", []):
raw_event["room_id"] = room_id
tasks += self.dispatch_event(
self._try_deserialize(EphemeralEvent, raw_event),
source=SyncStream.JOINED_ROOM | SyncStream.EPHEMERAL,
)
for room_id, room_data in rooms.get("invite", {}).items():
events: list[dict[str, JSON]] = room_data.get("invite_state", {}).get("events", [])
for raw_event in events:
raw_event["room_id"] = room_id
raw_invite = next(
raw_event
for raw_event in events
if raw_event.get("type", "") == "m.room.member"
and raw_event.get("state_key", "") == self.mxid
)
# These aren't required by the spec, so make sure they're set
raw_invite.setdefault("event_id", None)
raw_invite.setdefault("origin_server_ts", int(time.time() * 1000))
invite = self._try_deserialize(StateEvent, raw_invite)
invite.unsigned.invite_room_state = [
self._try_deserialize(StrippedStateEvent, raw_event)
for raw_event in events
if raw_event != raw_invite
]
tasks += self.dispatch_event(invite, source=SyncStream.INVITED_ROOM | SyncStream.STATE)
for room_id, room_data in rooms.get("leave", {}).items():
for raw_event in room_data.get("timeline", {}).get("events", []):
if "state_key" in raw_event:
raw_event["room_id"] = room_id
tasks += self.dispatch_event(
self._try_deserialize(StateEvent, raw_event),
source=SyncStream.LEFT_ROOM | SyncStream.TIMELINE,
)
return tasks
def start(self, filter_data: FilterID | Filter | None) -> asyncio.Future:
"""
Start syncing with the server. Can be stopped with :meth:`stop`.
Args:
filter_data: The filter data or filter ID to use for syncing.
"""
if self.syncing_task is not None:
self.syncing_task.cancel()
self.syncing_task = asyncio.create_task(self._try_start(filter_data))
return self.syncing_task
async def _try_start(self, filter_data: FilterID | Filter | None) -> None:
try:
if isinstance(filter_data, Filter):
filter_data = await self.create_filter(filter_data)
await self._start(filter_data)
except asyncio.CancelledError:
self.log.debug("Syncing cancelled")
except Exception as e:
self.log.critical("Fatal error while syncing", exc_info=True)
await self.run_internal_event(InternalEventType.SYNC_STOPPED, error=e)
return
except BaseException as e:
self.log.warning(
f"Syncing stopped with unexpected {e.__class__.__name__}", exc_info=True
)
raise
else:
self.log.debug("Syncing stopped without exception")
await self.run_internal_event(InternalEventType.SYNC_STOPPED, error=None)
async def _start(self, filter_id: FilterID | None) -> None:
fail_sleep = 5
is_first = True
self.log.debug("Starting syncing")
next_batch = await self.sync_store.get_next_batch()
await self.run_internal_event(InternalEventType.SYNC_STARTED)
timeout = 30
while True:
current_batch = next_batch
start = time.monotonic()
try:
data = await self.sync(
since=current_batch,
filter_id=filter_id,
set_presence=self.presence,
timeout=timeout * 1000,
)
except (asyncio.CancelledError, MUnknownToken):
raise
except Exception as e:
self.log.warning(
f"Sync request errored: {type(e).__name__}: {e}, waiting {fail_sleep}"
" seconds before continuing"
)
await self.run_internal_event(
InternalEventType.SYNC_ERRORED, error=e, sleep_for=fail_sleep
)
await asyncio.sleep(fail_sleep)
if fail_sleep < 320:
fail_sleep *= 2
continue
if fail_sleep != 5:
self.log.debug("Sync error resolved")
fail_sleep = 5
duration = time.monotonic() - start
if current_batch and duration > timeout + 10:
self.log.warning(f"Sync request ({current_batch}) took {duration:.3f} seconds")
is_initial = not current_batch
data["net.maunium.mautrix"] = {
"is_initial": is_initial,
"is_first": is_first,
}
next_batch = data.get("next_batch")
try:
await self.sync_store.put_next_batch(next_batch)
except Exception:
self.log.warning("Failed to store next batch", exc_info=True)
await self.run_internal_event(InternalEventType.SYNC_SUCCESSFUL, data=data)
if (self.ignore_first_sync and is_first) or (self.ignore_initial_sync and is_initial):
is_first = False
continue
is_first = False
self.log.silly(f"Starting sync handling ({current_batch})")
start = time.monotonic()
try:
tasks = self.handle_sync(data)
await asyncio.gather(*tasks)
except Exception:
self.log.exception(f"Sync handling ({current_batch}) errored")
else:
self.log.silly(f"Finished sync handling ({current_batch})")
finally:
duration = time.monotonic() - start
if duration > 10:
self.log.warning(
f"Sync handling ({current_batch}) took {duration:.3f} seconds"
)
def stop(self) -> None:
"""
Stop a sync started with :meth:`start`.
"""
if self.syncing_task:
self.syncing_task.cancel()
self.syncing_task = None
@abstractmethod
async def create_filter(self, filter_params: Filter) -> FilterID:
pass
@abstractmethod
async def sync(
self,
since: SyncToken = None,
timeout: int = 30000,
filter_id: FilterID = None,
full_state: bool = False,
set_presence: PresenceState = None,
) -> JSON:
pass
python-0.20.4/mautrix/crypto/ 0000775 0000000 0000000 00000000000 14547234302 0016133 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/crypto/__init__.py 0000664 0000000 0000000 00000001177 14547234302 0020252 0 ustar 00root root 0000000 0000000 from .account import OlmAccount
from .key_share import RejectKeyShare
from .sessions import InboundGroupSession, OutboundGroupSession, RatchetSafety, Session
# These have to be last
from .store import ( # isort: skip
CryptoStore,
MemoryCryptoStore,
PgCryptoStateStore,
PgCryptoStore,
StateStore,
)
from .machine import OlmMachine # isort: skip
__all__ = [
"OlmAccount",
"RejectKeyShare",
"InboundGroupSession",
"OutboundGroupSession",
"Session",
"CryptoStore",
"MemoryCryptoStore",
"PgCryptoStateStore",
"PgCryptoStore",
"StateStore",
"OlmMachine",
"attachments",
]
python-0.20.4/mautrix/crypto/account.py 0000664 0000000 0000000 00000007534 14547234302 0020152 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Dict, Optional, cast
from datetime import datetime
import olm
from mautrix.types import (
DeviceID,
EncryptionAlgorithm,
EncryptionKeyAlgorithm,
IdentityKey,
SigningKey,
UserID,
)
from . import base
from .sessions import Session
class OlmAccount(olm.Account):
shared: bool
_signing_key: Optional[SigningKey]
_identity_key: Optional[IdentityKey]
def __init__(self) -> None:
super().__init__()
self.shared = False
self._signing_key = None
self._identity_key = None
@property
def signing_key(self) -> SigningKey:
if self._signing_key is None:
self._signing_key = SigningKey(self.identity_keys["ed25519"])
return self._signing_key
@property
def identity_key(self) -> IdentityKey:
if self._identity_key is None:
self._identity_key = IdentityKey(self.identity_keys["curve25519"])
return self._identity_key
@property
def fingerprint(self) -> str:
"""
Fingerprint is the base64-encoded signing key of this account, with spaces every 4
characters. This is what is used for manual device verification.
"""
key = self.signing_key
return " ".join([key[i : i + 4] for i in range(0, len(key), 4)])
@classmethod
def from_pickle(cls, pickle: bytes, passphrase: str, shared: bool) -> "OlmAccount":
account = cast(OlmAccount, super().from_pickle(pickle, passphrase))
account.shared = shared
account._signing_key = None
account._identity_key = None
return account
def new_inbound_session(self, sender_key: IdentityKey, ciphertext: str) -> Session:
session = olm.InboundSession(self, olm.OlmPreKeyMessage(ciphertext), sender_key)
self.remove_one_time_keys(session)
return Session.from_pickle(
session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now()
)
def new_outbound_session(self, target_key: IdentityKey, one_time_key: IdentityKey) -> Session:
session = olm.OutboundSession(self, target_key, one_time_key)
return Session.from_pickle(
session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now()
)
def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> Dict[str, Any]:
device_keys = {
"user_id": user_id,
"device_id": device_id,
"algorithms": [EncryptionAlgorithm.OLM_V1.value, EncryptionAlgorithm.MEGOLM_V1.value],
"keys": {
f"{algorithm}:{device_id}": key for algorithm, key in self.identity_keys.items()
},
}
signature = self.sign(base.canonical_json(device_keys))
device_keys["signatures"] = {
user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature}
}
return device_keys
def get_one_time_keys(
self, user_id: UserID, device_id: DeviceID, current_otk_count: int
) -> Dict[str, Any]:
new_count = self.max_one_time_keys // 2 - current_otk_count
if new_count > 0:
self.generate_one_time_keys(new_count)
keys = {}
for key_id, key in self.one_time_keys.get("curve25519", {}).items():
signature = self.sign(base.canonical_json({"key": key}))
keys[f"{EncryptionKeyAlgorithm.SIGNED_CURVE25519}:{key_id}"] = {
"key": key,
"signatures": {
user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature}
},
}
self.mark_keys_as_published()
return keys
python-0.20.4/mautrix/crypto/attachments/ 0000775 0000000 0000000 00000000000 14547234302 0020446 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/crypto/attachments/__init__.py 0000664 0000000 0000000 00000001013 14547234302 0022552 0 ustar 00root root 0000000 0000000 from .async_attachments import (
async_encrypt_attachment,
async_generator_from_data,
async_inplace_encrypt_attachment,
)
from .attachments import (
decrypt_attachment,
encrypt_attachment,
encrypted_attachment_generator,
inplace_encrypt_attachment,
)
__all__ = [
"async_encrypt_attachment",
"async_generator_from_data",
"async_inplace_encrypt_attachment",
"decrypt_attachment",
"encrypt_attachment",
"encrypted_attachment_generator",
"inplace_encrypt_attachment",
]
python-0.20.4/mautrix/crypto/attachments/async_attachments.py 0000664 0000000 0000000 00000006630 14547234302 0024535 0 ustar 00root root 0000000 0000000 # Copyright © 2018, 2019 Damir Jelić
# Copyright © 2019 miruka
#
# Permission to use, copy, modify, and/or distribute this software for
# any purpose with or without fee is hereby granted, provided that the
# above copyright notice and this permission notice appear in all copies.
#
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import AsyncGenerator, AsyncIterable, Iterable
from functools import partial
import asyncio
import io
from mautrix.types import EncryptedFile
from .attachments import _get_decryption_info, _prepare_encryption, inplace_encrypt_attachment
async def async_encrypt_attachment(
data: bytes | Iterable[bytes] | AsyncIterable[bytes] | io.BufferedIOBase,
) -> AsyncGenerator[bytes | EncryptedFile, None]:
"""Async generator to encrypt data in order to send it as an encrypted
attachment.
This function lazily encrypts and yields data, thus it can be used to
encrypt large files without fully loading them into memory if an iterable
or async iterable of bytes is passed as data.
Args:
data: The data to encrypt.
Passing an async iterable allows the file data to be read in an asynchronous and lazy
(without reading the entire file into memory) way.
Passing a non-async iterable or standard open binary file object will still allow the
data to be read lazily, but not asynchronously.
Yields:
The encrypted bytes for each chunk of data.
The last yielded value will be a dict containing the info needed to
decrypt data. The keys are:
| key: AES-CTR JWK key object.
| iv: Base64 encoded 16 byte AES-CTR IV.
| hashes.sha256: Base64 encoded SHA-256 hash of the ciphertext.
"""
key, iv, cipher, sha256 = _prepare_encryption()
loop = asyncio.get_running_loop()
async for chunk in async_generator_from_data(data):
update_crypt = partial(cipher.encrypt, chunk)
crypt_chunk = await loop.run_in_executor(None, update_crypt)
update_hash = partial(sha256.update, crypt_chunk)
await loop.run_in_executor(None, update_hash)
yield crypt_chunk
yield _get_decryption_info(key, iv, sha256)
async def async_inplace_encrypt_attachment(data: bytearray) -> EncryptedFile:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, partial(inplace_encrypt_attachment, data))
async def async_generator_from_data(
data: bytes | Iterable[bytes] | AsyncIterable[bytes] | io.BufferedIOBase,
chunk_size: int = 4 * 1024,
) -> AsyncGenerator[bytes, None]:
if isinstance(data, bytes):
chunks = (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))
for chunk in chunks:
yield chunk
elif isinstance(data, io.BufferedIOBase):
while True:
chunk = data.read(chunk_size)
if not chunk:
return
yield chunk
elif isinstance(data, Iterable):
for chunk in data:
yield chunk
elif isinstance(data, AsyncIterable):
async for chunk in data:
yield chunk
else:
raise TypeError(f"Unknown type for data: {data!r}")
python-0.20.4/mautrix/crypto/attachments/async_attachments_test.py 0000664 0000000 0000000 00000002636 14547234302 0025576 0 ustar 00root root 0000000 0000000 # Copyright © 2019 Damir Jelić (under the Apache 2.0 license)
# Copyright © 2019 miruka (under the Apache 2.0 license)
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.types import EncryptedFile
from .async_attachments import async_encrypt_attachment, async_inplace_encrypt_attachment
from .attachments import decrypt_attachment
try:
from Crypto import Random
except ImportError:
from Cryptodome import Random
async def _get_data_cypher_keys(data: bytes) -> tuple[bytes, EncryptedFile]:
*chunks, keys = [i async for i in async_encrypt_attachment(data)]
return b"".join(chunks), keys
async def test_async_encrypt():
data = b"Test bytes"
cyphertext, keys = await _get_data_cypher_keys(data)
plaintext = decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], keys.iv)
assert data == plaintext
async def test_async_inplace_encrypt():
orig_data = b"Test bytes"
data = bytearray(orig_data)
keys = await async_inplace_encrypt_attachment(data)
assert data != orig_data
decrypt_attachment(data, keys.key.key, keys.hashes["sha256"], keys.iv, inplace=True)
assert data == orig_data
python-0.20.4/mautrix/crypto/attachments/attachments.py 0000664 0000000 0000000 00000012363 14547234302 0023340 0 ustar 00root root 0000000 0000000 # Copyright 2018 Zil0 (under the Apache 2.0 license)
# Copyright © 2019 Damir Jelić (under the Apache 2.0 license)
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Generator, Iterable
import binascii
import struct
import unpaddedbase64
from mautrix.errors import DecryptionError
from mautrix.types import EncryptedFile, JSONWebKey
try:
from Crypto import Random
from Crypto.Cipher import AES
from Crypto.Hash import SHA256
from Crypto.Util import Counter
except ImportError:
from Cryptodome import Random
from Cryptodome.Cipher import AES
from Cryptodome.Hash import SHA256
from Cryptodome.Util import Counter
def decrypt_attachment(
ciphertext: bytes | bytearray | memoryview, key: str, hash: str, iv: str, inplace: bool = False
) -> bytes:
"""Decrypt an encrypted attachment.
Args:
ciphertext: The data to decrypt.
key: AES_CTR JWK key object.
hash: Base64 encoded SHA-256 hash of the ciphertext.
iv: Base64 encoded 16 byte AES-CTR IV.
inplace: Should the decryption be performed in-place?
The input must be a bytearray or writable memoryview to use this.
Returns:
The plaintext bytes.
Raises:
EncryptionError: if the integrity check fails.
"""
expected_hash = unpaddedbase64.decode_base64(hash)
h = SHA256.new()
h.update(ciphertext)
if h.digest() != expected_hash:
raise DecryptionError("Mismatched SHA-256 digest")
try:
byte_key: bytes = unpaddedbase64.decode_base64(key)
except (binascii.Error, TypeError):
raise DecryptionError("Error decoding key")
try:
byte_iv: bytes = unpaddedbase64.decode_base64(iv)
if len(byte_iv) != 16:
raise DecryptionError("Invalid IV length")
prefix = byte_iv[:8]
# A non-zero IV counter is not spec-compliant, but some clients still do it,
# so decode the counter part too.
initial_value = struct.unpack(">Q", byte_iv[8:])[0]
except (binascii.Error, TypeError, IndexError, struct.error):
raise DecryptionError("Error decoding IV")
ctr = Counter.new(64, prefix=prefix, initial_value=initial_value)
try:
cipher = AES.new(byte_key, AES.MODE_CTR, counter=ctr)
except ValueError as e:
raise DecryptionError("Failed to create AES cipher") from e
if inplace:
cipher.decrypt(ciphertext, ciphertext)
return ciphertext
else:
return cipher.decrypt(ciphertext)
def encrypt_attachment(plaintext: bytes) -> tuple[bytes, EncryptedFile]:
"""Encrypt data in order to send it as an encrypted attachment.
Args:
plaintext: The data to encrypt.
Returns:
A tuple with the encrypted bytes and a dict containing the info needed
to decrypt data. See ``encrypted_attachment_generator()`` for the keys.
"""
values = list(encrypted_attachment_generator(plaintext))
return b"".join(values[:-1]), values[-1]
def _prepare_encryption() -> tuple[bytes, bytes, AES, SHA256.SHA256Hash]:
key = Random.new().read(32)
# 8 bytes IV
iv = Random.new().read(8)
# 8 bytes counter, prefixed by the IV
ctr = Counter.new(64, prefix=iv, initial_value=0)
cipher = AES.new(key, AES.MODE_CTR, counter=ctr)
sha256 = SHA256.new()
return key, iv, cipher, sha256
def inplace_encrypt_attachment(data: bytearray | memoryview) -> EncryptedFile:
key, iv, cipher, sha256 = _prepare_encryption()
cipher.encrypt(plaintext=data, output=data)
sha256.update(data)
return _get_decryption_info(key, iv, sha256)
def encrypted_attachment_generator(
data: bytes | Iterable[bytes],
) -> Generator[bytes | EncryptedFile, None, None]:
"""Generator to encrypt data in order to send it as an encrypted
attachment.
Unlike ``encrypt_attachment()``, this function lazily encrypts and yields
data, thus it can be used to encrypt large files without fully loading them
into memory if an iterable of bytes is passed as data.
Args:
data: The data to encrypt.
Yields:
The encrypted bytes for each chunk of data.
The last yielded value will be a dict containing the info needed to decrypt data.
"""
key, iv, cipher, sha256 = _prepare_encryption()
if isinstance(data, bytes):
data = [data]
for chunk in data:
encrypted_chunk = cipher.encrypt(chunk) # in executor
sha256.update(encrypted_chunk) # in executor
yield encrypted_chunk
yield _get_decryption_info(key, iv, sha256)
def _get_decryption_info(key: bytes, iv: bytes, sha256: SHA256.SHA256Hash) -> EncryptedFile:
return EncryptedFile(
version="v2",
iv=unpaddedbase64.encode_base64(iv + b"\x00" * 8),
hashes={"sha256": unpaddedbase64.encode_base64(sha256.digest())},
key=JSONWebKey(
key_type="oct",
algorithm="A256CTR",
extractable=True,
key_ops=["encrypt", "decrypt"],
key=unpaddedbase64.encode_base64(key, urlsafe=True),
),
)
python-0.20.4/mautrix/crypto/attachments/attachments_test.py 0000664 0000000 0000000 00000005321 14547234302 0024373 0 ustar 00root root 0000000 0000000 # Copyright © 2019 Damir Jelić (under the Apache 2.0 license)
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import pytest
import unpaddedbase64
from mautrix.errors import DecryptionError
from .attachments import decrypt_attachment, encrypt_attachment, inplace_encrypt_attachment
try:
from Crypto import Random
except ImportError:
from Cryptodome import Random
def test_encrypt():
data = b"Test bytes"
cyphertext, keys = encrypt_attachment(data)
plaintext = decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], keys.iv)
assert data == plaintext
def test_inplace_encrypt():
orig_data = b"Test bytes"
data = bytearray(orig_data)
keys = inplace_encrypt_attachment(data)
assert data != orig_data
decrypt_attachment(data, keys.key.key, keys.hashes["sha256"], keys.iv, inplace=True)
assert data == orig_data
def test_hash_verification():
data = b"Test bytes"
cyphertext, keys = encrypt_attachment(data)
with pytest.raises(DecryptionError):
decrypt_attachment(cyphertext, keys.key.key, "Fake hash", keys.iv)
def test_invalid_key():
data = b"Test bytes"
cyphertext, keys = encrypt_attachment(data)
with pytest.raises(DecryptionError):
decrypt_attachment(cyphertext, "Fake key", keys.hashes["sha256"], keys.iv)
def test_invalid_iv():
data = b"Test bytes"
cyphertext, keys = encrypt_attachment(data)
with pytest.raises(DecryptionError):
decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], "Fake iv")
def test_short_key():
data = b"Test bytes"
cyphertext, keys = encrypt_attachment(data)
with pytest.raises(DecryptionError):
decrypt_attachment(
cyphertext,
unpaddedbase64.encode_base64(b"Fake key", urlsafe=True),
keys["hashes"]["sha256"],
keys["iv"],
)
def test_short_iv():
data = b"Test bytes"
cyphertext, keys = encrypt_attachment(data)
with pytest.raises(DecryptionError):
decrypt_attachment(
cyphertext,
keys.key.key,
keys.hashes["sha256"],
unpaddedbase64.encode_base64(b"F" + b"\x00" * 8),
)
def test_fake_key():
data = b"Test bytes"
cyphertext, keys = encrypt_attachment(data)
fake_key = Random.new().read(32)
plaintext = decrypt_attachment(
cyphertext,
unpaddedbase64.encode_base64(fake_key, urlsafe=True),
keys["hashes"]["sha256"],
keys["iv"],
)
assert plaintext != data
python-0.20.4/mautrix/crypto/base.py 0000664 0000000 0000000 00000011164 14547234302 0017422 0 ustar 00root root 0000000 0000000 # Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Awaitable, Callable, TypedDict
import asyncio
import functools
import json
import olm
from mautrix.errors import MForbidden, MNotFound
from mautrix.types import (
DeviceID,
EncryptionKeyAlgorithm,
EventType,
IdentityKey,
KeyID,
RequestedKeyInfo,
RoomEncryptionStateEventContent,
RoomID,
RoomKeyEventContent,
SessionID,
SigningKey,
TrustState,
UserID,
)
from mautrix.util.logging import TraceLogger
from .. import client as cli, crypto
class SignedObject(TypedDict):
signatures: dict[UserID, dict[str, str]]
unsigned: Any
class BaseOlmMachine:
client: cli.Client
log: TraceLogger
crypto_store: crypto.CryptoStore
state_store: crypto.StateStore
account: account.OlmAccount
send_keys_min_trust: TrustState
share_keys_min_trust: TrustState
allow_key_share: Callable[[crypto.DeviceIdentity, RequestedKeyInfo], Awaitable[bool]]
delete_outbound_keys_on_ack: bool
dont_store_outbound_keys: bool
delete_previous_keys_on_receive: bool
ratchet_keys_on_decrypt: bool
delete_fully_used_keys_on_decrypt: bool
delete_keys_on_device_delete: bool
disable_device_change_key_rotation: bool
# Futures that wait for responses to a key request
_key_request_waiters: dict[SessionID, asyncio.Future]
# Futures that wait for a session to be received (either normally or through a key request)
_inbound_session_waiters: dict[SessionID, asyncio.Future]
_prev_unwedge: dict[IdentityKey, float]
_fetch_keys_lock: asyncio.Lock
_megolm_decrypt_lock: asyncio.Lock
_share_keys_lock: asyncio.Lock
_last_key_share: float
_cs_fetch_attempted: set[UserID]
async def wait_for_session(
self, room_id: RoomID, session_id: SessionID, timeout: float = 3
) -> bool:
try:
fut = self._inbound_session_waiters[session_id]
except KeyError:
fut = asyncio.get_running_loop().create_future()
self._inbound_session_waiters[session_id] = fut
try:
return await asyncio.wait_for(asyncio.shield(fut), timeout)
except asyncio.TimeoutError:
return await self.crypto_store.has_group_session(room_id, session_id)
def _mark_session_received(self, session_id: SessionID) -> None:
try:
self._inbound_session_waiters.pop(session_id).set_result(True)
except KeyError:
return
async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None:
encryption_info = await self.state_store.get_encryption_info(evt.room_id)
if not encryption_info:
self.log.warning(
f"Encryption info for {evt.room_id} not found in state store, fetching from server"
)
try:
encryption_info = await self.client.get_state_event(
evt.room_id, EventType.ROOM_ENCRYPTION
)
except (MNotFound, MForbidden) as e:
self.log.warning(
f"Failed to get encryption info for {evt.room_id} from server: {e},"
" using defaults"
)
encryption_info = RoomEncryptionStateEventContent()
if not encryption_info:
self.log.warning(
f"Didn't find encryption info for {evt.room_id} on server either,"
" using defaults"
)
encryption_info = RoomEncryptionStateEventContent()
if not evt.beeper_max_age_ms:
evt.beeper_max_age_ms = encryption_info.rotation_period_ms
if not evt.beeper_max_messages:
evt.beeper_max_messages = encryption_info.rotation_period_msgs
canonical_json = functools.partial(
json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True
)
def verify_signature_json(
data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey
) -> bool:
data_copy = {**data}
data_copy.pop("unsigned", None)
signatures = data_copy.pop("signatures")
key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name))
try:
signature = signatures[user_id][key_id]
except KeyError:
return False
signed_data = canonical_json(data_copy)
try:
olm.ed25519_verify(key, signed_data, signature)
return True
except olm.OlmVerifyError:
return False
python-0.20.4/mautrix/crypto/decrypt_megolm.py 0000664 0000000 0000000 00000017366 14547234302 0021534 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import json
import olm
from mautrix.errors import (
DecryptedPayloadError,
DecryptionError,
DuplicateMessageIndex,
MismatchingRoomError,
SessionNotFound,
VerificationError,
)
from mautrix.types import (
EncryptedEvent,
EncryptedMegolmEventContent,
EncryptionAlgorithm,
Event,
SessionID,
TrustState,
)
from .device_lists import DeviceListMachine
from .sessions import InboundGroupSession
class MegolmDecryptionMachine(DeviceListMachine):
async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event:
"""
Decrypt an event that was encrypted using Megolm.
Args:
evt: The whole encrypted event.
Returns:
The decrypted event, including some unencrypted metadata from the input event.
Raises:
DecryptionError: If decryption failed.
"""
if not isinstance(evt.content, EncryptedMegolmEventContent):
raise DecryptionError("Unsupported event content class")
elif evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1:
raise DecryptionError("Unsupported event encryption algorithm")
async with self._megolm_decrypt_lock:
session = await self.crypto_store.get_group_session(
evt.room_id, evt.content.session_id
)
if session is None:
# TODO check if olm session is wedged
raise SessionNotFound(evt.content.session_id, evt.content.sender_key)
try:
plaintext, index = session.decrypt(evt.content.ciphertext)
except olm.OlmGroupSessionError as e:
raise DecryptionError("Failed to decrypt megolm event") from e
if not await self.crypto_store.validate_message_index(
session.sender_key, SessionID(session.id), evt.event_id, index, evt.timestamp
):
raise DuplicateMessageIndex()
await self._ratchet_session(session, index)
forwarded_keys = False
if (
evt.content.device_id == self.client.device_id
and session.signing_key == self.account.signing_key
and session.sender_key == self.account.identity_key
and not session.forwarding_chain
):
trust_level = TrustState.VERIFIED
else:
device = await self.get_or_fetch_device_by_key(evt.sender, session.sender_key)
if not session.forwarding_chain or (
len(session.forwarding_chain) == 1
and session.forwarding_chain[0] == session.sender_key
):
if not device:
self.log.debug(
f"Couldn't resolve trust level of session {session.id}: "
f"sent by unknown device {evt.sender}/{session.sender_key}"
)
trust_level = TrustState.UNKNOWN_DEVICE
elif (
device.signing_key != session.signing_key
or device.identity_key != session.sender_key
):
raise VerificationError()
else:
trust_level = await self.resolve_trust(device)
else:
forwarded_keys = True
last_chain_item = session.forwarding_chain[-1]
received_from = await self.crypto_store.find_device_by_key(
evt.sender, last_chain_item
)
if received_from:
trust_level = await self.resolve_trust(received_from)
else:
self.log.debug(
f"Couldn't resolve trust level of session {session.id}: "
f"forwarding chain ends with unknown device {last_chain_item}"
)
trust_level = TrustState.FORWARDED
try:
data = json.loads(plaintext)
room_id = data["room_id"]
event_type = data["type"]
content = data["content"]
except json.JSONDecodeError as e:
raise DecryptedPayloadError("Failed to parse megolm payload") from e
except KeyError as e:
raise DecryptedPayloadError("Megolm payload is missing fields") from e
if room_id != evt.room_id:
raise MismatchingRoomError()
if evt.content.relates_to and "m.relates_to" not in content:
content["m.relates_to"] = evt.content.relates_to.serialize()
result = Event.deserialize(
{
"room_id": evt.room_id,
"event_id": evt.event_id,
"sender": evt.sender,
"origin_server_ts": evt.timestamp,
"type": event_type,
"content": content,
}
)
result.unsigned = evt.unsigned
result.type = result.type.with_class(evt.type.t_class)
result["mautrix"] = {
"trust_state": trust_level,
"forwarded_keys": forwarded_keys,
"was_encrypted": True,
}
return result
async def _ratchet_session(self, sess: InboundGroupSession, index: int) -> None:
expected_message_index = sess.ratchet_safety.next_index
did_modify = True
if index > expected_message_index:
sess.ratchet_safety.missed_indices += list(range(expected_message_index, index))
sess.ratchet_safety.next_index = index + 1
elif index == expected_message_index:
sess.ratchet_safety.next_index = index + 1
else:
try:
sess.ratchet_safety.missed_indices.remove(index)
except ValueError:
did_modify = False
# Use presence of received_at as a sign that this is a recent megolm session,
# and therefore it's safe to drop missed indices entirely.
if (
sess.received_at
and sess.ratchet_safety.missed_indices
and sess.ratchet_safety.missed_indices[0] < expected_message_index - 10
):
i = 0
for i, lost_index in enumerate(sess.ratchet_safety.missed_indices):
if lost_index < expected_message_index - 10:
sess.ratchet_safety.lost_indices.append(lost_index)
else:
break
sess.ratchet_safety.missed_indices = sess.ratchet_safety.missed_indices[i + 1 :]
ratchet_target_index = sess.ratchet_safety.next_index
if len(sess.ratchet_safety.missed_indices) > 0:
ratchet_target_index = min(sess.ratchet_safety.missed_indices)
self.log.debug(
f"Ratchet safety info for {sess.id}: {sess.ratchet_safety}, {ratchet_target_index=}"
)
sess_id = SessionID(sess.id)
if (
sess.max_messages
and ratchet_target_index >= sess.max_messages
and not sess.ratchet_safety.missed_indices
and self.delete_fully_used_keys_on_decrypt
):
self.log.info(f"Deleting fully used session {sess.id}")
await self.crypto_store.redact_group_session(
sess.room_id, sess_id, reason="maximum messages reached"
)
return
elif sess.first_known_index < ratchet_target_index and self.ratchet_keys_on_decrypt:
self.log.info(f"Ratcheting session {sess.id} to {ratchet_target_index}")
sess = sess.ratchet_to(ratchet_target_index)
elif not did_modify:
return
await self.crypto_store.put_group_session(sess.room_id, sess.sender_key, sess_id, sess)
python-0.20.4/mautrix/crypto/decrypt_olm.py 0000664 0000000 0000000 00000012725 14547234302 0021035 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
import asyncio
import olm
from mautrix.errors import DecryptionError, MatchingSessionDecryptionError
from mautrix.types import (
DecryptedOlmEvent,
EncryptedOlmEventContent,
EncryptionAlgorithm,
IdentityKey,
OlmCiphertext,
OlmMsgType,
ToDeviceEvent,
UserID,
)
from mautrix.util import background_task
from .base import BaseOlmMachine
from .sessions import Session
class OlmDecryptionMachine(BaseOlmMachine):
async def _decrypt_olm_event(self, evt: ToDeviceEvent) -> DecryptedOlmEvent:
if not isinstance(evt.content, EncryptedOlmEventContent):
raise DecryptionError("unsupported event content class")
elif evt.content.algorithm != EncryptionAlgorithm.OLM_V1:
raise DecryptionError("unsupported event encryption algorithm")
try:
own_content = evt.content.ciphertext[self.account.identity_key]
except KeyError:
raise DecryptionError("olm event doesn't contain ciphertext for this device")
self.log.debug(
f"Decrypting to-device olm event from {evt.sender}/{evt.content.sender_key}"
)
plaintext = await self._decrypt_olm_ciphertext(
evt.sender, evt.content.sender_key, own_content
)
try:
decrypted_evt: DecryptedOlmEvent = DecryptedOlmEvent.parse_json(plaintext)
except Exception:
self.log.trace("Failed to parse olm event plaintext: %s", plaintext)
raise
if decrypted_evt.sender != evt.sender:
raise DecryptionError("mismatched sender in olm payload")
elif decrypted_evt.recipient != self.client.mxid:
raise DecryptionError("mismatched recipient in olm payload")
elif decrypted_evt.recipient_keys.ed25519 != self.account.signing_key:
raise DecryptionError("mismatched recipient key in olm payload")
decrypted_evt.sender_key = evt.content.sender_key
decrypted_evt.source = evt
self.log.debug(
f"Successfully decrypted olm event from {evt.sender}/{decrypted_evt.sender_device} "
f"(sender key: {decrypted_evt.sender_key} into a {decrypted_evt.type}"
)
return decrypted_evt
async def _decrypt_olm_ciphertext(
self, sender: UserID, sender_key: IdentityKey, message: OlmCiphertext
) -> str:
if message.type not in (OlmMsgType.PREKEY, OlmMsgType.MESSAGE):
raise DecryptionError("unsupported olm message type")
try:
plaintext = await self._try_decrypt_olm_ciphertext(sender_key, message)
except MatchingSessionDecryptionError:
self.log.warning(
f"Found matching session yet decryption failed for sender {sender}"
f" with key {sender_key}"
)
background_task.create(self._unwedge_session(sender, sender_key))
raise
if not plaintext:
if message.type != OlmMsgType.PREKEY:
background_task.create(self._unwedge_session(sender, sender_key))
raise DecryptionError("Decryption failed for normal message")
self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}")
try:
session = await self._create_inbound_session(sender_key, message.body)
except olm.OlmSessionError as e:
background_task.create(self._unwedge_session(sender, sender_key))
raise DecryptionError("Failed to create new session from prekey message") from e
self.log.debug(
f"Created inbound session {session.id} for {sender} (sender key: {sender_key})"
)
try:
plaintext = session.decrypt(message)
except olm.OlmSessionError as e:
raise DecryptionError(
"Failed to decrypt olm event with session created from prekey message"
) from e
await self.crypto_store.update_session(sender_key, session)
return plaintext
async def _try_decrypt_olm_ciphertext(
self, sender_key: IdentityKey, message: OlmCiphertext
) -> Optional[str]:
sessions = await self.crypto_store.get_sessions(sender_key)
for session in sessions:
if message.type == OlmMsgType.PREKEY and not session.matches(message.body):
continue
try:
plaintext = session.decrypt(message)
except olm.OlmSessionError as e:
if message.type == OlmMsgType.PREKEY:
raise MatchingSessionDecryptionError(
"decryption failed with matching session"
) from e
else:
await self.crypto_store.update_session(sender_key, session)
return plaintext
return None
async def _create_inbound_session(self, sender_key: IdentityKey, ciphertext: str) -> Session:
session = self.account.new_inbound_session(sender_key, ciphertext)
await self.crypto_store.put_account(self.account)
await self.crypto_store.add_session(sender_key, session)
return session
async def _unwedge_session(self, sender: UserID, sender_key: IdentityKey) -> None:
raise NotImplementedError()
python-0.20.4/mautrix/crypto/device_lists.py 0000664 0000000 0000000 00000035536 14547234302 0021176 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.errors import DeviceValidationError
from mautrix.types import (
CrossSigner,
CrossSigningKeys,
CrossSigningUsage,
DeviceID,
DeviceIdentity,
DeviceKeys,
EncryptionKeyAlgorithm,
IdentityKey,
KeyID,
QueryKeysResponse,
SigningKey,
SyncToken,
TrustState,
UserID,
)
from .base import BaseOlmMachine, verify_signature_json
class DeviceListMachine(BaseOlmMachine):
async def _fetch_keys(
self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False
) -> dict[UserID, dict[DeviceID, DeviceIdentity]]:
if not include_untracked:
users = await self.crypto_store.filter_tracked_users(users)
if len(users) == 0:
return {}
users = set(users)
self.log.trace(f"Querying keys for {users}")
resp = await self.client.query_keys(users, token=since)
missing_users = users.copy()
for server, err in resp.failures.items():
self.log.warning(f"Query keys failure for {server}: {err}")
data = {}
for user_id, devices in resp.device_keys.items():
missing_users.remove(user_id)
new_devices = {}
existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
self.log.trace(
f"Updating devices for {user_id}, got {len(devices)}, "
f"have {len(existing_devices)} in store"
)
changed = False
ssks = resp.self_signing_keys.get(user_id)
ssk = ssks.first_ed25519_key if ssks else None
for device_id, device_keys in devices.items():
try:
existing = existing_devices[device_id]
except KeyError:
existing = None
changed = True
self.log.trace(f"Validating device {device_keys} of {user_id}")
try:
new_device = await self._validate_device(
user_id, device_id, device_keys, existing
)
except DeviceValidationError as e:
self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
else:
if new_device:
new_devices[device_id] = new_device
await self._store_device_self_signatures(device_keys, ssk)
self.log.debug(
f"Storing new device list for {user_id} containing {len(new_devices)} devices"
)
await self.crypto_store.put_devices(user_id, new_devices)
data[user_id] = new_devices
if changed or len(new_devices) != len(existing_devices):
if self.delete_keys_on_device_delete:
for device_id in existing_devices.keys() - new_devices.keys():
device = existing_devices[device_id]
removed_ids = await self.crypto_store.redact_group_sessions(
room_id=None, sender_key=device.identity_key, reason="device removed"
)
self.log.info(
"Redacted megolm sessions sent by removed device "
f"{device.user_id}/{device.device_id}: {removed_ids}"
)
await self.on_devices_changed(user_id)
for user_id in missing_users:
self.log.warning(f"Didn't get any devices for user {user_id}")
for user_id in users:
await self._store_cross_signing_keys(resp, user_id)
return data
async def _store_device_self_signatures(
self, device_keys: DeviceKeys, self_signing_key: SigningKey | None
) -> None:
device_desc = f"Device {device_keys.user_id}/{device_keys.device_id}"
try:
self_signatures = device_keys.signatures[device_keys.user_id].copy()
except KeyError:
self.log.warning(f"{device_desc} doesn't have any signatures from the user")
return
if len(device_keys.signatures) > 1:
self.log.debug(
f"{device_desc} has signatures from other users (%s)",
set(device_keys.signatures.keys()) - {device_keys.user_id},
)
device_self_sig = self_signatures.pop(
KeyID(EncryptionKeyAlgorithm.ED25519, device_keys.device_id)
)
target = CrossSigner(device_keys.user_id, device_keys.ed25519)
# This one is already validated by _validate_device
await self.crypto_store.put_signature(target, target, device_self_sig)
try:
cs_self_sig = self_signatures.pop(
KeyID(EncryptionKeyAlgorithm.ED25519, self_signing_key)
)
except KeyError:
self.log.warning(f"{device_desc} isn't cross-signed")
else:
is_valid_self_sig = verify_signature_json(
device_keys.serialize(), device_keys.user_id, self_signing_key, self_signing_key
)
if is_valid_self_sig:
signer = CrossSigner(device_keys.user_id, self_signing_key)
await self.crypto_store.put_signature(target, signer, cs_self_sig)
else:
self.log.warning(f"{device_desc} doesn't have a valid cross-signing signature")
if len(self_signatures) > 0:
self.log.debug(
f"{device_desc} has signatures from unexpected keys (%s)",
set(self_signatures.keys()),
)
async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: UserID) -> None:
new_keys: dict[CrossSigningUsage, CrossSigningKeys] = {}
try:
master = new_keys[CrossSigningUsage.MASTER] = resp.master_keys[user_id]
except KeyError:
self.log.debug(f"Didn't get a cross-signing master key for {user_id}")
return
try:
new_keys[CrossSigningUsage.SELF] = resp.self_signing_keys[user_id]
except KeyError:
self.log.debug(f"Didn't get a cross-signing self-signing key for {user_id}")
return
try:
new_keys[CrossSigningUsage.USER] = resp.user_signing_keys[user_id]
except KeyError:
pass
current_keys = await self.crypto_store.get_cross_signing_keys(user_id)
for usage, key in current_keys.items():
if usage in new_keys and key.key != new_keys[usage].first_ed25519_key:
num = await self.crypto_store.drop_signatures_by_key(CrossSigner(user_id, key.key))
if num >= 0:
self.log.debug(
f"Dropped {num} signatures made by key {user_id}/{key.key} ({usage})"
" as it has been replaced"
)
for usage, key in new_keys.items():
actual_key = key.first_ed25519_key
self.log.debug(f"Storing cross-signing key for {user_id}: {actual_key} (type {usage})")
await self.crypto_store.put_cross_signing_key(user_id, usage, actual_key)
if usage != CrossSigningUsage.MASTER and (
KeyID(EncryptionKeyAlgorithm.ED25519, master.first_ed25519_key)
not in key.signatures[user_id]
):
self.log.warning(
f"Cross-signing key {user_id}/{actual_key}/{usage}"
" doesn't seem to have a signature from the master key"
)
for signer_user_id, signatures in key.signatures.items():
for key_id, signature in signatures.items():
signing_key = SigningKey(key_id.key_id)
if signer_user_id == user_id:
try:
device = resp.device_keys[signer_user_id][DeviceID(key_id.key_id)]
signing_key = device.ed25519
except KeyError:
pass
if len(signing_key) != 43:
self.log.debug(
f"Cross-signing key {user_id}/{actual_key} has a signature from "
f"an unknown key {key_id}"
)
continue
signing_key_log = signing_key
if signing_key != key_id.key_id:
signing_key_log = f"{signing_key} ({key_id})"
self.log.debug(
f"Verifying cross-signing key {user_id}/{actual_key} "
f"with key {signer_user_id}/{signing_key_log}"
)
is_valid_sig = verify_signature_json(
key.serialize(), signer_user_id, key_id.key_id, signing_key
)
if is_valid_sig:
self.log.debug(f"Signature from {signing_key_log} for {key_id} verified")
await self.crypto_store.put_signature(
target=CrossSigner(user_id, actual_key),
signer=CrossSigner(signer_user_id, signing_key),
signature=signature,
)
else:
self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}")
async def get_or_fetch_device(
self, user_id: UserID, device_id: DeviceID
) -> DeviceIdentity | None:
device = await self.crypto_store.get_device(user_id, device_id)
if device is not None:
return device
devices = await self._fetch_keys([user_id], include_untracked=True)
try:
return devices[user_id][device_id]
except KeyError:
return None
async def get_or_fetch_device_by_key(
self, user_id: UserID, identity_key: IdentityKey
) -> DeviceIdentity | None:
device = await self.crypto_store.find_device_by_key(user_id, identity_key)
if device is not None:
return device
devices = await self._fetch_keys([user_id], include_untracked=True)
for device in devices.get(user_id, {}).values():
if device.identity_key == identity_key:
return device
return None
async def on_devices_changed(self, user_id: UserID) -> None:
if self.disable_device_change_key_rotation:
return
shared_rooms = await self.state_store.find_shared_rooms(user_id)
self.log.debug(
f"Devices of {user_id} changed, invalidating group session in {shared_rooms}"
)
await self.crypto_store.remove_outbound_group_sessions(shared_rooms)
@staticmethod
async def _validate_device(
user_id: UserID,
device_id: DeviceID,
device_keys: DeviceKeys,
existing: DeviceIdentity | None = None,
) -> DeviceIdentity:
if user_id != device_keys.user_id:
raise DeviceValidationError(
f"mismatching user ID (expected {user_id}, got {device_keys.user_id})"
)
elif device_id != device_keys.device_id:
raise DeviceValidationError(
f"mismatching device ID (expected {device_id}, got {device_keys.device_id})"
)
signing_key = device_keys.ed25519
if not signing_key:
raise DeviceValidationError("didn't find ed25519 signing key")
identity_key = device_keys.curve25519
if not identity_key:
raise DeviceValidationError("didn't find curve25519 identity key")
if existing and existing.signing_key != signing_key:
raise DeviceValidationError(
f"received update for device with different signing key "
f"(expected {existing.signing_key}, got {signing_key})"
)
if not verify_signature_json(device_keys.serialize(), user_id, device_id, signing_key):
raise DeviceValidationError("invalid signature on device keys")
name = device_keys.unsigned.device_display_name or device_id
return DeviceIdentity(
user_id=user_id,
device_id=device_id,
identity_key=identity_key,
signing_key=signing_key,
trust=TrustState.UNVERIFIED,
name=name,
deleted=False,
)
async def resolve_trust(self, device: DeviceIdentity) -> TrustState:
try:
return await self._try_resolve_trust(device)
except Exception:
self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}")
return TrustState.UNVERIFIED
async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState:
if device.trust in (TrustState.VERIFIED, TrustState.BLACKLISTED):
return device.trust
their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id)
if len(their_keys) == 0 and device.user_id not in self._cs_fetch_attempted:
self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...")
async with self._fetch_keys_lock:
if device.user_id not in self._cs_fetch_attempted:
self._cs_fetch_attempted.add(device.user_id)
await self._fetch_keys([device.user_id])
their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id)
try:
msk = their_keys[CrossSigningUsage.MASTER]
ssk = their_keys[CrossSigningUsage.SELF]
except KeyError as e:
self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
return TrustState.UNVERIFIED
ssk_signed = await self.crypto_store.is_key_signed_by(
target=CrossSigner(device.user_id, ssk.key),
signer=CrossSigner(device.user_id, msk.key),
)
if not ssk_signed:
self.log.warning(
f"Self-signing key of {device.user_id} is not signed by their master key"
)
return TrustState.UNVERIFIED
device_signed = await self.crypto_store.is_key_signed_by(
target=CrossSigner(device.user_id, device.signing_key),
signer=CrossSigner(device.user_id, ssk.key),
)
if device_signed:
if await self.is_user_trusted(device.user_id):
return TrustState.CROSS_SIGNED_TRUSTED
elif msk.key == msk.first:
return TrustState.CROSS_SIGNED_TOFU
return TrustState.CROSS_SIGNED_UNTRUSTED
return TrustState.UNVERIFIED
async def is_user_trusted(self, user_id: UserID) -> bool:
# TODO implement once own cross-signing key stuff is ready
return False
python-0.20.4/mautrix/crypto/encrypt_megolm.py 0000664 0000000 0000000 00000035342 14547234302 0021540 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Dict, List, Tuple, Union
from collections import defaultdict
from datetime import datetime, timedelta
import asyncio
import json
import time
from mautrix.errors import EncryptionError, SessionShareError
from mautrix.types import (
DeviceID,
DeviceIdentity,
EncryptedMegolmEventContent,
EncryptionAlgorithm,
EventType,
IdentityKey,
RelatesTo,
RoomID,
RoomKeyWithheldCode,
RoomKeyWithheldEventContent,
Serializable,
SessionID,
SigningKey,
TrustState,
UserID,
)
from .device_lists import DeviceListMachine
from .encrypt_olm import OlmEncryptionMachine
from .sessions import InboundGroupSession, OutboundGroupSession, Session
class Sentinel:
pass
already_shared = Sentinel()
key_missing = Sentinel()
DeviceSessionWrapper = Tuple[Session, DeviceIdentity]
DeviceMap = Dict[UserID, Dict[DeviceID, DeviceSessionWrapper]]
SessionEncryptResult = Union[
type(already_shared), # already shared
DeviceSessionWrapper, # share successful
RoomKeyWithheldEventContent, # won't share
type(key_missing), # missing device
]
class MegolmEncryptionMachine(OlmEncryptionMachine, DeviceListMachine):
_megolm_locks: Dict[RoomID, asyncio.Lock]
_sharing_group_session: Dict[RoomID, asyncio.Event]
def __init__(self) -> None:
super().__init__()
self._megolm_locks = defaultdict(lambda: asyncio.Lock())
self._sharing_group_session = {}
async def encrypt_megolm_event(
self, room_id: RoomID, event_type: EventType, content: Any
) -> EncryptedMegolmEventContent:
"""
Encrypt an event for a specific room using Megolm.
Args:
room_id: The room to encrypt the message for.
event_type: The event type.
content: The event content. Using the content structs in the mautrix.types
module is recommended.
Returns:
The encrypted event content.
Raises:
EncryptionError: If a group session has not been shared.
Use :meth:`share_group_session` to share a group session if this error is raised.
"""
# The crypto store is async, so we need to make sure only one thing is writing at a time.
async with self._megolm_locks[room_id]:
return await self._encrypt_megolm_event(room_id, event_type, content)
async def _encrypt_megolm_event(
self, room_id: RoomID, event_type: EventType, content: Any
) -> EncryptedMegolmEventContent:
self.log.debug(f"Encrypting event of type {event_type} for {room_id}")
session = await self.crypto_store.get_outbound_group_session(room_id)
if not session:
raise EncryptionError("No group session created")
ciphertext = session.encrypt(
json.dumps(
{
"room_id": room_id,
"type": event_type.serialize(),
"content": content.serialize()
if isinstance(content, Serializable)
else content,
}
)
)
try:
relates_to = content.relates_to
except AttributeError:
try:
relates_to = RelatesTo.deserialize(content["m.relates_to"])
except KeyError:
relates_to = None
await self.crypto_store.update_outbound_group_session(session)
return EncryptedMegolmEventContent(
sender_key=self.account.identity_key,
device_id=self.client.device_id,
ciphertext=ciphertext,
session_id=SessionID(session.id),
relates_to=relates_to,
)
def is_sharing_group_session(self, room_id: RoomID) -> bool:
"""
Check if there's a group session being shared for a specific room
Args:
room_id: The room ID to check.
Returns:
True if a group session share is in progress, False if not
"""
return room_id in self._sharing_group_session
async def wait_group_session_share(self, room_id: RoomID) -> None:
"""
Wait for a group session to be shared.
Args:
room_id: The room ID to wait for.
"""
try:
event = self._sharing_group_session[room_id]
await event.wait()
except KeyError:
pass
async def share_group_session(self, room_id: RoomID, users: List[UserID]) -> None:
"""
Create a Megolm session for a specific room and share it with the given list of users.
Note that you must not call this method again before the previous share has finished.
You should either lock calls yourself, or use :meth:`wait_group_session_share` to use
built-in locking capabilities.
Args:
room_id: The room to create the session for.
users: The list of users to share the session with.
Raises:
SessionShareError: If something went wrong while sharing the session.
"""
if room_id in self._sharing_group_session:
raise SessionShareError("Already sharing group session for that room")
self._sharing_group_session[room_id] = asyncio.Event()
try:
await self._share_group_session(room_id, users)
finally:
self._sharing_group_session.pop(room_id).set()
async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> None:
session = await self.crypto_store.get_outbound_group_session(room_id)
if session and session.shared and not session.expired:
raise SessionShareError("Group session has already been shared")
if not session or session.expired:
session = await self._new_outbound_group_session(room_id)
self.log.debug(f"Sharing group session {session.id} for room {room_id} with {users}")
olm_sessions: DeviceMap = defaultdict(lambda: {})
withhold_key_msgs = defaultdict(lambda: {})
missing_sessions: Dict[UserID, Dict[DeviceID, DeviceIdentity]] = defaultdict(lambda: {})
fetch_keys = []
for user_id in users:
devices = await self.crypto_store.get_devices(user_id)
if devices is None:
self.log.debug(
f"get_devices returned nil for {user_id}, will fetch keys and retry"
)
fetch_keys.append(user_id)
elif len(devices) == 0:
self.log.debug(f"{user_id} has no devices, skipping")
else:
self.log.debug(f"Trying to encrypt group session {session.id} for {user_id}")
for device_id, device in devices.items():
result = await self._find_olm_sessions(session, user_id, device_id, device)
if isinstance(result, RoomKeyWithheldEventContent):
withhold_key_msgs[user_id][device_id] = result
elif result == key_missing:
missing_sessions[user_id][device_id] = device
elif isinstance(result, tuple):
olm_sessions[user_id][device_id] = result
if fetch_keys:
self.log.debug(f"Fetching missing keys for {fetch_keys}")
fetched_keys = await self._fetch_keys(users, include_untracked=True)
for user_id, devices in fetched_keys.items():
missing_sessions[user_id] = devices
if missing_sessions:
self.log.debug(f"Creating missing outbound sessions {missing_sessions}")
try:
await self._create_outbound_sessions(missing_sessions)
except Exception:
self.log.exception("Failed to create missing outbound sessions")
for user_id, devices in missing_sessions.items():
for device_id, device in devices.items():
result = await self._find_olm_sessions(session, user_id, device_id, device)
if isinstance(result, RoomKeyWithheldEventContent):
withhold_key_msgs[user_id][device_id] = result
elif isinstance(result, tuple):
olm_sessions[user_id][device_id] = result
# We don't care about missing keys at this point
if len(olm_sessions) > 0:
async with self._olm_lock:
await self._encrypt_and_share_group_session(session, olm_sessions)
if len(withhold_key_msgs) > 0:
event_count = sum(len(map) for map in withhold_key_msgs.values())
self.log.debug(
f"Sending {event_count} to-device events to report {session.id} is withheld"
)
await self.client.send_to_device(EventType.ROOM_KEY_WITHHELD, withhold_key_msgs)
await self.client.send_to_device(
EventType.ORG_MATRIX_ROOM_KEY_WITHHELD, withhold_key_msgs
)
self.log.info(f"Group session {session.id} for {room_id} successfully shared")
session.shared = True
await self.crypto_store.add_outbound_group_session(session)
async def _new_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession:
session = OutboundGroupSession(room_id)
encryption_info = await self.state_store.get_encryption_info(room_id)
if encryption_info:
if encryption_info.algorithm != EncryptionAlgorithm.MEGOLM_V1:
raise SessionShareError("Room encryption algorithm is not supported")
session.max_messages = encryption_info.rotation_period_msgs or session.max_messages
session.max_age = (
timedelta(milliseconds=encryption_info.rotation_period_ms)
if encryption_info.rotation_period_ms
else session.max_age
)
self.log.debug(
"Got stored encryption state event and configured session to rotate "
f"after {session.max_messages} messages or {session.max_age}"
)
if not self.dont_store_outbound_keys:
await self._create_group_session(
self.account.identity_key,
self.account.signing_key,
room_id,
SessionID(session.id),
session.session_key,
max_messages=session.max_messages,
max_age=session.max_age,
is_scheduled=False,
)
return session
async def _encrypt_and_share_group_session(
self, session: OutboundGroupSession, olm_sessions: DeviceMap
):
msgs = defaultdict(lambda: {})
count = 0
for user_id, devices in olm_sessions.items():
count += len(devices)
for device_id, (olm_session, device_identity) in devices.items():
msgs[user_id][device_id] = await self._encrypt_olm_event(
olm_session, device_identity, EventType.ROOM_KEY, session.share_content
)
self.log.debug(
f"Sending to-device events to {count} devices of {len(msgs)} users "
f"to share {session.id}"
)
await self.client.send_to_device(EventType.TO_DEVICE_ENCRYPTED, msgs)
async def _create_group_session(
self,
sender_key: IdentityKey,
signing_key: SigningKey,
room_id: RoomID,
session_id: SessionID,
session_key: str,
max_age: Union[timedelta, int],
max_messages: int,
is_scheduled: bool = False,
) -> None:
start = time.monotonic()
session = InboundGroupSession(
session_key=session_key,
signing_key=signing_key,
sender_key=sender_key,
room_id=room_id,
received_at=datetime.utcnow(),
max_age=max_age,
max_messages=max_messages,
is_scheduled=is_scheduled,
)
olm_duration = time.monotonic() - start
if olm_duration > 5:
self.log.warning(f"Creating inbound group session took {olm_duration:.3f} seconds!")
if session_id != session.id:
self.log.warning(f"Mismatching session IDs: expected {session_id}, got {session.id}")
session_id = session.id
await self.crypto_store.put_group_session(room_id, sender_key, session_id, session)
self._mark_session_received(session_id)
self.log.debug(
f"Created inbound group session {room_id}/{sender_key}/{session_id} "
f"(max {max_age} / {max_messages} messages, {is_scheduled=})"
)
async def _find_olm_sessions(
self,
session: OutboundGroupSession,
user_id: UserID,
device_id: DeviceID,
device: DeviceIdentity,
) -> SessionEncryptResult:
key = (user_id, device_id)
if key in session.users_ignored or key in session.users_shared_with:
return already_shared
elif user_id == self.client.mxid and device_id == self.client.device_id:
session.users_ignored.add(key)
return already_shared
trust = await self.resolve_trust(device)
if trust == TrustState.BLACKLISTED:
self.log.debug(
f"Not encrypting group session {session.id} for {device_id} "
f"of {user_id}: device is blacklisted"
)
session.users_ignored.add(key)
return RoomKeyWithheldEventContent(
room_id=session.room_id,
algorithm=EncryptionAlgorithm.MEGOLM_V1,
session_id=SessionID(session.id),
sender_key=self.account.identity_key,
code=RoomKeyWithheldCode.BLACKLISTED,
reason="Device is blacklisted",
)
elif self.send_keys_min_trust > trust:
self.log.debug(
f"Not encrypting group session {session.id} for {device_id} "
f"of {user_id}: device is not trusted "
f"(min: {self.send_keys_min_trust}, device: {trust})"
)
session.users_ignored.add(key)
return RoomKeyWithheldEventContent(
room_id=session.room_id,
algorithm=EncryptionAlgorithm.MEGOLM_V1,
session_id=SessionID(session.id),
sender_key=self.account.identity_key,
code=RoomKeyWithheldCode.UNVERIFIED,
reason="This device does not encrypt messages for unverified devices",
)
device_session = await self.crypto_store.get_latest_session(device.identity_key)
if not device_session:
return key_missing
session.users_shared_with.add(key)
return device_session, device
python-0.20.4/mautrix/crypto/encrypt_olm.py 0000664 0000000 0000000 00000012324 14547234302 0021042 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Dict
import asyncio
from mautrix.types import (
DecryptedOlmEvent,
DeviceID,
DeviceIdentity,
EncryptedOlmEventContent,
EncryptionKeyAlgorithm,
EventType,
OlmEventKeys,
ToDeviceEventContent,
UserID,
)
from .base import BaseOlmMachine, verify_signature_json
from .sessions import Session
ClaimKeysList = Dict[UserID, Dict[DeviceID, DeviceIdentity]]
class OlmEncryptionMachine(BaseOlmMachine):
_claim_keys_lock: asyncio.Lock
_olm_lock: asyncio.Lock
def __init__(self):
self._claim_keys_lock = asyncio.Lock()
self._olm_lock = asyncio.Lock()
async def _encrypt_olm_event(
self, session: Session, recipient: DeviceIdentity, event_type: EventType, content: Any
) -> EncryptedOlmEventContent:
evt = DecryptedOlmEvent(
sender=self.client.mxid,
sender_device=self.client.device_id,
keys=OlmEventKeys(ed25519=self.account.signing_key),
recipient=recipient.user_id,
recipient_keys=OlmEventKeys(ed25519=recipient.signing_key),
type=event_type,
content=content,
)
ciphertext = session.encrypt(evt.json())
await self.crypto_store.update_session(recipient.identity_key, session)
return EncryptedOlmEventContent(
ciphertext={recipient.identity_key: ciphertext}, sender_key=self.account.identity_key
)
async def _create_outbound_sessions(
self, users: ClaimKeysList, _force_recreate_session: bool = False
) -> None:
async with self._claim_keys_lock:
return await self._create_outbound_sessions_locked(users, _force_recreate_session)
async def _create_outbound_sessions_locked(
self, users: ClaimKeysList, _force_recreate_session: bool = False
) -> None:
request: Dict[UserID, Dict[DeviceID, EncryptionKeyAlgorithm]] = {}
expected_devices = set()
for user_id, devices in users.items():
request[user_id] = {}
for device_id, identity in devices.items():
if _force_recreate_session or not await self.crypto_store.has_session(
identity.identity_key
):
request[user_id][device_id] = EncryptionKeyAlgorithm.SIGNED_CURVE25519
expected_devices.add((user_id, device_id))
if not request[user_id]:
del request[user_id]
if not request:
return
request_device_count = len(expected_devices)
keys = await self.client.claim_keys(request)
for server, info in (keys.failures or {}).items():
self.log.warning(f"Key claim failure for {server}: {info}")
for user_id, devices in keys.one_time_keys.items():
for device_id, one_time_keys in devices.items():
expected_devices.discard((user_id, device_id))
key_id, one_time_key_data = one_time_keys.popitem()
one_time_key = one_time_key_data["key"]
identity = users[user_id][device_id]
if not verify_signature_json(
one_time_key_data, user_id, device_id, identity.signing_key
):
self.log.warning(f"Invalid signature for {device_id} of {user_id}")
else:
session = self.account.new_outbound_session(
identity.identity_key, one_time_key
)
await self.crypto_store.add_session(identity.identity_key, session)
self.log.debug(
f"Created new Olm session with {user_id}/{device_id} "
f"(OTK ID: {key_id})"
)
if expected_devices:
if request_device_count == 1:
raise Exception(
"Key claim response didn't contain key "
f"for queried device {expected_devices.pop()}"
)
else:
self.log.warning(
"Key claim response didn't contain keys for %d out of %d expected devices: %s",
len(expected_devices),
request_device_count,
expected_devices,
)
async def send_encrypted_to_device(
self,
device: DeviceIdentity,
event_type: EventType,
content: ToDeviceEventContent,
_force_recreate_session: bool = False,
) -> None:
await self._create_outbound_sessions(
{device.user_id: {device.device_id: device}},
_force_recreate_session=_force_recreate_session,
)
session = await self.crypto_store.get_latest_session(device.identity_key)
async with self._olm_lock:
encrypted_content = await self._encrypt_olm_event(session, device, event_type, content)
await self.client.send_to_one_device(
EventType.TO_DEVICE_ENCRYPTED, device.user_id, device.device_id, encrypted_content
)
python-0.20.4/mautrix/crypto/key_request.py 0000664 0000000 0000000 00000013502 14547234302 0021046 0 ustar 00root root 0000000 0000000 # Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, Optional, Union
import asyncio
import uuid
from mautrix.types import (
DecryptedOlmEvent,
DeviceID,
EncryptionAlgorithm,
EventType,
ForwardedRoomKeyEventContent,
IdentityKey,
KeyRequestAction,
RequestedKeyInfo,
RoomID,
RoomKeyRequestEventContent,
SessionID,
UserID,
)
from .base import BaseOlmMachine
from .sessions import InboundGroupSession
class KeyRequestingMachine(BaseOlmMachine):
async def request_room_key(
self,
room_id: RoomID,
sender_key: IdentityKey,
session_id: SessionID,
from_devices: Dict[UserID, List[DeviceID]],
timeout: Optional[Union[int, float]] = None,
) -> bool:
"""
Request keys for a Megolm group session from other devices.
Once the keys are received, or if this task is cancelled (via the ``timeout`` parameter),
a cancel request event is sent to the remaining devices. If the ``timeout`` is set to zero
or less, this will return immediately, and the extra key requests will not be cancelled.
Args:
room_id: The room where the session is used.
sender_key: The key of the user who created the session.
session_id: The ID of the session.
from_devices: A dict from user ID to list of device IDs whom to ask for the keys.
timeout: The maximum number of seconds to wait for the keys. If the timeout is
``None``, the wait time is not limited, but the task can still be cancelled.
If it's zero or less, this returns immediately and will never cancel requests.
Returns:
``True`` if the keys were received and are now in the crypto store,
``False`` otherwise (including if the method didn't wait at all).
"""
request_id = str(uuid.uuid1())
request = RoomKeyRequestEventContent(
action=KeyRequestAction.REQUEST,
body=RequestedKeyInfo(
algorithm=EncryptionAlgorithm.MEGOLM_V1,
room_id=room_id,
sender_key=sender_key,
session_id=session_id,
),
request_id=request_id,
requesting_device_id=self.client.device_id,
)
wait = timeout is None or timeout > 0
fut: Optional[asyncio.Future] = None
if wait:
fut = asyncio.get_running_loop().create_future()
self._key_request_waiters[session_id] = fut
await self.client.send_to_device(
EventType.ROOM_KEY_REQUEST,
{
user_id: {device_id: request for device_id in devices}
for user_id, devices in from_devices.items()
},
)
if not wait:
# Timeout is set and <=0, don't wait for keys
return False
assert fut is not None
got_keys = False
try:
user_id, device_id = await asyncio.wait_for(fut, timeout=timeout)
got_keys = True
try:
del from_devices[user_id][device_id]
if len(from_devices[user_id]) == 0:
del from_devices[user_id]
except KeyError:
pass
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
del self._key_request_waiters[session_id]
if len(from_devices) > 0:
cancel = RoomKeyRequestEventContent(
action=KeyRequestAction.CANCEL,
request_id=str(request_id),
requesting_device_id=self.client.device_id,
)
await self.client.send_to_device(
EventType.ROOM_KEY_REQUEST,
{
user_id: {device_id: cancel for device_id in devices}
for user_id, devices in from_devices.items()
},
)
return got_keys
async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None:
key: ForwardedRoomKeyEventContent = evt.content
if await self.crypto_store.has_group_session(key.room_id, key.session_id):
self.log.debug(
f"Ignoring received session {key.session_id} from {evt.sender}/"
f"{evt.sender_device}, as crypto store says we have it already"
)
return
if not key.beeper_max_messages or not key.beeper_max_age_ms:
await self._fill_encryption_info(key)
key.forwarding_key_chain.append(evt.sender_key)
sess = InboundGroupSession.import_session(
key.session_key,
key.signing_key,
key.sender_key,
key.room_id,
key.forwarding_key_chain,
max_age=key.beeper_max_age_ms,
max_messages=key.beeper_max_messages,
is_scheduled=key.beeper_is_scheduled,
)
if key.session_id != sess.id:
self.log.warning(
f"Mismatched session ID while importing forwarded key from "
f"{evt.sender}/{evt.sender_device}: '{key.session_id}' != '{sess.id}'"
)
return
await self.crypto_store.put_group_session(
key.room_id, key.sender_key, key.session_id, sess
)
self._mark_session_received(key.session_id)
self.log.debug(
f"Imported {key.session_id} for {key.room_id} "
f"from {evt.sender}/{evt.sender_device}"
)
try:
task = self._key_request_waiters[key.session_id]
except KeyError:
pass
else:
task.set_result((evt.sender, evt.sender_device))
python-0.20.4/mautrix/crypto/key_share.py 0000664 0000000 0000000 00000020152 14547234302 0020457 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
from mautrix.errors import MatrixConnectionError, MatrixError, MatrixRequestError
from mautrix.types import (
DeviceIdentity,
EncryptionAlgorithm,
EventType,
ForwardedRoomKeyEventContent,
KeyRequestAction,
RequestedKeyInfo,
RoomKeyRequestEventContent,
RoomKeyWithheldCode,
RoomKeyWithheldEventContent,
SessionID,
ToDeviceEvent,
TrustState,
)
from .device_lists import DeviceListMachine
from .encrypt_olm import OlmEncryptionMachine
class RejectKeyShare(MatrixError):
def __init__(
self,
log_message: str = "",
code: Optional[RoomKeyWithheldCode] = None,
reason: Optional[str] = None,
) -> None:
"""
RejectKeyShare is an error used to signal that a key share request should be rejected.
Args:
log_message: The message to log when rejecting the request.
code: The m.room_key.withheld code, or ``None`` to reject silently.
reason: The human-readable reason for the rejection.
"""
super().__init__(log_message)
self.code = code
self.reason = reason
class KeySharingMachine(OlmEncryptionMachine, DeviceListMachine):
async def default_allow_key_share(
self, device: DeviceIdentity, request: RequestedKeyInfo
) -> bool:
"""
Check whether or not the given key request should be fulfilled. You can set a custom
function in :attr:`allow_key_share` to override this.
Args:
device: The identity of the device requesting keys.
request: The requested key details.
Returns:
``True`` if the key share should be accepted,
``False`` if it should be silently ignored.
Raises:
RejectKeyShare: if the key share should be rejected.
"""
if device.user_id != self.client.mxid:
raise RejectKeyShare(
f"Ignoring key request from a different user ({device.user_id})", code=None
)
elif device.device_id == self.client.device_id:
raise RejectKeyShare("Ignoring key request from ourselves", code=None)
elif device.trust == TrustState.BLACKLISTED:
raise RejectKeyShare(
f"Rejecting key request from blacklisted device {device.device_id}",
code=RoomKeyWithheldCode.BLACKLISTED,
reason="You have been blacklisted by this device",
)
elif await self.resolve_trust(device) >= self.share_keys_min_trust:
self.log.debug(f"Accepting key request from trusted device {device.device_id}")
return True
else:
raise RejectKeyShare(
f"Rejecting key request from untrusted device {device.device_id}",
code=RoomKeyWithheldCode.UNVERIFIED,
reason="You have not been verified by this device",
)
async def handle_room_key_request(
self, evt: ToDeviceEvent, raise_exceptions: bool = False
) -> None:
"""
Handle a ``m.room_key_request`` where the action is ``request``.
This is automatically registered as an event handler and therefore called if the client you
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
do syncing in some manual way.
Args:
evt: The to-device event.
raise_exceptions: Whether or not errors while handling should be raised.
"""
request: RoomKeyRequestEventContent = evt.content
if request.action != KeyRequestAction.REQUEST:
return
elif (
evt.sender == self.client.mxid
and request.requesting_device_id == self.client.device_id
):
self.log.debug(f"Ignoring key request {request.request_id} from ourselves")
return
try:
device = await self.get_or_fetch_device(evt.sender, request.requesting_device_id)
except Exception:
self.log.warning(
f"Failed to get device {evt.sender}/{request.requesting_device_id} to "
f"handle key request {request.request_id}",
exc_info=True,
)
if raise_exceptions:
raise
return
if not device:
self.log.warning(
f"Couldn't find device {evt.sender}/{request.requesting_device_id} to "
f"handle key request {request.request_id}"
)
return
self.log.debug(
f"Received key request {request.request_id} from {device.user_id}/"
f"{device.device_id} for session {request.body.session_id}"
)
try:
await self._handle_room_key_request(device, request.body)
except RejectKeyShare as e:
self.log.debug(f"Rejecting key request {request.request_id}: {e}")
await self._reject_key_request(e, device, request.body)
except (MatrixRequestError, MatrixConnectionError):
self.log.exception(
f"API error while handling key request {request.request_id} "
f"(not sending rejection)"
)
if raise_exceptions:
raise
except Exception:
self.log.exception(
f"Error while handling key request {request.request_id}, sending rejection..."
)
error = RejectKeyShare(
code=RoomKeyWithheldCode.UNAVAILABLE,
reason="An internal error occurred while trying to share the requested session",
)
await self._reject_key_request(error, device, request.body)
if raise_exceptions:
raise
async def _handle_room_key_request(
self, device: DeviceIdentity, request: RequestedKeyInfo
) -> None:
if not await self.allow_key_share(device, request):
return
sess = await self.crypto_store.get_group_session(request.room_id, request.session_id)
if sess is None:
raise RejectKeyShare(
f"Didn't find group session {request.session_id} to forward to "
f"{device.user_id}/{device.device_id}",
code=RoomKeyWithheldCode.UNAVAILABLE,
reason="Requested session ID not found on this device",
)
exported_key = sess.export_session(sess.first_known_index)
forward_content = ForwardedRoomKeyEventContent(
algorithm=EncryptionAlgorithm.MEGOLM_V1,
room_id=sess.room_id,
session_id=SessionID(sess.id),
session_key=exported_key,
sender_key=sess.sender_key,
forwarding_key_chain=sess.forwarding_chain,
signing_key=sess.signing_key,
)
await self.send_encrypted_to_device(device, EventType.FORWARDED_ROOM_KEY, forward_content)
async def _reject_key_request(
self, rejection: RejectKeyShare, device: DeviceIdentity, request: RequestedKeyInfo
) -> None:
if not rejection.code:
# Silent rejection
return
content = RoomKeyWithheldEventContent(
room_id=request.room_id,
algorithm=request.algorithm,
session_id=request.session_id,
sender_key=request.sender_key,
code=rejection.code,
reason=rejection.reason,
)
try:
await self.client.send_to_one_device(
EventType.ROOM_KEY_WITHHELD, device.user_id, device.device_id, content
)
await self.client.send_to_one_device(
EventType.ORG_MATRIX_ROOM_KEY_WITHHELD, device.user_id, device.device_id, content
)
except MatrixError:
self.log.warning(
f"Failed to send key share rejection {rejection.code} "
f"to {device.user_id}/{device.device_id}",
exc_info=True,
)
python-0.20.4/mautrix/crypto/machine.py 0000664 0000000 0000000 00000032773 14547234302 0020125 0 ustar 00root root 0000000 0000000 # Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Optional
import asyncio
import logging
import time
from mautrix import client as cli
from mautrix.errors import GroupSessionWithheldError
from mautrix.types import (
ASToDeviceEvent,
DecryptedOlmEvent,
DeviceID,
DeviceLists,
DeviceOTKCount,
EncryptionAlgorithm,
EncryptionKeyAlgorithm,
EventType,
Member,
Membership,
StateEvent,
ToDeviceEvent,
TrustState,
UserID,
)
from mautrix.util import background_task
from mautrix.util.logging import TraceLogger
from .account import OlmAccount
from .decrypt_megolm import MegolmDecryptionMachine
from .encrypt_megolm import MegolmEncryptionMachine
from .key_request import KeyRequestingMachine
from .key_share import KeySharingMachine
from .store import CryptoStore, StateStore
from .unwedge import OlmUnwedgingMachine
class OlmMachine(
MegolmEncryptionMachine,
MegolmDecryptionMachine,
OlmUnwedgingMachine,
KeySharingMachine,
KeyRequestingMachine,
):
"""
OlmMachine is the main class for handling things related to Matrix end-to-end encryption with
Olm and Megolm. Users primarily need :meth:`encrypt_megolm_event`, :meth:`share_group_session`,
and :meth:`decrypt_megolm_event`. Tracking device lists, establishing Olm sessions and handling
Megolm group sessions is handled internally.
"""
client: cli.Client
log: TraceLogger
crypto_store: CryptoStore
state_store: StateStore
account: Optional[OlmAccount]
def __init__(
self,
client: cli.Client,
crypto_store: CryptoStore,
state_store: StateStore,
log: Optional[TraceLogger] = None,
) -> None:
super().__init__()
self.client = client
self.log = log or logging.getLogger("mau.crypto")
self.crypto_store = crypto_store
self.state_store = state_store
self.account = None
self.send_keys_min_trust = TrustState.UNVERIFIED
self.share_keys_min_trust = TrustState.CROSS_SIGNED_TOFU
self.allow_key_share = self.default_allow_key_share
self.delete_outbound_keys_on_ack = False
self.dont_store_outbound_keys = False
self.delete_previous_keys_on_receive = False
self.ratchet_keys_on_decrypt = False
self.delete_fully_used_keys_on_decrypt = False
self.delete_keys_on_device_delete = False
self.disable_device_change_key_rotation = False
self._fetch_keys_lock = asyncio.Lock()
self._megolm_decrypt_lock = asyncio.Lock()
self._share_keys_lock = asyncio.Lock()
self._last_key_share = time.monotonic() - 60
self._key_request_waiters = {}
self._inbound_session_waiters = {}
self._prev_unwedge = {}
self._cs_fetch_attempted = set()
self.client.add_event_handler(
cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True
)
self.client.add_event_handler(cli.InternalEventType.DEVICE_LISTS, self.handle_device_lists)
self.client.add_event_handler(EventType.TO_DEVICE_ENCRYPTED, self.handle_to_device_event)
self.client.add_event_handler(EventType.ROOM_KEY_REQUEST, self.handle_room_key_request)
self.client.add_event_handler(EventType.BEEPER_ROOM_KEY_ACK, self.handle_beep_room_key_ack)
# self.client.add_event_handler(EventType.ROOM_KEY_WITHHELD, self.handle_room_key_withheld)
# self.client.add_event_handler(EventType.ORG_MATRIX_ROOM_KEY_WITHHELD,
# self.handle_room_key_withheld)
self.client.add_event_handler(EventType.ROOM_MEMBER, self.handle_member_event)
async def load(self) -> None:
"""Load the Olm account into memory, or create one if the store doesn't have one stored."""
self.account = await self.crypto_store.get_account()
if self.account is None:
self.account = OlmAccount()
await self.crypto_store.put_account(self.account)
async def handle_as_otk_counts(
self, otk_counts: dict[UserID, dict[DeviceID, DeviceOTKCount]]
) -> None:
for user_id, devices in otk_counts.items():
for device_id, count in devices.items():
if user_id == self.client.mxid and device_id == self.client.device_id:
await self.handle_otk_count(count)
else:
self.log.warning(f"Got OTK count for unknown device {user_id}/{device_id}")
async def handle_as_device_lists(self, device_lists: DeviceLists) -> None:
background_task.create(self.handle_device_lists(device_lists))
async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None:
if evt.to_user_id != self.client.mxid or evt.to_device_id != self.client.device_id:
self.log.warning(
f"Got to-device event for unknown device {evt.to_user_id}/{evt.to_device_id}"
)
return
if evt.type == EventType.TO_DEVICE_ENCRYPTED:
await self.handle_to_device_event(evt)
elif evt.type == EventType.ROOM_KEY_REQUEST:
await self.handle_room_key_request(evt)
elif evt.type == EventType.BEEPER_ROOM_KEY_ACK:
await self.handle_beep_room_key_ack(evt)
else:
self.log.debug(f"Got unknown to-device event {evt.type} from {evt.sender}")
async def handle_otk_count(self, otk_count: DeviceOTKCount) -> None:
"""
Handle the ``device_one_time_keys_count`` data in a sync response.
This is automatically registered as an event handler and therefore called if the client you
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
do syncing in some manual way.
"""
if otk_count.signed_curve25519 < self.account.max_one_time_keys // 2:
self.log.debug(
f"Sync response said we have {otk_count.signed_curve25519} signed"
" curve25519 keys left, sharing new ones..."
)
await self.share_keys(otk_count.signed_curve25519)
async def handle_device_lists(self, device_lists: DeviceLists) -> None:
"""
Handle the ``device_lists`` data in a sync response.
This is automatically registered as an event handler and therefore called if the client you
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
do syncing in some manual way.
"""
if len(device_lists.changed) > 0:
async with self._fetch_keys_lock:
await self._fetch_keys(device_lists.changed, include_untracked=False)
async def handle_member_event(self, evt: StateEvent) -> None:
"""
Handle a new member event.
This is automatically registered as an event handler and therefore called if the client you
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
receive events in some manual way (e.g. through appservice transactions)
"""
if not await self.state_store.is_encrypted(evt.room_id):
return
prev = evt.prev_content.membership
cur = evt.content.membership
ignored_changes = {
Membership.INVITE: Membership.JOIN,
Membership.BAN: Membership.LEAVE,
Membership.LEAVE: Membership.BAN,
}
if prev == cur or ignored_changes.get(prev) == cur:
return
src = getattr(evt, "source", None)
prev_cache = evt.unsigned.get("mautrix_prev_membership")
if isinstance(prev_cache, Member) and prev_cache.membership == cur:
self.log.debug(
f"Got duplicate membership state event in {evt.room_id} changing {evt.state_key} "
f"from {prev} to {cur}, cached state was {prev_cache} (event ID: {evt.event_id}, "
f"sync source: {src})"
)
return
self.log.debug(
f"Got membership state event in {evt.room_id} changing {evt.state_key} from "
f"{prev} to {cur} (event ID: {evt.event_id}, sync source: {src}, "
f"cached: {prev_cache.membership if prev_cache else None}), invalidating group session"
)
await self.crypto_store.remove_outbound_group_session(evt.room_id)
async def handle_to_device_event(self, evt: ToDeviceEvent) -> None:
"""
Handle an encrypted to-device event.
This is automatically registered as an event handler and therefore called if the client you
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
do syncing in some manual way.
"""
self.log.trace(
f"Handling encrypted to-device event from {evt.content.sender_key} ({evt.sender})"
)
decrypted_evt = await self._decrypt_olm_event(evt)
if decrypted_evt.type == EventType.ROOM_KEY:
await self._receive_room_key(decrypted_evt)
elif decrypted_evt.type == EventType.FORWARDED_ROOM_KEY:
await self._receive_forwarded_room_key(decrypted_evt)
async def _receive_room_key(self, evt: DecryptedOlmEvent) -> None:
# TODO nio had a comment saying "handle this better"
# for the case where evt.Keys.Ed25519 is none?
if evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1 or not evt.keys.ed25519:
return
if not evt.content.beeper_max_messages or not evt.content.beeper_max_age_ms:
await self._fill_encryption_info(evt.content)
if self.delete_previous_keys_on_receive and not evt.content.beeper_is_scheduled:
removed_ids = await self.crypto_store.redact_group_sessions(
evt.content.room_id, evt.sender_key, reason="received new key from device"
)
self.log.info(f"Redacted previous megolm sessions: {removed_ids}")
await self._create_group_session(
evt.sender_key,
evt.keys.ed25519,
evt.content.room_id,
evt.content.session_id,
evt.content.session_key,
max_age=evt.content.beeper_max_age_ms,
max_messages=evt.content.beeper_max_messages,
is_scheduled=evt.content.beeper_is_scheduled,
)
async def handle_beep_room_key_ack(self, evt: ToDeviceEvent) -> None:
try:
sess = await self.crypto_store.get_group_session(
evt.content.room_id, evt.content.session_id
)
except GroupSessionWithheldError:
self.log.debug(
f"Ignoring room key ack for session {evt.content.session_id}"
" that was already redacted"
)
return
if not sess:
self.log.debug(f"Ignoring room key ack for unknown session {evt.content.session_id}")
return
if (
sess.sender_key == self.account.identity_key
and self.delete_outbound_keys_on_ack
and evt.content.first_message_index == 0
):
self.log.debug("Redacting inbound copy of outbound group session after ack")
await self.crypto_store.redact_group_session(
evt.content.room_id, evt.content.session_id, reason="outbound session acked"
)
else:
self.log.debug(f"Received room key ack for {sess.id}")
async def share_keys(self, current_otk_count: int | None = None) -> None:
"""
Share any keys that need to be shared. This is automatically called from
:meth:`handle_otk_count`, so you should not need to call this yourself.
Args:
current_otk_count: The current number of signed curve25519 keys present on the server.
If omitted, the count will be fetched from the server.
"""
async with self._share_keys_lock:
await self._share_keys(current_otk_count)
async def _share_keys(self, current_otk_count: int | None) -> None:
if current_otk_count is None or (
# If the last key share was recent and the new count is very low, re-check the count
# from the server to avoid any race conditions.
self._last_key_share + 60 > time.monotonic()
and current_otk_count < 10
):
self.log.debug("Checking OTK count on server")
current_otk_count = (await self.client.upload_keys()).get(
EncryptionKeyAlgorithm.SIGNED_CURVE25519
)
device_keys = (
self.account.get_device_keys(self.client.mxid, self.client.device_id)
if not self.account.shared
else None
)
one_time_keys = self.account.get_one_time_keys(
self.client.mxid, self.client.device_id, current_otk_count
)
if not device_keys and not one_time_keys:
self.log.warning("No one-time keys nor device keys got when trying to share keys")
return
if device_keys:
self.log.debug("Going to upload initial account keys")
self.log.debug(f"Uploading {len(one_time_keys)} one-time keys")
resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys)
self.account.shared = True
self._last_key_share = time.monotonic()
await self.crypto_store.put_account(self.account)
self.log.debug(f"Shared keys and saved account, new keys: {resp}")
python-0.20.4/mautrix/crypto/sessions.py 0000664 0000000 0000000 00000023005 14547234302 0020353 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import List, Optional, Set, Tuple, Union, cast
from datetime import datetime, timedelta
from _libolm import ffi, lib
from attr import dataclass
import olm
from mautrix.errors import EncryptionError
from mautrix.types import (
DeviceID,
EncryptionAlgorithm,
IdentityKey,
OlmCiphertext,
OlmMsgType,
RoomID,
RoomKeyEventContent,
SerializableAttrs,
SigningKey,
UserID,
field,
)
class Session(olm.Session):
creation_time: datetime
last_encrypted: datetime
last_decrypted: datetime
def __init__(self):
super().__init__()
self.creation_time = datetime.now()
self.last_encrypted = datetime.now()
self.last_decrypted = datetime.now()
def __new__(cls, *args, **kwargs):
return super().__new__(cls)
@property
def expired(self):
return False
@classmethod
def from_pickle(
cls,
pickle: bytes,
passphrase: str,
creation_time: datetime,
last_encrypted: Optional[datetime] = None,
last_decrypted: Optional[datetime] = None,
) -> "Session":
session = super().from_pickle(pickle, passphrase=passphrase)
session.creation_time = creation_time
session.last_encrypted = last_encrypted or creation_time
session.last_decrypted = last_decrypted or creation_time
return session
def matches(self, ciphertext: str) -> bool:
return super().matches(olm.OlmPreKeyMessage(ciphertext))
def decrypt(self, ciphertext: OlmCiphertext) -> str:
plaintext = super().decrypt(
olm.OlmPreKeyMessage(ciphertext.body)
if ciphertext.type == OlmMsgType.PREKEY
else olm.OlmMessage(ciphertext.body)
)
self.last_decrypted = datetime.now()
return plaintext
def encrypt(self, plaintext: str) -> OlmCiphertext:
self.last_encrypted = datetime.now()
result = super().encrypt(plaintext)
return OlmCiphertext(
type=(
OlmMsgType.PREKEY
if isinstance(result, olm.OlmPreKeyMessage)
else OlmMsgType.MESSAGE
),
body=result.ciphertext,
)
def describe(self) -> str:
parent = super()
if hasattr(parent, "describe"):
return parent.describe()
elif hasattr(lib, "olm_session_describe"):
describe_length = 600
describe_buffer = ffi.new("char[]", describe_length)
lib.olm_session_describe(self._session, describe_buffer, describe_length)
return ffi.string(describe_buffer).decode("utf-8")
else:
return "describe not supported"
@dataclass
class RatchetSafety(SerializableAttrs):
next_index: int = 0
missed_indices: List[int] = field(factory=lambda: [])
lost_indices: List[int] = field(factory=lambda: [])
class InboundGroupSession(olm.InboundGroupSession):
room_id: RoomID
signing_key: SigningKey
sender_key: IdentityKey
forwarding_chain: List[IdentityKey]
ratchet_safety: RatchetSafety
received_at: datetime
max_age: timedelta
max_messages: int
is_scheduled: bool
def __init__(
self,
session_key: str,
signing_key: SigningKey,
sender_key: IdentityKey,
room_id: RoomID,
forwarding_chain: Optional[List[IdentityKey]] = None,
ratchet_safety: Optional[RatchetSafety] = None,
received_at: Optional[datetime] = None,
max_age: Union[timedelta, int, None] = None,
max_messages: Optional[int] = None,
is_scheduled: bool = False,
) -> None:
self.signing_key = signing_key
self.sender_key = sender_key
self.room_id = room_id
self.forwarding_chain = forwarding_chain or []
self.ratchet_safety = ratchet_safety or RatchetSafety()
self.received_at = received_at or datetime.utcnow()
if isinstance(max_age, int):
max_age = timedelta(milliseconds=max_age)
self.max_age = max_age
self.max_messages = max_messages
self.is_scheduled = is_scheduled
super().__init__(session_key)
def __new__(cls, *args, **kwargs):
return super().__new__(cls)
@classmethod
def from_pickle(
cls,
pickle: bytes,
passphrase: str,
signing_key: SigningKey,
sender_key: IdentityKey,
room_id: RoomID,
forwarding_chain: Optional[List[IdentityKey]] = None,
ratchet_safety: Optional[RatchetSafety] = None,
received_at: Optional[datetime] = None,
max_age: Optional[timedelta] = None,
max_messages: Optional[int] = None,
is_scheduled: bool = False,
) -> "InboundGroupSession":
session = super().from_pickle(pickle, passphrase)
session.signing_key = signing_key
session.sender_key = sender_key
session.room_id = room_id
session.forwarding_chain = forwarding_chain or []
session.ratchet_safety = ratchet_safety or RatchetSafety()
session.received_at = received_at
session.max_age = max_age
session.max_messages = max_messages
session.is_scheduled = is_scheduled
return session
@classmethod
def import_session(
cls,
session_key: str,
signing_key: SigningKey,
sender_key: IdentityKey,
room_id: RoomID,
forwarding_chain: Optional[List[str]] = None,
ratchet_safety: Optional[RatchetSafety] = None,
received_at: Optional[datetime] = None,
max_age: Union[timedelta, int, None] = None,
max_messages: Optional[int] = None,
is_scheduled: bool = False,
) -> "InboundGroupSession":
session = super().import_session(session_key)
session.signing_key = signing_key
session.sender_key = sender_key
session.room_id = room_id
session.forwarding_chain = forwarding_chain or []
session.ratchet_safety = ratchet_safety or RatchetSafety()
session.received_at = received_at or datetime.utcnow()
if isinstance(max_age, int):
max_age = timedelta(milliseconds=max_age)
session.max_age = max_age
session.max_messages = max_messages
session.is_scheduled = is_scheduled
return session
def ratchet_to(self, index: int) -> "InboundGroupSession":
exported = self.export_session(index)
return self.import_session(
exported,
signing_key=self.signing_key,
sender_key=self.sender_key,
room_id=self.room_id,
forwarding_chain=self.forwarding_chain,
ratchet_safety=self.ratchet_safety,
received_at=self.received_at,
max_age=self.max_age,
max_messages=self.max_messages,
is_scheduled=self.is_scheduled,
)
class OutboundGroupSession(olm.OutboundGroupSession):
"""Outbound group session aware of the users it is shared with.
Also remembers the time it was created and the number of messages it has
encrypted, in order to know if it needs to be rotated.
"""
max_age: timedelta
max_messages: int
creation_time: datetime
use_time: datetime
message_count: int
room_id: RoomID
users_shared_with: Set[Tuple[UserID, DeviceID]]
users_ignored: Set[Tuple[UserID, DeviceID]]
shared: bool
def __init__(self, room_id: RoomID) -> None:
self.max_age = timedelta(days=7)
self.max_messages = 100
self.creation_time = datetime.now()
self.use_time = datetime.now()
self.message_count = 0
self.room_id = room_id
self.users_shared_with = set()
self.users_ignored = set()
self.shared = False
super().__init__()
def __new__(cls, *args, **kwargs):
return super().__new__(cls)
@property
def expired(self):
return (
self.message_count >= self.max_messages
or datetime.now() - self.creation_time >= self.max_age
)
def encrypt(self, plaintext):
if not self.shared:
raise EncryptionError("Group session has not been shared")
if self.expired:
raise EncryptionError("Group session has expired")
self.message_count += 1
self.use_time = datetime.now()
return super().encrypt(plaintext)
@classmethod
def from_pickle(
cls,
pickle: bytes,
passphrase: str,
max_age: timedelta,
max_messages: int,
creation_time: datetime,
use_time: datetime,
message_count: int,
room_id: RoomID,
shared: bool,
) -> "OutboundGroupSession":
session = cast(OutboundGroupSession, super().from_pickle(pickle, passphrase))
session.max_age = max_age
session.max_messages = max_messages
session.creation_time = creation_time
session.use_time = use_time
session.message_count = message_count
session.room_id = room_id
session.users_shared_with = set()
session.users_ignored = set()
session.shared = shared
return session
@property
def share_content(self) -> RoomKeyEventContent:
return RoomKeyEventContent(
algorithm=EncryptionAlgorithm.MEGOLM_V1,
room_id=self.room_id,
session_id=self.id,
session_key=self.session_key,
)
python-0.20.4/mautrix/crypto/store/ 0000775 0000000 0000000 00000000000 14547234302 0017267 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/crypto/store/__init__.py 0000664 0000000 0000000 00000000621 14547234302 0021377 0 ustar 00root root 0000000 0000000 from mautrix import __optional_imports__
from .abstract import CryptoStore, StateStore
from .memory import MemoryCryptoStore
try:
from .asyncpg import PgCryptoStateStore, PgCryptoStore
except ImportError:
if __optional_imports__:
raise
PgCryptoStore = PgCryptoStateStore = None
__all__ = ["CryptoStore", "StateStore", "MemoryCryptoStore", "PgCryptoStateStore", "PgCryptoStore"]
python-0.20.4/mautrix/crypto/store/abstract.py 0000664 0000000 0000000 00000036723 14547234302 0021457 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import NamedTuple
from abc import ABC, abstractmethod
from mautrix.types import (
CrossSigner,
CrossSigningUsage,
DeviceID,
DeviceIdentity,
EventID,
IdentityKey,
RoomEncryptionStateEventContent,
RoomID,
SessionID,
SigningKey,
TOFUSigningKey,
UserID,
)
from ..account import OlmAccount
from ..sessions import InboundGroupSession, OutboundGroupSession, Session
class StateStore(ABC):
@abstractmethod
async def is_encrypted(self, room_id: RoomID) -> bool:
pass
@abstractmethod
async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
pass
@abstractmethod
async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]:
pass
class CryptoStore(ABC):
"""
CryptoStore is used by :class:`OlmMachine` to store Olm and Megolm sessions, user device lists
and message indices.
"""
account_id: str
"""The unique identifier for the account that is stored in this CryptoStore."""
pickle_key: str
"""The pickle key to use when pickling Olm objects."""
@abstractmethod
async def get_device_id(self) -> DeviceID | None:
"""
Get the device ID corresponding to this account_id
Returns:
The device ID in the store.
"""
@abstractmethod
async def put_device_id(self, device_id: DeviceID) -> None:
"""
Store a device ID.
Args:
device_id: The device ID to store.
"""
async def open(self) -> None:
"""
Open the store. If the store doesn't require opening any resources beforehand or only opens
when flushing, this can be a no-op
"""
async def close(self) -> None:
"""
Close the store when it will no longer be used. The default implementation will simply call
.flush(). If the store doesn't keep any persistent resources, the default implementation is
sufficient.
"""
await self.flush()
async def flush(self) -> None:
"""Flush the store. If all the methods persist data immediately, this can be a no-op."""
@abstractmethod
async def delete(self) -> None:
"""Delete the data in the store."""
@abstractmethod
async def put_account(self, account: OlmAccount) -> None:
"""Insert or update the OlmAccount in the store."""
@abstractmethod
async def get_account(self) -> OlmAccount | None:
"""Get the OlmAccount that was previously inserted with :meth:`put_account`.
If no account has been inserted, this must return ``None``."""
@abstractmethod
async def has_session(self, key: IdentityKey) -> bool:
"""
Check whether or not the store has a session for a specific device.
Args:
key: The curve25519 identity key of the device to check.
Returns:
``True`` if the session has at least one Olm session for the given identity key,
``False`` otherwise.
"""
@abstractmethod
async def get_sessions(self, key: IdentityKey) -> list[Session]:
"""
Get all Olm sessions in the store for the specific device.
Args:
key: The curve25519 identity key of the device whose sessions to get.
Returns:
A list of Olm sessions for the given identity key.
If the store contains no sessions, an empty list.
"""
@abstractmethod
async def get_latest_session(self, key: IdentityKey) -> Session | None:
"""
Get the Olm session with the highest session ID (lexiographically sorting) for a specific
device. It's usually safe to return the most recently added session if sorting by session
ID is too difficult.
Args:
key: The curve25519 identity key of the device whose session to get.
Returns:
The most recent session for the given device.
If the store contains no sessions, ``None``.
"""
@abstractmethod
async def add_session(self, key: IdentityKey, session: Session) -> None:
"""
Insert an Olm session into the store.
Args:
key: The curve25519 identity key of the device with whom this session was made.
session: The session itself.
"""
@abstractmethod
async def update_session(self, key: IdentityKey, session: Session) -> None:
"""
Update a session in the store. Implementations may assume that the given session was
previously either inserted with :meth:`add_session` or fetched with either
:meth:`get_sessions` or :meth:`get_latest_session`.
Args:
key: The curve25519 identity key of the device with whom this session was made.
session: The session itself.
"""
@abstractmethod
async def put_group_session(
self,
room_id: RoomID,
sender_key: IdentityKey,
session_id: SessionID,
session: InboundGroupSession,
) -> None:
"""
Insert an inbound Megolm session into the store.
Args:
room_id: The room ID for which this session was made.
sender_key: The curve25519 identity key of the user who made this session.
session_id: The unique identifier for this session.
session: The session itself.
"""
@abstractmethod
async def get_group_session(
self, room_id: RoomID, session_id: SessionID
) -> InboundGroupSession | None:
"""
Get an inbound Megolm group session that was previously inserted with
:meth:`put_group_session`.
Args:
room_id: The room ID for which the session was made.
session_id: The unique identifier of the session.
Returns:
The :class:`InboundGroupSession` object, or ``None`` if not found.
"""
@abstractmethod
async def redact_group_session(
self, room_id: RoomID, session_id: SessionID, reason: str
) -> None:
"""
Remove the keys for a specific Megolm group session.
Args:
room_id: The room where the session is.
session_id: The session ID to remove.
reason: The reason the session is being removed.
"""
@abstractmethod
async def redact_group_sessions(
self, room_id: RoomID | None, sender_key: IdentityKey | None, reason: str
) -> list[SessionID]:
"""
Remove the keys for multiple Megolm group sessions,
based on the room ID and/or sender device.
Args:
room_id: The room ID to delete keys from.
sender_key: The Olm identity key of the device to delete keys from.
reason: The reason why the keys are being deleted.
Returns:
The list of session IDs that were deleted.
"""
@abstractmethod
async def redact_expired_group_sessions(self) -> list[SessionID]:
"""
Remove all Megolm group sessions where at least twice the maximum age has passed since
receiving the keys.
Returns:
The list of session IDs that were deleted.
"""
@abstractmethod
async def redact_outdated_group_sessions(self) -> list[SessionID]:
"""
Remove all Megolm group sessions which lack the metadata to determine when they should
expire.
Returns:
The list of session IDs that were deleted.
"""
@abstractmethod
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
"""
Check whether or not a specific inbound Megolm session is in the store. This is used before
importing forwarded keys.
Args:
room_id: The room ID for which the session was made.
session_id: The unique identifier of the session.
Returns:
``True`` if the store has a session with the given ID, ``False`` otherwise.
"""
@abstractmethod
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
"""
Insert an outbound Megolm session into the store.
The store should index inserted sessions by the room_id field of the session to support
getting and removing sessions. There will only be one outbound session per room ID at a
time.
Args:
session: The session itself.
"""
@abstractmethod
async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
"""
Update an outbound Megolm session in the store. Implementations may assume that the given
session was previously either inserted with :meth:`add_outbound_group_session` or fetched
with :meth:`get_outbound_group_session`.
Args:
session: The session itself.
"""
@abstractmethod
async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
"""
Get the stored outbound Megolm session from the store.
Args:
room_id: The room whose session to get.
Returns:
The :class:`OutboundGroupSession` object, or ``None`` if not found.
"""
@abstractmethod
async def remove_outbound_group_session(self, room_id: RoomID) -> None:
"""
Remove the stored outbound Megolm session for a specific room.
This is used when a membership change is received in a specific room.
Args:
room_id: The room whose session to remove.
"""
@abstractmethod
async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
"""
Remove the stored outbound Megolm session for multiple rooms.
This is used when the device list of a user changes.
Args:
rooms: The list of rooms whose sessions to remove.
"""
@abstractmethod
async def validate_message_index(
self,
sender_key: IdentityKey,
session_id: SessionID,
event_id: EventID,
index: int,
timestamp: int,
) -> bool:
"""
Validate that a specific message isn't a replay attack.
Implementations should store a map from ``(sender_key, session_id, index)`` to
``(event_id, timestamp)``, then use that map to check whether or not the message
index is valid:
* If the map key doesn't exist, the given values should be stored and the message is valid.
* If the map key exists and the stored values match the given values, the message is valid.
* If the map key exists, but the stored values do not match the given values, the message
is not valid.
Args:
sender_key: The curve25519 identity key of the user who sent the message.
session_id: The Megolm session ID for the session with which the message was encrypted.
event_id: The event ID of the message.
index: The Megolm message index of the message.
timestamp: The timestamp of the message.
Returns:
``True`` if the message is valid, ``False`` if not.
"""
@abstractmethod
async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
"""
Get all devices for a given user.
Args:
user_id: The ID of the user whose devices to get.
Returns:
If there has been a previous call to :meth:`put_devices` with the same user ID (even
with an empty dict), a dict from device ID to :class:`DeviceIdentity` object.
Otherwise, ``None``.
"""
@abstractmethod
async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
"""
Get a specific device identity.
Args:
user_id: The ID of the user whose device to get.
device_id: The ID of the device to get.
Returns:
The :class:`DeviceIdentity` object, or ``None`` if not found.
"""
@abstractmethod
async def find_device_by_key(
self, user_id: UserID, identity_key: IdentityKey
) -> DeviceIdentity | None:
"""
Find a specific device identity based on the identity key.
Args:
user_id: The ID of the user whose device to get.
identity_key: The identity key of the device to get.
Returns:
The :class:`DeviceIdentity` object, or ``None`` if not found.
"""
@abstractmethod
async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
"""
Replace the stored device list for a specific user.
Args:
user_id: The ID of the user whose device list to update.
devices: A dict from device ID to :class:`DeviceIdentity` object. The dict may be empty.
"""
@abstractmethod
async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
"""
Filter a list of user IDs to only include users whose device lists are being tracked.
Args:
users: The list of user IDs to filter.
Returns:
A filtered version of the input list that only includes users who have had a previous
call to :meth:`put_devices` (even if the call was with an empty dict).
"""
@abstractmethod
async def put_cross_signing_key(
self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey
) -> None:
"""
Store a single cross-signing key.
Args:
user_id: The user whose cross-signing key is being stored.
usage: The type of key being stored.
key: The key itself.
"""
@abstractmethod
async def get_cross_signing_keys(
self, user_id: UserID
) -> dict[CrossSigningUsage, TOFUSigningKey]:
"""
Retrieve stored cross-signing keys for a specific user.
Args:
user_id: The user whose cross-signing keys to get.
Returns:
A map from the type of key to a tuple containing the current key and the key that was
seen first. If the keys are different, it should be treated as a local TOFU violation.
"""
@abstractmethod
async def put_signature(
self, target: CrossSigner, signer: CrossSigner, signature: str
) -> None:
"""
Store a signature for a given key from a given key.
Args:
target: The user ID and key being signed.
signer: The user ID and key who are doing the signing.
signature: The signature.
"""
@abstractmethod
async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
"""
Check if a given key is signed by the given signer.
Args:
target: The key to check.
signer: The signer who is expected to have signed the key.
Returns:
``True`` if the database contains a signature for the key, ``False`` otherwise.
"""
@abstractmethod
async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
"""
Delete signatures made by the given key.
Args:
signer: The key whose signatures to delete.
Returns:
The number of signatures deleted.
"""
python-0.20.4/mautrix/crypto/store/asyncpg/ 0000775 0000000 0000000 00000000000 14547234302 0020733 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/crypto/store/asyncpg/__init__.py 0000664 0000000 0000000 00000000150 14547234302 0023040 0 ustar 00root root 0000000 0000000 from .store import PgCryptoStateStore, PgCryptoStore
__all__ = ["PgCryptoStore", "PgCryptoStateStore"]
python-0.20.4/mautrix/crypto/store/asyncpg/store.py 0000664 0000000 0000000 00000066614 14547234302 0022456 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from collections import defaultdict
from datetime import timedelta
from asyncpg import UniqueViolationError
from mautrix.client.state_store import SyncStore
from mautrix.client.state_store.asyncpg import PgStateStore
from mautrix.errors import GroupSessionWithheldError
from mautrix.types import (
CrossSigner,
CrossSigningUsage,
DeviceID,
DeviceIdentity,
EventID,
IdentityKey,
RoomID,
RoomKeyWithheldCode,
SessionID,
SigningKey,
SyncToken,
TOFUSigningKey,
TrustState,
UserID,
)
from mautrix.util.async_db import Database, Scheme
from mautrix.util.logging import TraceLogger
from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, RatchetSafety, Session
from ..abstract import CryptoStore, StateStore
from .upgrade import upgrade_table
try:
from sqlite3 import IntegrityError, sqlite_version_info as sqlite_version
from aiosqlite import Cursor
except ImportError:
Cursor = None
sqlite_version = (0, 0, 0)
class IntegrityError(Exception):
pass
class PgCryptoStateStore(PgStateStore, StateStore):
"""
This class ensures that the PgStateStore in the client module implements the StateStore
methods needed by the crypto module.
"""
class PgCryptoStore(CryptoStore, SyncStore):
upgrade_table = upgrade_table
db: Database
account_id: str
pickle_key: str
log: TraceLogger
_sync_token: SyncToken | None
_device_id: DeviceID | None
_account: OlmAccount | None
_olm_cache: dict[IdentityKey, dict[SessionID, Session]]
def __init__(self, account_id: str, pickle_key: str, db: Database) -> None:
self.db = db
self.account_id = account_id
self.pickle_key = pickle_key
self.log = db.log
self._sync_token = None
self._device_id = DeviceID("")
self._account = None
self._olm_cache = defaultdict(lambda: {})
async def delete(self) -> None:
tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session")
async with self.db.acquire() as conn, conn.transaction():
for table in tables:
await conn.execute(f"DELETE FROM {table} WHERE account_id=$1", self.account_id)
async def get_device_id(self) -> DeviceID | None:
q = "SELECT device_id FROM crypto_account WHERE account_id=$1"
device_id = await self.db.fetchval(q, self.account_id)
self._device_id = device_id or self._device_id
return self._device_id
async def put_device_id(self, device_id: DeviceID) -> None:
q = "UPDATE crypto_account SET device_id=$1 WHERE account_id=$2"
await self.db.fetchval(q, device_id, self.account_id)
self._device_id = device_id
async def put_next_batch(self, next_batch: SyncToken) -> None:
self._sync_token = next_batch
q = "UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2"
await self.db.execute(q, self._sync_token, self.account_id)
async def get_next_batch(self) -> SyncToken:
if self._sync_token is None:
q = "SELECT sync_token FROM crypto_account WHERE account_id=$1"
self._sync_token = await self.db.fetchval(q, self.account_id)
return self._sync_token
async def put_account(self, account: OlmAccount) -> None:
self._account = account
pickle = account.pickle(self.pickle_key)
q = """
INSERT INTO crypto_account (account_id, device_id, shared, sync_token, account)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (account_id) DO UPDATE
SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account
"""
await self.db.execute(
q,
self.account_id,
self._device_id or "",
account.shared,
self._sync_token or "",
pickle,
)
async def get_account(self) -> OlmAccount:
if self._account is None:
q = "SELECT shared, account, device_id FROM crypto_account WHERE account_id=$1"
row = await self.db.fetchrow(q, self.account_id)
if row is not None:
self._account = OlmAccount.from_pickle(
row["account"], passphrase=self.pickle_key, shared=row["shared"]
)
return self._account
async def has_session(self, key: IdentityKey) -> bool:
if len(self._olm_cache[key]) > 0:
return True
q = "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2"
val = await self.db.fetchval(q, key, self.account_id)
return val is not None
async def get_sessions(self, key: IdentityKey) -> list[Session]:
q = """
SELECT session_id, session, created_at, last_encrypted, last_decrypted
FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2
ORDER BY last_decrypted DESC
"""
rows = await self.db.fetch(q, key, self.account_id)
sessions = []
for row in rows:
try:
sess = self._olm_cache[key][row["session_id"]]
except KeyError:
sess = Session.from_pickle(
row["session"],
passphrase=self.pickle_key,
creation_time=row["created_at"],
last_encrypted=row["last_encrypted"],
last_decrypted=row["last_decrypted"],
)
self._olm_cache[key][SessionID(sess.id)] = sess
sessions.append(sess)
return sessions
async def get_latest_session(self, key: IdentityKey) -> Session | None:
q = """
SELECT session_id, session, created_at, last_encrypted, last_decrypted
FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2
ORDER BY last_decrypted DESC LIMIT 1
"""
row = await self.db.fetchrow(q, key, self.account_id)
if row is None:
return None
try:
return self._olm_cache[key][row["session_id"]]
except KeyError:
sess = Session.from_pickle(
row["session"],
passphrase=self.pickle_key,
creation_time=row["created_at"],
last_encrypted=row["last_encrypted"],
last_decrypted=row["last_decrypted"],
)
self._olm_cache[key][SessionID(sess.id)] = sess
return sess
async def add_session(self, key: IdentityKey, session: Session) -> None:
if session.id in self._olm_cache[key]:
self.log.warning(f"Cache already contains Olm session with ID {session.id}")
self._olm_cache[key][SessionID(session.id)] = session
pickle = session.pickle(self.pickle_key)
q = """
INSERT INTO crypto_olm_session (
session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7)
"""
await self.db.execute(
q,
session.id,
key,
pickle,
session.creation_time,
session.last_encrypted,
session.last_decrypted,
self.account_id,
)
async def update_session(self, key: IdentityKey, session: Session) -> None:
try:
assert self._olm_cache[key][SessionID(session.id)] == session
except (KeyError, AssertionError) as e:
self.log.warning(
f"Cached olm session with ID {session.id} "
f"isn't equal to the one being saved to the database ({e})"
)
pickle = session.pickle(self.pickle_key)
q = """
UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3
WHERE session_id=$4 AND account_id=$5
"""
await self.db.execute(
q, pickle, session.last_encrypted, session.last_decrypted, session.id, self.account_id
)
async def put_group_session(
self,
room_id: RoomID,
sender_key: IdentityKey,
session_id: SessionID,
session: InboundGroupSession,
) -> None:
pickle = session.pickle(self.pickle_key)
forwarding_chains = ",".join(session.forwarding_chain)
q = """
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key,
signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session,
forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety,
received_at=excluded.received_at, max_age=excluded.max_age,
max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
"""
try:
await self.db.execute(
q,
session_id,
sender_key,
session.signing_key,
room_id,
pickle,
forwarding_chains,
session.ratchet_safety.json(),
session.received_at,
int(session.max_age.total_seconds() * 1000) if session.max_age else None,
session.max_messages,
session.is_scheduled,
self.account_id,
)
except (IntegrityError, UniqueViolationError):
self.log.exception(f"Failed to insert megolm session {session_id}")
async def get_group_session(
self, room_id: RoomID, session_id: SessionID
) -> InboundGroupSession | None:
q = """
SELECT
sender_key, signing_key, session, forwarding_chains, withheld_code,
ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND session_id=$2 AND account_id=$3
"""
row = await self.db.fetchrow(q, room_id, session_id, self.account_id)
if row is None:
return None
if row["withheld_code"] is not None:
raise GroupSessionWithheldError(session_id, row["withheld_code"])
forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else []
return InboundGroupSession.from_pickle(
row["session"],
passphrase=self.pickle_key,
signing_key=row["signing_key"],
sender_key=row["sender_key"],
room_id=room_id,
forwarding_chain=forwarding_chain,
ratchet_safety=RatchetSafety.parse_json(row["ratchet_safety"] or "{}"),
received_at=row["received_at"],
max_age=timedelta(milliseconds=row["max_age"]) if row["max_age"] else None,
max_messages=row["max_messages"],
is_scheduled=row["is_scheduled"],
)
async def redact_group_session(
self, room_id: RoomID, session_id: SessionID, reason: str
) -> None:
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
"""
await self.db.execute(
q,
RoomKeyWithheldCode.BEEPER_REDACTED.value,
f"Session redacted: {reason}",
session_id,
self.account_id,
)
async def redact_group_sessions(
self, room_id: RoomID, sender_key: IdentityKey, reason: str
) -> list[SessionID]:
if not room_id and not sender_key:
raise ValueError("Either room_id or sender_key must be provided")
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
RETURNING session_id
"""
rows = await self.db.fetch(
q,
RoomKeyWithheldCode.BEEPER_REDACTED.value,
f"Session redacted: {reason}",
room_id,
sender_key,
self.account_id,
)
return [row["session_id"] for row in rows]
async def redact_expired_group_sessions(self) -> list[SessionID]:
if self.db.scheme == Scheme.SQLITE:
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
RETURNING session_id
"""
elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND received_at + 2 * (max_age * interval '1 millisecond') < now()
RETURNING session_id
"""
else:
raise RuntimeError(f"Unsupported dialect {self.db.scheme}")
rows = await self.db.fetch(
q,
RoomKeyWithheldCode.BEEPER_REDACTED.value,
f"Session redacted: expired",
self.account_id,
)
return [row["session_id"] for row in rows]
async def redact_outdated_group_sessions(self) -> list[SessionID]:
q = """
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
RETURNING session_id
"""
rows = await self.db.fetch(
q,
RoomKeyWithheldCode.BEEPER_REDACTED.value,
f"Session redacted: outdated",
self.account_id,
)
return [row["session_id"] for row in rows]
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
q = """
SELECT COUNT(session) FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL
"""
count = await self.db.fetchval(q, room_id, session_id, self.account_id)
return count > 0
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
pickle = session.pickle(self.pickle_key)
max_age = int(session.max_age.total_seconds() * 1000)
q = """
INSERT INTO crypto_megolm_outbound_session (
room_id, session_id, session, shared, max_messages, message_count,
max_age, created_at, last_used, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (account_id, room_id) DO UPDATE
SET session_id=excluded.session_id, session=excluded.session, shared=excluded.shared,
max_messages=excluded.max_messages, message_count=excluded.message_count,
max_age=excluded.max_age, created_at=excluded.created_at, last_used=excluded.last_used
"""
await self.db.execute(
q,
session.room_id,
session.id,
pickle,
session.shared,
session.max_messages,
session.message_count,
max_age,
session.creation_time,
session.use_time,
self.account_id,
)
async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
pickle = session.pickle(self.pickle_key)
q = """
UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3
WHERE room_id=$4 AND session_id=$5 AND account_id=$6
"""
await self.db.execute(
q,
pickle,
session.message_count,
session.use_time,
session.room_id,
session.id,
self.account_id,
)
async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
q = """
SELECT room_id, session_id, session, shared, max_messages, message_count, max_age,
created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2
"""
row = await self.db.fetchrow(q, room_id, self.account_id)
if row is None:
return None
return OutboundGroupSession.from_pickle(
row["session"],
passphrase=self.pickle_key,
room_id=row["room_id"],
shared=row["shared"],
max_messages=row["max_messages"],
message_count=row["message_count"],
max_age=timedelta(milliseconds=row["max_age"]),
use_time=row["last_used"],
creation_time=row["created_at"],
)
async def remove_outbound_group_session(self, room_id: RoomID) -> None:
q = "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2"
await self.db.execute(q, room_id, self.account_id)
async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = """
DELETE FROM crypto_megolm_outbound_session WHERE account_id=$1 AND room_id=ANY($2)
"""
await self.db.execute(q, self.account_id, rooms)
else:
params = ",".join(["?"] * len(rooms))
q = f"""
DELETE FROM crypto_megolm_outbound_session WHERE account_id=? AND room_id IN ({params})
"""
await self.db.execute(q, self.account_id, *rooms)
_validate_message_index_query = """
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
VALUES ($1, $2, $3, $4, $5)
-- have to update something so that RETURNING * always returns the row
ON CONFLICT (sender_key, session_id, "index") DO UPDATE SET sender_key=excluded.sender_key
RETURNING *
"""
async def validate_message_index(
self,
sender_key: IdentityKey,
session_id: SessionID,
event_id: EventID,
index: int,
timestamp: int,
) -> bool:
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH) or (
# RETURNING was added in SQLite 3.35.0 https://www.sqlite.org/lang_returning.html
self.db.scheme == Scheme.SQLITE
and sqlite_version >= (3, 35)
):
row = await self.db.fetchrow(
self._validate_message_index_query,
sender_key,
session_id,
index,
event_id,
timestamp,
)
return row["event_id"] == event_id and row["timestamp"] == timestamp
else:
row = await self.db.fetchrow(
"SELECT event_id, timestamp FROM crypto_message_index "
'WHERE sender_key=$1 AND session_id=$2 AND "index"=$3',
sender_key,
session_id,
index,
)
if row is not None:
return row["event_id"] == event_id and row["timestamp"] == timestamp
await self.db.execute(
"INSERT INTO crypto_message_index(sender_key, session_id, "
' "index", event_id, timestamp) '
"VALUES ($1, $2, $3, $4, $5)",
sender_key,
session_id,
index,
event_id,
timestamp,
)
return True
async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
q = "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1"
tracked_user_id = await self.db.fetchval(q, user_id)
if tracked_user_id is None:
return None
q = """
SELECT device_id, identity_key, signing_key, trust, deleted, name
FROM crypto_device WHERE user_id=$1
"""
rows = await self.db.fetch(q, user_id)
result = {}
for row in rows:
result[row["device_id"]] = DeviceIdentity(
user_id=user_id,
device_id=row["device_id"],
identity_key=row["identity_key"],
signing_key=row["signing_key"],
trust=TrustState(row["trust"]),
deleted=row["deleted"],
name=row["name"],
)
return result
async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
q = """
SELECT identity_key, signing_key, trust, deleted, name FROM crypto_device
WHERE user_id=$1 AND device_id=$2
"""
row = await self.db.fetchrow(q, user_id, device_id)
if row is None:
return None
return DeviceIdentity(
user_id=user_id,
device_id=device_id,
name=row["name"],
identity_key=row["identity_key"],
signing_key=row["signing_key"],
trust=TrustState(row["trust"]),
deleted=row["deleted"],
)
async def find_device_by_key(
self, user_id: UserID, identity_key: IdentityKey
) -> DeviceIdentity | None:
q = """
SELECT device_id, signing_key, trust, deleted, name FROM crypto_device
WHERE user_id=$1 AND identity_key=$2
"""
row = await self.db.fetchrow(
q,
user_id,
identity_key,
)
if row is None:
return None
return DeviceIdentity(
user_id=user_id,
device_id=row["device_id"],
name=row["name"],
identity_key=identity_key,
signing_key=row["signing_key"],
trust=TrustState(row["trust"]),
deleted=row["deleted"],
)
async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
data = [
(
user_id,
device_id,
identity.identity_key,
identity.signing_key,
identity.trust,
identity.deleted,
identity.name,
)
for device_id, identity in devices.items()
]
columns = [
"user_id",
"device_id",
"identity_key",
"signing_key",
"trust",
"deleted",
"name",
]
async with self.db.acquire() as conn, conn.transaction():
q = """
INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING
"""
await conn.execute(q, user_id)
await conn.execute("DELETE FROM crypto_device WHERE user_id=$1", user_id)
if self.db.scheme == Scheme.POSTGRES:
await conn.copy_records_to_table("crypto_device", records=data, columns=columns)
else:
q = """
INSERT INTO crypto_device (
user_id, device_id, identity_key, signing_key, trust, deleted, name
) VALUES ($1, $2, $3, $4, $5, $6, $7)
"""
await conn.executemany(q, data)
async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)"
rows = await self.db.fetch(q, users)
else:
params = ",".join(["?"] * len(users))
q = f"SELECT user_id FROM crypto_tracked_user WHERE user_id IN ({params})"
rows = await self.db.fetch(q, *users)
return [row["user_id"] for row in rows]
async def put_cross_signing_key(
self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey
) -> None:
q = """
INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key
"""
try:
await self.db.execute(q, user_id, usage.value, key, key)
except Exception:
self.log.exception(f"Failed to store cross-signing key {user_id}/{key}/{usage}")
async def get_cross_signing_keys(
self, user_id: UserID
) -> dict[CrossSigningUsage, TOFUSigningKey]:
q = "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1"
return {
CrossSigningUsage(row["usage"]): TOFUSigningKey(
key=SigningKey(row["key"]),
first=SigningKey(row["first_seen_key"]),
)
for row in await self.db.fetch(q, user_id)
}
async def put_signature(
self, target: CrossSigner, signer: CrossSigner, signature: str
) -> None:
q = """
INSERT INTO crypto_cross_signing_signatures (
signed_user_id, signed_key, signer_user_id, signer_key, signature
) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key)
DO UPDATE SET signature=excluded.signature
"""
signed_user_id, signed_key = target
signer_user_id, signer_key = signer
try:
await self.db.execute(
q, signed_user_id, signed_key, signer_user_id, signer_key, signature
)
except Exception:
self.log.exception(
f"Failed to store signature from {signer_user_id}/{signer_key} "
f"for {signed_user_id}/{signed_key}"
)
async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
q = """
SELECT EXISTS(
SELECT 1 FROM crypto_cross_signing_signatures
WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4
)
"""
signed_user_id, signed_key = target
signer_user_id, signer_key = signer
return await self.db.fetchval(q, signed_user_id, signed_key, signer_user_id, signer_key)
async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
signer_user_id, signer_key = signer
q = "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2"
try:
res = await self.db.execute(q, signer_user_id, signer_key)
except Exception:
self.log.exception(
f"Failed to drop old signatures made by replaced key {signer_user_id}/{signer_key}"
)
return -1
if Cursor is not None and isinstance(res, Cursor):
return res.rowcount
elif (
isinstance(res, str)
and res.startswith("DELETE ")
and (intPart := res[len("DELETE ") :]).isdecimal()
):
return int(intPart)
return -1
python-0.20.4/mautrix/crypto/store/asyncpg/upgrade.py 0000664 0000000 0000000 00000043003 14547234302 0022734 0 ustar 00root root 0000000 0000000 # Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import logging
from mautrix.util.async_db import Connection, Scheme, UpgradeTable
upgrade_table = UpgradeTable(
version_table_name="crypto_version",
database_name="crypto store",
log=logging.getLogger("mau.crypto.db.upgrade"),
)
@upgrade_table.register(description="Latest revision", upgrades_to=10)
async def upgrade_blank_to_latest(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
)"""
)
await conn.execute(
"""CREATE TABLE IF NOT EXISTS crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id TEXT NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
)"""
)
await conn.execute(
"""CREATE TABLE IF NOT EXISTS crypto_tracked_user (
user_id TEXT PRIMARY KEY
)"""
)
await conn.execute(
"""CREATE TABLE IF NOT EXISTS crypto_device (
user_id TEXT,
device_id TEXT,
identity_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
trust SMALLINT NOT NULL,
deleted BOOLEAN NOT NULL,
name TEXT NOT NULL,
PRIMARY KEY (user_id, device_id)
)"""
)
await conn.execute(
"""CREATE TABLE IF NOT EXISTS crypto_olm_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
session bytea NOT NULL,
created_at timestamp NOT NULL,
last_decrypted timestamp NOT NULL,
last_encrypted timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
)"""
)
await conn.execute(
"""CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43),
room_id TEXT NOT NULL,
session bytea,
forwarding_chains TEXT,
withheld_code TEXT,
withheld_reason TEXT,
ratchet_safety jsonb,
received_at timestamp,
max_age BIGINT,
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (account_id, session_id)
)"""
)
await conn.execute(
"""CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
account_id TEXT,
room_id TEXT,
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
)"""
)
await conn.execute(
"""CREATE TABLE crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
key CHAR(43) NOT NULL,
first_seen_key CHAR(43) NOT NULL,
PRIMARY KEY (user_id, usage)
)"""
)
await conn.execute(
"""CREATE TABLE crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
)"""
)
@upgrade_table.register(description="Add account_id primary key column")
async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
if scheme == Scheme.SQLITE:
await conn.execute("DROP TABLE crypto_account")
await conn.execute("DROP TABLE crypto_olm_session")
await conn.execute("DROP TABLE crypto_megolm_inbound_session")
await conn.execute("DROP TABLE crypto_megolm_outbound_session")
await conn.execute(
"""CREATE TABLE crypto_account (
account_id VARCHAR(255) PRIMARY KEY,
device_id VARCHAR(255) NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
)"""
)
await conn.execute(
"""CREATE TABLE crypto_olm_session (
account_id VARCHAR(255),
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
session bytea NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
)"""
)
await conn.execute(
"""CREATE TABLE crypto_megolm_inbound_session (
account_id VARCHAR(255),
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session bytea NOT NULL,
forwarding_chains TEXT NOT NULL,
PRIMARY KEY (account_id, session_id)
)"""
)
await conn.execute(
"""CREATE TABLE crypto_megolm_outbound_session (
account_id VARCHAR(255),
room_id VARCHAR(255),
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
)"""
)
else:
async def add_account_id_column(table: str, pkey_columns: list[str]) -> None:
await conn.execute(f"ALTER TABLE {table} ADD COLUMN account_id VARCHAR(255)")
await conn.execute(f"UPDATE {table} SET account_id=''")
await conn.execute(f"ALTER TABLE {table} ALTER COLUMN account_id SET NOT NULL")
await conn.execute(f"ALTER TABLE {table} DROP CONSTRAINT {table}_pkey")
pkey_columns.append("account_id")
pkey_columns_str = ", ".join(f'"{col}"' for col in pkey_columns)
await conn.execute(
f"ALTER TABLE {table} ADD CONSTRAINT {table}_pkey "
f"PRIMARY KEY ({pkey_columns_str})"
)
await add_account_id_column("crypto_account", [])
await add_account_id_column("crypto_olm_session", ["session_id"])
await add_account_id_column("crypto_megolm_inbound_session", ["session_id"])
await add_account_id_column("crypto_megolm_outbound_session", ["room_id"])
@upgrade_table.register(description="Stop using size-limited string fields")
async def upgrade_v3(conn: Connection, scheme: Scheme) -> None:
if scheme == Scheme.SQLITE:
return
await conn.execute("ALTER TABLE crypto_account ALTER COLUMN account_id TYPE TEXT")
await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id TYPE TEXT")
await conn.execute("ALTER TABLE crypto_message_index ALTER COLUMN event_id TYPE TEXT")
await conn.execute("ALTER TABLE crypto_tracked_user ALTER COLUMN user_id TYPE TEXT")
await conn.execute("ALTER TABLE crypto_device ALTER COLUMN user_id TYPE TEXT")
await conn.execute("ALTER TABLE crypto_device ALTER COLUMN device_id TYPE TEXT")
await conn.execute("ALTER TABLE crypto_device ALTER COLUMN name TYPE TEXT")
await conn.execute("ALTER TABLE crypto_olm_session ALTER COLUMN account_id TYPE TEXT")
await conn.execute(
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN account_id TYPE TEXT"
)
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN room_id TYPE TEXT")
await conn.execute(
"ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN account_id TYPE TEXT"
)
await conn.execute("ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN room_id TYPE TEXT")
@upgrade_table.register(description="Split last_used into last_encrypted and last_decrypted")
async def upgrade_v4(conn: Connection, scheme: Scheme) -> None:
await conn.execute("ALTER TABLE crypto_olm_session RENAME COLUMN last_used TO last_decrypted")
await conn.execute("ALTER TABLE crypto_olm_session ADD COLUMN last_encrypted timestamp")
await conn.execute("UPDATE crypto_olm_session SET last_encrypted=last_decrypted")
if scheme == Scheme.POSTGRES:
# This is too hard to do on sqlite, so let's just do it on postgres
await conn.execute(
"ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL"
)
@upgrade_table.register(description="Add cross-signing key and signature caches")
async def upgrade_v5(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
key CHAR(43),
first_seen_key CHAR(43),
PRIMARY KEY (user_id, usage)
)"""
)
await conn.execute(
"""CREATE TABLE crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
signature TEXT,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
)"""
)
@upgrade_table.register(description="Update trust state values")
async def upgrade_v6(conn: Connection) -> None:
await conn.execute("UPDATE crypto_device SET trust=300 WHERE trust=1") # verified
await conn.execute("UPDATE crypto_device SET trust=-100 WHERE trust=2") # blacklisted
await conn.execute("UPDATE crypto_device SET trust=0 WHERE trust=3") # ignored -> unset
@upgrade_table.register(
description="Synchronize schema with mautrix-go", upgrades_to=9, transaction=False
)
async def upgrade_v9(conn: Connection, scheme: Scheme) -> None:
if scheme == Scheme.POSTGRES:
async with conn.transaction():
await upgrade_v9_postgres(conn)
else:
await upgrade_v9_sqlite(conn)
# These two are never used because the previous one jumps from 6 to 9.
@upgrade_table.register
async def upgrade_noop_7_to_8(_: Connection) -> None:
pass
@upgrade_table.register
async def upgrade_noop_8_to_9(_: Connection) -> None:
pass
async def upgrade_v9_postgres(conn: Connection) -> None:
await conn.execute("UPDATE crypto_account SET device_id='' WHERE device_id IS NULL")
await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id SET NOT NULL")
await conn.execute(
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN signing_key DROP NOT NULL"
)
await conn.execute(
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN session DROP NOT NULL"
)
await conn.execute(
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN forwarding_chains DROP NOT NULL"
)
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_code TEXT")
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_reason TEXT")
await conn.execute("DELETE FROM crypto_cross_signing_keys WHERE key IS NULL")
await conn.execute(
"UPDATE crypto_cross_signing_keys SET first_seen_key=key WHERE first_seen_key IS NULL"
)
await conn.execute("ALTER TABLE crypto_cross_signing_keys ALTER COLUMN key SET NOT NULL")
await conn.execute(
"ALTER TABLE crypto_cross_signing_keys ALTER COLUMN first_seen_key SET NOT NULL"
)
await conn.execute("DELETE FROM crypto_cross_signing_signatures WHERE signature IS NULL")
await conn.execute(
"ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature SET NOT NULL"
)
await conn.execute(
"ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN max_age TYPE BIGINT "
"USING (EXTRACT(EPOCH from max_age)*1000)::bigint"
)
async def upgrade_v9_sqlite(conn: Connection) -> None:
await conn.execute("PRAGMA foreign_keys = OFF")
async with conn.transaction():
await conn.execute(
"""CREATE TABLE new_crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
)"""
)
await conn.execute(
"""
INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account)
SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account
FROM crypto_account
"""
)
await conn.execute("DROP TABLE crypto_account")
await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account")
await conn.execute(
"""CREATE TABLE new_crypto_megolm_inbound_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43),
room_id TEXT NOT NULL,
session bytea,
forwarding_chains TEXT,
withheld_code TEXT,
withheld_reason TEXT,
PRIMARY KEY (account_id, session_id)
)"""
)
await conn.execute(
"""
INSERT INTO new_crypto_megolm_inbound_session (
account_id, session_id, sender_key, signing_key, room_id, session,
forwarding_chains
)
SELECT account_id, session_id, sender_key, signing_key, room_id, session,
forwarding_chains
FROM crypto_megolm_inbound_session
"""
)
await conn.execute("DROP TABLE crypto_megolm_inbound_session")
await conn.execute(
"ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session"
)
await conn.execute("UPDATE crypto_megolm_outbound_session SET max_age=max_age*1000")
await conn.execute(
"""CREATE TABLE new_crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
key CHAR(43) NOT NULL,
first_seen_key CHAR(43) NOT NULL,
PRIMARY KEY (user_id, usage)
)"""
)
await conn.execute(
"""
INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key)
SELECT user_id, usage, key, COALESCE(first_seen_key, key)
FROM crypto_cross_signing_keys
WHERE key IS NOT NULL
"""
)
await conn.execute("DROP TABLE crypto_cross_signing_keys")
await conn.execute(
"ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys"
)
await conn.execute(
"""CREATE TABLE new_crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
)"""
)
await conn.execute(
"""
INSERT INTO new_crypto_cross_signing_signatures (
signed_user_id, signed_key, signer_user_id, signer_key, signature
)
SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature
FROM crypto_cross_signing_signatures
WHERE signature IS NOT NULL
"""
)
await conn.execute("DROP TABLE crypto_cross_signing_signatures")
await conn.execute(
"ALTER TABLE new_crypto_cross_signing_signatures "
"RENAME TO crypto_cross_signing_signatures"
)
await conn.execute("PRAGMA foreign_key_check")
await conn.execute("PRAGMA foreign_keys = ON")
@upgrade_table.register(
description="Add metadata for detecting when megolm sessions are safe to delete"
)
async def upgrade_v10(conn: Connection) -> None:
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb")
await conn.execute(
"ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp"
)
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT")
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER")
await conn.execute(
"ALTER TABLE crypto_megolm_inbound_session "
"ADD COLUMN is_scheduled BOOLEAN NOT NULL DEFAULT false"
)
python-0.20.4/mautrix/crypto/store/memory.py 0000664 0000000 0000000 00000017545 14547234302 0021165 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from mautrix.client.state_store import SyncStore
from mautrix.types import (
CrossSigner,
CrossSigningUsage,
DeviceID,
DeviceIdentity,
EventID,
IdentityKey,
RoomID,
SessionID,
SigningKey,
SyncToken,
TOFUSigningKey,
UserID,
)
from ..account import OlmAccount
from ..sessions import InboundGroupSession, OutboundGroupSession, Session
from .abstract import CryptoStore
class MemoryCryptoStore(CryptoStore, SyncStore):
_device_id: DeviceID | None
_sync_token: SyncToken | None
_account: OlmAccount | None
_message_indices: dict[tuple[IdentityKey, SessionID, int], tuple[EventID, int]]
_devices: dict[UserID, dict[DeviceID, DeviceIdentity]]
_olm_sessions: dict[IdentityKey, list[Session]]
_inbound_sessions: dict[tuple[RoomID, SessionID], InboundGroupSession]
_outbound_sessions: dict[RoomID, OutboundGroupSession]
_signatures: dict[CrossSigner, dict[CrossSigner, str]]
_cross_signing_keys: dict[UserID, dict[CrossSigningUsage, TOFUSigningKey]]
def __init__(self, account_id: str, pickle_key: str) -> None:
self.account_id = account_id
self.pickle_key = pickle_key
self._sync_token = None
self._device_id = None
self._account = None
self._message_indices = {}
self._devices = {}
self._olm_sessions = {}
self._inbound_sessions = {}
self._outbound_sessions = {}
self._signatures = {}
self._cross_signing_keys = {}
async def get_device_id(self) -> DeviceID | None:
return self._device_id
async def put_device_id(self, device_id: DeviceID) -> None:
self._device_id = device_id
async def put_next_batch(self, next_batch: SyncToken) -> None:
self._sync_token = next_batch
async def get_next_batch(self) -> SyncToken:
return self._sync_token
async def delete(self) -> None:
self._account = None
self._device_id = None
self._olm_sessions = {}
self._outbound_sessions = {}
async def put_account(self, account: OlmAccount) -> None:
self._account = account
async def get_account(self) -> OlmAccount:
return self._account
async def has_session(self, key: IdentityKey) -> bool:
return key in self._olm_sessions
async def get_sessions(self, key: IdentityKey) -> list[Session]:
return self._olm_sessions.get(key, [])
async def get_latest_session(self, key: IdentityKey) -> Session | None:
try:
return self._olm_sessions[key][-1]
except (KeyError, IndexError):
return None
async def add_session(self, key: IdentityKey, session: Session) -> None:
self._olm_sessions.setdefault(key, []).append(session)
async def update_session(self, key: IdentityKey, session: Session) -> None:
# This is a no-op as the session object is the same one previously added.
pass
async def put_group_session(
self,
room_id: RoomID,
sender_key: IdentityKey,
session_id: SessionID,
session: InboundGroupSession,
) -> None:
self._inbound_sessions[(room_id, session_id)] = session
async def get_group_session(
self, room_id: RoomID, session_id: SessionID
) -> InboundGroupSession:
return self._inbound_sessions.get((room_id, session_id))
async def redact_group_session(
self, room_id: RoomID, session_id: SessionID, reason: str
) -> None:
self._inbound_sessions.pop((room_id, session_id), None)
async def redact_group_sessions(
self, room_id: RoomID, sender_key: IdentityKey, reason: str
) -> list[SessionID]:
if not room_id and not sender_key:
raise ValueError("Either room_id or sender_key must be provided")
deleted = []
keys = list(self._inbound_sessions.keys())
for key in keys:
item = self._inbound_sessions[key]
if (not room_id or item.room_id == room_id) and (
not sender_key or item.sender_key == sender_key
):
deleted.append(SessionID(item.id))
del self._inbound_sessions[key]
return deleted
async def redact_expired_group_sessions(self) -> list[SessionID]:
raise NotImplementedError()
async def redact_outdated_group_sessions(self) -> list[SessionID]:
raise NotImplementedError()
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
return (room_id, session_id) in self._inbound_sessions
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
self._outbound_sessions[session.room_id] = session
async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
# This is a no-op as the session object is the same one previously added.
pass
async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
return self._outbound_sessions.get(room_id)
async def remove_outbound_group_session(self, room_id: RoomID) -> None:
self._outbound_sessions.pop(room_id, None)
async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
for room_id in rooms:
self._outbound_sessions.pop(room_id, None)
async def validate_message_index(
self,
sender_key: IdentityKey,
session_id: SessionID,
event_id: EventID,
index: int,
timestamp: int,
) -> bool:
try:
return self._message_indices[(sender_key, session_id, index)] == (event_id, timestamp)
except KeyError:
self._message_indices[(sender_key, session_id, index)] = (event_id, timestamp)
return True
async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
return self._devices.get(user_id)
async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
return self._devices.get(user_id, {}).get(device_id)
async def find_device_by_key(
self, user_id: UserID, identity_key: IdentityKey
) -> DeviceIdentity | None:
for device in self._devices.get(user_id, {}).values():
if device.identity_key == identity_key:
return device
return None
async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
self._devices[user_id] = devices
async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
return [user_id for user_id in users if user_id in self._devices]
async def put_cross_signing_key(
self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey
) -> None:
try:
current = self._cross_signing_keys[user_id][usage]
except KeyError:
self._cross_signing_keys.setdefault(user_id, {})[usage] = TOFUSigningKey(
key=key, first=key
)
else:
current.key = key
async def get_cross_signing_keys(
self, user_id: UserID
) -> dict[CrossSigningUsage, TOFUSigningKey]:
return self._cross_signing_keys.get(user_id, {})
async def put_signature(
self, target: CrossSigner, signer: CrossSigner, signature: str
) -> None:
self._signatures.setdefault(signer, {})[target] = signature
async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
return target in self._signatures.get(signer, {})
async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
deleted = self._signatures.pop(signer, None)
return len(deleted)
python-0.20.4/mautrix/crypto/store/tests/ 0000775 0000000 0000000 00000000000 14547234302 0020431 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/crypto/store/tests/__init__.py 0000664 0000000 0000000 00000000000 14547234302 0022530 0 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/crypto/store/tests/store_test.py 0000664 0000000 0000000 00000011513 14547234302 0023177 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import AsyncContextManager, AsyncIterator, Callable
from contextlib import asynccontextmanager
import os
import random
import string
import time
import asyncpg
import pytest
from mautrix.client.state_store import SyncStore
from mautrix.crypto import InboundGroupSession, OlmAccount, OutboundGroupSession
from mautrix.types import DeviceID, EventID, RoomID, SessionID, SyncToken
from mautrix.util.async_db import Database
from .. import CryptoStore, MemoryCryptoStore, PgCryptoStore
@asynccontextmanager
async def async_postgres_store() -> AsyncIterator[PgCryptoStore]:
try:
pg_url = os.environ["MEOW_TEST_PG_URL"]
except KeyError:
pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)")
return
conn: asyncpg.Connection = await asyncpg.connect(pg_url)
schema_name = "".join(random.choices(string.ascii_lowercase, k=8))
schema_name = f"test_schema_{schema_name}_{int(time.time())}"
await conn.execute(f"CREATE SCHEMA {schema_name}")
db = Database.create(
pg_url,
upgrade_table=PgCryptoStore.upgrade_table,
db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}},
)
store = PgCryptoStore("", "test", db)
await db.start()
yield store
await db.stop()
await conn.execute(f"DROP SCHEMA {schema_name} CASCADE")
await conn.close()
@asynccontextmanager
async def async_sqlite_store() -> AsyncIterator[PgCryptoStore]:
db = Database.create(
"sqlite::memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1}
)
store = PgCryptoStore("", "test", db)
await db.start()
yield store
await db.stop()
@asynccontextmanager
async def memory_store() -> AsyncIterator[MemoryCryptoStore]:
yield MemoryCryptoStore("", "test")
@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store])
async def crypto_store(request) -> AsyncIterator[CryptoStore]:
param: Callable[[], AsyncContextManager[CryptoStore]] = request.param
async with param() as state_store:
yield state_store
async def test_basic(crypto_store: CryptoStore) -> None:
acc = OlmAccount()
keys = acc.identity_keys
await crypto_store.put_account(acc)
await crypto_store.put_device_id(DeviceID("TEST"))
if isinstance(crypto_store, SyncStore):
await crypto_store.put_next_batch(SyncToken("TEST"))
assert await crypto_store.get_device_id() == "TEST"
assert (await crypto_store.get_account()).identity_keys == keys
if isinstance(crypto_store, SyncStore):
assert await crypto_store.get_next_batch() == "TEST"
def _make_group_sess(
acc: OlmAccount, room_id: RoomID
) -> tuple[InboundGroupSession, OutboundGroupSession]:
outbound = OutboundGroupSession(room_id)
inbound = InboundGroupSession(
session_key=outbound.session_key,
signing_key=acc.signing_key,
sender_key=acc.identity_key,
room_id=room_id,
)
return inbound, outbound
async def test_validate_message_index(crypto_store: CryptoStore) -> None:
acc = OlmAccount()
inbound, outbound = _make_group_sess(acc, RoomID("!foo:bar.com"))
outbound.shared = True
orig_plaintext = "hello world"
ciphertext = outbound.encrypt(orig_plaintext)
ts = int(time.time() * 1000)
plaintext, index = inbound.decrypt(ciphertext)
assert plaintext == orig_plaintext
assert await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
), "Initial validation returns True"
assert await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
), "Validating the same details again returns True"
assert not await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$bar"), index, ts
), "Different event ID causes validation to fail"
assert not await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1
), "Different timestamp causes validation to fail"
assert not await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1
), "Validating incorrect details twice fails"
assert await crypto_store.validate_message_index(
acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts
), "Validating the same details after fails still returns True"
# TODO tests for device identity storage, group session storage
# and cross-signing key/signature storage
python-0.20.4/mautrix/crypto/unwedge.py 0000664 0000000 0000000 00000003467 14547234302 0020155 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import time
from mautrix.types import EventType, IdentityKey, Obj, UserID
from .decrypt_olm import OlmDecryptionMachine
from .device_lists import DeviceListMachine
from .encrypt_olm import OlmEncryptionMachine
MIN_UNWEDGE_INTERVAL = 1 * 60 * 60
class OlmUnwedgingMachine(OlmDecryptionMachine, OlmEncryptionMachine, DeviceListMachine):
async def _unwedge_session(self, sender: UserID, sender_key: IdentityKey) -> None:
try:
prev_unwedge = self._prev_unwedge[sender_key]
except KeyError:
pass
else:
delta = time.monotonic() - prev_unwedge
if delta < MIN_UNWEDGE_INTERVAL:
self.log.debug(
f"Not creating new Olm session with {sender}/{sender_key}, "
f"previous recreation was {delta}s ago"
)
return
self._prev_unwedge[sender_key] = time.monotonic()
try:
device = await self.get_or_fetch_device_by_key(sender, sender_key)
if device is None:
self.log.warning(
f"Didn't find identity of {sender}/{sender_key}, can't unwedge session"
)
return
self.log.debug(
f"Creating new Olm session with {sender}/{device.user_id} (key: {sender_key})"
)
await self.send_encrypted_to_device(
device, EventType.TO_DEVICE_DUMMY, Obj(), _force_recreate_session=True
)
except Exception:
self.log.exception(f"Error unwedging session with {sender}/{sender_key}")
python-0.20.4/mautrix/errors/ 0000775 0000000 0000000 00000000000 14547234302 0016127 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/errors/__init__.py 0000664 0000000 0000000 00000005521 14547234302 0020243 0 ustar 00root root 0000000 0000000 from .base import IntentError, MatrixConnectionError, MatrixError, MatrixResponseError
from .crypto import (
CryptoError,
DecryptedPayloadError,
DecryptionError,
DeviceValidationError,
DuplicateMessageIndex,
EncryptionError,
GroupSessionWithheldError,
MatchingSessionDecryptionError,
MismatchingRoomError,
SessionNotFound,
SessionShareError,
VerificationError,
)
from .request import (
MAlreadyJoined,
MatrixBadContent,
MatrixBadRequest,
MatrixInvalidToken,
MatrixRequestError,
MatrixStandardRequestError,
MatrixUnknownRequestError,
MBadJSON,
MBadState,
MCaptchaInvalid,
MCaptchaNeeded,
MExclusive,
MForbidden,
MGuestAccessForbidden,
MIncompatibleRoomVersion,
MInsufficientPower,
MInvalidParam,
MInvalidRoomState,
MInvalidUsername,
MLimitExceeded,
MMissingParam,
MMissingToken,
MNotFound,
MNotJoined,
MNotJSON,
MRoomInUse,
MTooLarge,
MUnauthorized,
MUnknown,
MUnknownEndpoint,
MUnknownToken,
MUnrecognized,
MUnsupportedRoomVersion,
MUserDeactivated,
MUserInUse,
make_request_error,
standard_error,
)
from .well_known import (
WellKnownError,
WellKnownInvalidVersionsResponse,
WellKnownMissingHomeserver,
WellKnownNotJSON,
WellKnownNotURL,
WellKnownUnexpectedStatus,
WellKnownUnsupportedScheme,
)
__all__ = [
"IntentError",
"MatrixConnectionError",
"MatrixError",
"MatrixResponseError",
"CryptoError",
"DecryptedPayloadError",
"DecryptionError",
"DeviceValidationError",
"DuplicateMessageIndex",
"EncryptionError",
"GroupSessionWithheldError",
"MatchingSessionDecryptionError",
"MismatchingRoomError",
"SessionNotFound",
"SessionShareError",
"VerificationError",
"MAlreadyJoined",
"MatrixBadContent",
"MatrixBadRequest",
"MatrixInvalidToken",
"MatrixRequestError",
"MatrixStandardRequestError",
"MatrixUnknownRequestError",
"MBadJSON",
"MBadState",
"MCaptchaInvalid",
"MCaptchaNeeded",
"MExclusive",
"MForbidden",
"MGuestAccessForbidden",
"MIncompatibleRoomVersion",
"MInsufficientPower",
"MInvalidParam",
"MInvalidRoomState",
"MInvalidUsername",
"MLimitExceeded",
"MMissingParam",
"MMissingToken",
"MNotFound",
"MNotJoined",
"MNotJSON",
"MRoomInUse",
"MTooLarge",
"MUnauthorized",
"MUnknown",
"MUnknownEndpoint",
"MUnknownToken",
"MUnrecognized",
"MUnsupportedRoomVersion",
"MUserDeactivated",
"MUserInUse",
"make_request_error",
"standard_error",
"WellKnownError",
"WellKnownInvalidVersionsResponse",
"WellKnownMissingHomeserver",
"WellKnownNotJSON",
"WellKnownNotURL",
"WellKnownUnexpectedStatus",
"WellKnownUnsupportedScheme",
]
python-0.20.4/mautrix/errors/base.py 0000664 0000000 0000000 00000001325 14547234302 0017414 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
class MatrixError(Exception):
"""A generic Matrix error. Specific errors will subclass this."""
pass
class MatrixConnectionError(MatrixError):
pass
class MatrixResponseError(MatrixError):
"""The response from the homeserver did not fulfill expectations."""
def __init__(self, message: str) -> None:
super().__init__(message)
class IntentError(MatrixError):
"""An intent execution failure, most likely caused by a `MatrixRequestError`."""
pass
python-0.20.4/mautrix/errors/crypto.py 0000664 0000000 0000000 00000004663 14547234302 0020032 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import warnings
from mautrix.types import IdentityKey, SessionID
from .base import MatrixError
class CryptoError(MatrixError):
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message
class EncryptionError(CryptoError):
pass
class SessionShareError(CryptoError):
pass
class DecryptionError(CryptoError):
@property
def human_message(self) -> str:
return "the bridge failed to decrypt the message"
class MatchingSessionDecryptionError(DecryptionError):
pass
class GroupSessionWithheldError(DecryptionError):
def __init__(self, session_id: SessionID, withheld_code: str) -> None:
super().__init__(f"Session ID {session_id} was withheld ({withheld_code})")
self.withheld_code = withheld_code
class SessionNotFound(DecryptionError):
def __init__(self, session_id: SessionID, sender_key: IdentityKey | None = None) -> None:
super().__init__(
f"Failed to decrypt megolm event: no session with given ID {session_id} found"
)
self.session_id = session_id
self._sender_key = sender_key
@property
def human_message(self) -> str:
return "the bridge hasn't received the decryption keys"
@property
def sender_key(self) -> IdentityKey | None:
"""
.. deprecated:: 0.17.0
Matrix v1.3 deprecated the device_id and sender_key fields in megolm events.
"""
warnings.warn(
"The sender_key field in Megolm events was deprecated in Matrix 1.3",
DeprecationWarning,
)
return self._sender_key
class DuplicateMessageIndex(DecryptionError):
def __init__(self) -> None:
super().__init__("Duplicate message index")
class VerificationError(DecryptionError):
def __init__(self) -> None:
super().__init__("Device keys in session and cached device info do not match")
class DecryptedPayloadError(DecryptionError):
pass
class MismatchingRoomError(DecryptionError):
def __init__(self) -> None:
super().__init__("Encrypted megolm event is not intended for this room")
class DeviceValidationError(EncryptionError):
pass
python-0.20.4/mautrix/errors/request.py 0000664 0000000 0000000 00000013443 14547234302 0020176 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Callable, Type
from .base import MatrixError
class MatrixRequestError(MatrixError):
"""An error that was returned by the homeserver."""
http_status: int
message: str | None
errcode: str
class MatrixUnknownRequestError(MatrixRequestError):
"""An unknown error type returned by the homeserver."""
http_status: int
text: str
errcode: str | None
message: str | None
def __init__(
self,
http_status: int = 0,
text: str = "",
errcode: str | None = None,
message: str | None = None,
) -> None:
super().__init__(f"{http_status}: {text}")
self.http_status = http_status
self.text = text
self.errcode = errcode
self.message = message
class MatrixStandardRequestError(MatrixRequestError):
"""A standard error type returned by the homeserver."""
errcode: str = None
def __init__(self, http_status: int, message: str = "") -> None:
super().__init__(message)
self.http_status: int = http_status
self.message: str = message
MxSRE = Type[MatrixStandardRequestError]
ec_map: dict[str, MxSRE] = {}
uec_map: dict[str, MxSRE] = {}
def standard_error(code: str, unstable: str | None = None) -> Callable[[MxSRE], MxSRE]:
def decorator(cls: MxSRE) -> MxSRE:
cls.errcode = code
ec_map[code] = cls
if unstable:
cls.unstable_errcode = unstable
uec_map[unstable] = cls
return cls
return decorator
def make_request_error(
http_status: int,
text: str,
errcode: str | None,
message: str | None,
unstable_errcode: str | None = None,
) -> MatrixRequestError:
"""
Determine the correct exception class for the error code and create an instance of that class
with the given values.
Args:
http_status: The HTTP status code.
text: The raw response text.
errcode: The errcode field in the response JSON.
message: The error field in the response JSON.
unstable_errcode: The MSC3848 error code field in the response JSON.
"""
if unstable_errcode:
try:
ec_class = uec_map[unstable_errcode]
return ec_class(http_status, message)
except KeyError:
pass
try:
ec_class = ec_map[errcode]
return ec_class(http_status, message)
except KeyError:
return MatrixUnknownRequestError(http_status, text, errcode, message)
# Standard error codes from https://spec.matrix.org/v1.3/client-server-api/#api-standards
# Additionally some combining superclasses for some of the error codes
@standard_error("M_FORBIDDEN")
class MForbidden(MatrixStandardRequestError):
pass
@standard_error("M_ALREADY_JOINED", unstable="ORG.MATRIX.MSC3848.ALREADY_JOINED")
class MAlreadyJoined(MForbidden):
pass
@standard_error("M_NOT_JOINED", unstable="ORG.MATRIX.MSC3848.NOT_JOINED")
class MNotJoined(MForbidden):
pass
@standard_error("M_INSUFFICIENT_POWER", unstable="ORG.MATRIX.MSC3848.INSUFFICIENT_POWER")
class MInsufficientPower(MForbidden):
pass
@standard_error("M_UNKNOWN_ENDPOINT")
class MUnknownEndpoint(MatrixStandardRequestError):
pass
@standard_error("M_USER_DEACTIVATED")
class MUserDeactivated(MForbidden):
pass
class MatrixInvalidToken(MatrixStandardRequestError):
pass
@standard_error("M_UNKNOWN_TOKEN")
class MUnknownToken(MatrixInvalidToken):
pass
@standard_error("M_MISSING_TOKEN")
class MMissingToken(MatrixInvalidToken):
pass
class MatrixBadRequest(MatrixStandardRequestError):
pass
class MatrixBadContent(MatrixBadRequest):
pass
@standard_error("M_BAD_JSON")
class MBadJSON(MatrixBadContent):
pass
@standard_error("M_NOT_JSON")
class MNotJSON(MatrixBadContent):
pass
@standard_error("M_NOT_FOUND")
class MNotFound(MatrixStandardRequestError):
pass
@standard_error("M_LIMIT_EXCEEDED")
class MLimitExceeded(MatrixStandardRequestError):
pass
@standard_error("M_UNKNOWN")
class MUnknown(MatrixStandardRequestError):
pass
@standard_error("M_UNRECOGNIZED")
class MUnrecognized(MatrixStandardRequestError):
pass
@standard_error("M_UNAUTHORIZED")
class MUnauthorized(MatrixStandardRequestError):
pass
@standard_error("M_USER_IN_USE")
class MUserInUse(MatrixStandardRequestError):
pass
@standard_error("M_INVALID_USERNAME")
class MInvalidUsername(MatrixStandardRequestError):
pass
@standard_error("M_ROOM_IN_USE")
class MRoomInUse(MatrixStandardRequestError):
pass
@standard_error("M_INVALID_ROOM_STATE")
class MInvalidRoomState(MatrixStandardRequestError):
pass
# TODO THREEPID_ errors
@standard_error("M_UNSUPPORTED_ROOM_VERSION")
class MUnsupportedRoomVersion(MatrixStandardRequestError):
pass
@standard_error("M_INCOMPATIBLE_ROOM_VERSION")
class MIncompatibleRoomVersion(MatrixStandardRequestError):
pass
@standard_error("M_BAD_STATE")
class MBadState(MatrixStandardRequestError):
pass
@standard_error("M_GUEST_ACCESS_FORBIDDEN")
class MGuestAccessForbidden(MatrixStandardRequestError):
pass
@standard_error("M_CAPTCHA_NEEDED")
class MCaptchaNeeded(MatrixStandardRequestError):
pass
@standard_error("M_CAPTCHA_INVALID")
class MCaptchaInvalid(MatrixStandardRequestError):
pass
@standard_error("M_MISSING_PARAM")
class MMissingParam(MatrixBadRequest):
pass
@standard_error("M_INVALID_PARAM")
class MInvalidParam(MatrixBadRequest):
pass
@standard_error("M_TOO_LARGE")
class MTooLarge(MatrixBadRequest):
pass
@standard_error("M_EXCLUSIVE")
class MExclusive(MatrixStandardRequestError):
pass
python-0.20.4/mautrix/errors/well_known.py 0000664 0000000 0000000 00000003021 14547234302 0020654 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from .base import MatrixResponseError
class WellKnownError(MatrixResponseError):
"""
An error that occurred during server discovery.
https://matrix.org/docs/spec/client_server/latest#get-well-known-matrix-client
"""
pass
class WellKnownUnexpectedStatus(WellKnownError):
def __init__(self, status: int) -> None:
super().__init__(f"Unexpected status code {status} when fetching .well-known file")
self.status = status
class WellKnownNotJSON(WellKnownError):
def __init__(self) -> None:
super().__init__(".well-known response was not JSON")
class WellKnownMissingHomeserver(WellKnownError):
def __init__(self) -> None:
super().__init__("No homeserver found in .well-known response")
class WellKnownNotURL(WellKnownError):
def __init__(self) -> None:
super().__init__("Homeserver base URL in .well-known response was not a valid URL")
class WellKnownUnsupportedScheme(WellKnownError):
def __init__(self, scheme: str) -> None:
super().__init__(f"URL in .well-known response has unsupported scheme {scheme}")
class WellKnownInvalidVersionsResponse(WellKnownError):
def __init__(self) -> None:
super().__init__(
"URL in .well-known response didn't respond to versions endpoint properly"
)
python-0.20.4/mautrix/fixmodule.py 0000664 0000000 0000000 00000003066 14547234302 0017166 0 ustar 00root root 0000000 0000000 # This is a script that fixes the __module__ tags in mautrix-python and some libraries.
# It's used to help Sphinx/autodoc figure out where the things are canonically imported from
# (by default, it shows the exact module they're defined in rather than the top-level import path).
from typing import NewType
from types import FunctionType, ModuleType
import aiohttp
from . import appservice, bridge, client, crypto, errors, types
from .crypto import attachments
from .util import async_db, config, db, formatter, logging
def _fix(obj: ModuleType) -> None:
for item_name in getattr(obj, "__all__", None) or dir(obj):
item = getattr(obj, item_name)
if isinstance(item, (type, FunctionType, NewType)):
# Ignore backwards-compatibility imports like the BridgeState import in mautrix.bridge
if item.__module__.startswith("mautrix") and not item.__module__.startswith(
obj.__name__
):
continue
item.__module__ = obj.__name__
if isinstance(item, NewType):
# By default autodoc makes a blank "Bases:" text,
# so adjust it to show the type as the "base"
item.__bases__ = (item.__supertype__,)
# elif type(item).__module__ == "typing":
# print(obj.__name__, item_name, type(item))
_things_to_fix = [
types,
bridge,
client,
crypto,
attachments,
errors,
appservice,
async_db,
config,
db,
formatter,
logging,
aiohttp,
]
for mod in _things_to_fix:
_fix(mod)
python-0.20.4/mautrix/genall.py 0000664 0000000 0000000 00000002735 14547234302 0016436 0 ustar 00root root 0000000 0000000 # This script generates the __all__ arrays for types/__init__.py and errors/__init__.py
# to avoid having to manually add both the import and the __all__ entry.
# See https://github.com/mautrix/python/issues/90 for why __all__ is needed at all.
from pathlib import Path
import ast
import black
root_module = Path(__file__).parent
black_cfg = black.parse_pyproject_toml(str(root_module.parent / "pyproject.toml"))
black_mode = black.Mode(
target_versions={black.TargetVersion[ver.upper()] for ver in black_cfg["target_version"]},
line_length=black_cfg["line_length"],
)
def add_imports_to_all(dir: str) -> None:
init_file = root_module / dir / "__init__.py"
with open(init_file) as f:
init_ast = ast.parse(f.read(), filename=f"mautrix/{dir}/__init__.py")
imports: list[str] = []
all_node: ast.List | None = None
for node in ast.iter_child_nodes(init_ast):
if isinstance(node, (ast.Import, ast.ImportFrom)):
imports += (name.name for name in node.names)
elif isinstance(node, ast.Assign) and isinstance(node.value, ast.List):
target = node.targets[0]
if len(node.targets) == 1 and isinstance(target, ast.Name) and target.id == "__all__":
all_node = node.value
all_node.elts = [ast.Constant(name) for name in imports]
with open(init_file, "w") as f:
f.write(black.format_str(ast.unparse(init_ast), mode=black_mode))
add_imports_to_all("types")
add_imports_to_all("errors")
python-0.20.4/mautrix/py.typed 0000664 0000000 0000000 00000000000 14547234302 0016300 0 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/types/ 0000775 0000000 0000000 00000000000 14547234302 0015757 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/types/__init__.py 0000664 0000000 0000000 00000021733 14547234302 0020076 0 ustar 00root root 0000000 0000000 from .auth import (
DiscoveryInformation,
DiscoveryIntegrations,
DiscoveryIntegrationServer,
DiscoveryServer,
LoginFlow,
LoginFlowList,
LoginResponse,
LoginType,
MatrixUserIdentifier,
PhoneIdentifier,
ThirdPartyIdentifier,
UserIdentifier,
UserIdentifierType,
WhoamiResponse,
)
from .crypto import (
ClaimKeysResponse,
CrossSigner,
CrossSigningKeys,
CrossSigningUsage,
DecryptedOlmEvent,
DeviceIdentity,
DeviceKeys,
OlmEventKeys,
QueryKeysResponse,
TOFUSigningKey,
TrustState,
UnsignedDeviceInfo,
)
from .event import (
AccountDataEvent,
AccountDataEventContent,
ASToDeviceEvent,
AudioInfo,
BaseEvent,
BaseFileInfo,
BaseMessageEventContent,
BaseMessageEventContentFuncs,
BaseRoomEvent,
BaseUnsigned,
BatchSendEvent,
BatchSendStateEvent,
BeeperMessageStatusEvent,
BeeperMessageStatusEventContent,
CallAnswerEventContent,
CallCandidate,
CallCandidatesEventContent,
CallData,
CallDataType,
CallEvent,
CallEventContent,
CallHangupEventContent,
CallHangupReason,
CallInviteEventContent,
CallNegotiateEventContent,
CallRejectEventContent,
CallSelectAnswerEventContent,
CanonicalAliasStateEventContent,
EncryptedEvent,
EncryptedEventContent,
EncryptedFile,
EncryptedMegolmEventContent,
EncryptedOlmEventContent,
EncryptionAlgorithm,
EncryptionKeyAlgorithm,
EphemeralEvent,
Event,
EventContent,
EventType,
FileInfo,
Format,
ForwardedRoomKeyEventContent,
GenericEvent,
ImageInfo,
InReplyTo,
JoinRule,
JoinRulesStateEventContent,
JSONWebKey,
KeyID,
KeyRequestAction,
LocationInfo,
LocationMessageEventContent,
MediaInfo,
MediaMessageEventContent,
Membership,
MemberStateEventContent,
MessageEvent,
MessageEventContent,
MessageStatus,
MessageStatusReason,
MessageType,
MessageUnsigned,
OlmCiphertext,
OlmMsgType,
PowerLevelStateEventContent,
PresenceEvent,
PresenceEventContent,
PresenceState,
ReactionEvent,
ReactionEventContent,
ReceiptEvent,
ReceiptEventContent,
ReceiptType,
RedactionEvent,
RedactionEventContent,
RelatesTo,
RelationType,
RequestedKeyInfo,
RoomAvatarStateEventContent,
RoomCreateStateEventContent,
RoomEncryptionStateEventContent,
RoomKeyEventContent,
RoomKeyRequestEventContent,
RoomKeyWithheldCode,
RoomKeyWithheldEventContent,
RoomNameStateEventContent,
RoomPinnedEventsStateEventContent,
RoomPredecessor,
RoomTagAccountDataEventContent,
RoomTagInfo,
RoomTombstoneStateEventContent,
RoomTopicStateEventContent,
RoomType,
SingleReceiptEventContent,
SpaceChildStateEventContent,
SpaceParentStateEventContent,
StateEvent,
StateEventContent,
StateUnsigned,
StrippedStateEvent,
TextMessageEventContent,
ThumbnailInfo,
ToDeviceEvent,
ToDeviceEventContent,
TypingEvent,
TypingEventContent,
VideoInfo,
)
from .filter import EventFilter, Filter, RoomEventFilter, RoomFilter, StateFilter
from .matrixuri import IdentifierType, MatrixURI, MatrixURIError, URIAction
from .media import (
MediaCreateResponse,
MediaRepoConfig,
MXOpenGraph,
OpenGraphAudio,
OpenGraphImage,
OpenGraphVideo,
)
from .misc import (
BatchSendResponse,
BeeperBatchSendResponse,
DeviceLists,
DeviceOTKCount,
DirectoryPaginationToken,
EventContext,
PaginatedMessages,
PaginationDirection,
RoomAliasInfo,
RoomCreatePreset,
RoomDirectoryResponse,
RoomDirectoryVisibility,
)
from .primitive import (
JSON,
BatchID,
ContentURI,
DeviceID,
EventID,
FilterID,
IdentityKey,
RoomAlias,
RoomID,
SessionID,
Signature,
SigningKey,
SyncToken,
UserID,
)
from .push_rules import (
PushAction,
PushActionDict,
PushActionType,
PushCondition,
PushConditionKind,
PushOperator,
PushRule,
PushRuleID,
PushRuleKind,
PushRuleScope,
)
from .users import Member, User, UserSearchResults
from .util import (
ExtensibleEnum,
Lst,
Obj,
Serializable,
SerializableAttrs,
SerializableEnum,
SerializerError,
deserializer,
field,
serializer,
)
from .versions import SpecVersions, Version, VersionFormat, VersionsResponse
__all__ = [
"DiscoveryInformation",
"DiscoveryIntegrations",
"DiscoveryIntegrationServer",
"DiscoveryServer",
"LoginFlow",
"LoginFlowList",
"LoginResponse",
"LoginType",
"MatrixUserIdentifier",
"PhoneIdentifier",
"ThirdPartyIdentifier",
"UserIdentifier",
"UserIdentifierType",
"WhoamiResponse",
"ClaimKeysResponse",
"CrossSigner",
"CrossSigningKeys",
"CrossSigningUsage",
"DecryptedOlmEvent",
"DeviceIdentity",
"DeviceKeys",
"OlmEventKeys",
"QueryKeysResponse",
"TOFUSigningKey",
"TrustState",
"UnsignedDeviceInfo",
"AccountDataEvent",
"AccountDataEventContent",
"ASToDeviceEvent",
"AudioInfo",
"BaseEvent",
"BaseFileInfo",
"BaseMessageEventContent",
"BaseMessageEventContentFuncs",
"BaseRoomEvent",
"BaseUnsigned",
"BatchSendEvent",
"BatchSendStateEvent",
"BeeperMessageStatusEvent",
"BeeperMessageStatusEventContent",
"CallAnswerEventContent",
"CallCandidate",
"CallCandidatesEventContent",
"CallData",
"CallDataType",
"CallEvent",
"CallEventContent",
"CallHangupEventContent",
"CallHangupReason",
"CallInviteEventContent",
"CallNegotiateEventContent",
"CallRejectEventContent",
"CallSelectAnswerEventContent",
"CanonicalAliasStateEventContent",
"EncryptedEvent",
"EncryptedEventContent",
"EncryptedFile",
"EncryptedMegolmEventContent",
"EncryptedOlmEventContent",
"EncryptionAlgorithm",
"EncryptionKeyAlgorithm",
"EphemeralEvent",
"Event",
"EventContent",
"EventType",
"FileInfo",
"Format",
"ForwardedRoomKeyEventContent",
"GenericEvent",
"ImageInfo",
"InReplyTo",
"JoinRule",
"JoinRulesStateEventContent",
"JSONWebKey",
"KeyID",
"KeyRequestAction",
"LocationInfo",
"LocationMessageEventContent",
"MediaInfo",
"MediaMessageEventContent",
"Membership",
"MemberStateEventContent",
"MessageEvent",
"MessageEventContent",
"MessageStatus",
"MessageStatusReason",
"MessageType",
"MessageUnsigned",
"OlmCiphertext",
"OlmMsgType",
"PowerLevelStateEventContent",
"PresenceEvent",
"PresenceEventContent",
"PresenceState",
"ReactionEvent",
"ReactionEventContent",
"ReceiptEvent",
"ReceiptEventContent",
"ReceiptType",
"RedactionEvent",
"RedactionEventContent",
"RelatesTo",
"RelationType",
"RequestedKeyInfo",
"RoomAvatarStateEventContent",
"RoomCreateStateEventContent",
"RoomEncryptionStateEventContent",
"RoomKeyEventContent",
"RoomKeyRequestEventContent",
"RoomKeyWithheldCode",
"RoomKeyWithheldEventContent",
"RoomNameStateEventContent",
"RoomPinnedEventsStateEventContent",
"RoomPredecessor",
"RoomTagAccountDataEventContent",
"RoomTagInfo",
"RoomTombstoneStateEventContent",
"RoomTopicStateEventContent",
"RoomType",
"SingleReceiptEventContent",
"SpaceChildStateEventContent",
"SpaceParentStateEventContent",
"StateEvent",
"StateEventContent",
"StateUnsigned",
"StrippedStateEvent",
"TextMessageEventContent",
"ThumbnailInfo",
"ToDeviceEvent",
"ToDeviceEventContent",
"TypingEvent",
"TypingEventContent",
"VideoInfo",
"EventFilter",
"Filter",
"RoomEventFilter",
"RoomFilter",
"StateFilter",
"IdentifierType",
"MatrixURI",
"MatrixURIError",
"URIAction",
"MediaCreateResponse",
"MediaRepoConfig",
"MXOpenGraph",
"OpenGraphAudio",
"OpenGraphImage",
"OpenGraphVideo",
"BatchSendResponse",
"DeviceLists",
"DeviceOTKCount",
"DirectoryPaginationToken",
"EventContext",
"PaginatedMessages",
"PaginationDirection",
"RoomAliasInfo",
"RoomCreatePreset",
"RoomDirectoryResponse",
"RoomDirectoryVisibility",
"JSON",
"BatchID",
"ContentURI",
"DeviceID",
"EventID",
"FilterID",
"IdentityKey",
"RoomAlias",
"RoomID",
"SessionID",
"Signature",
"SigningKey",
"SyncToken",
"UserID",
"PushAction",
"PushActionDict",
"PushActionType",
"PushCondition",
"PushConditionKind",
"PushOperator",
"PushRule",
"PushRuleID",
"PushRuleKind",
"PushRuleScope",
"Member",
"User",
"UserSearchResults",
"ExtensibleEnum",
"Lst",
"Obj",
"Serializable",
"SerializableAttrs",
"SerializableEnum",
"SerializerError",
"deserializer",
"field",
"serializer",
"SpecVersions",
"Version",
"VersionFormat",
"VersionsResponse",
]
python-0.20.4/mautrix/types/auth.py 0000664 0000000 0000000 00000014234 14547234302 0017276 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import List, NewType, Optional, Union
from attr import dataclass
from .primitive import JSON, DeviceID, UserID
from .util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field
class LoginType(ExtensibleEnum):
"""
A login type, as specified in the `POST /login endpoint`_
.. _POST /login endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login
"""
PASSWORD: "LoginType" = "m.login.password"
TOKEN: "LoginType" = "m.login.token"
SSO: "LoginType" = "m.login.sso"
APPSERVICE: "LoginType" = "m.login.application_service"
UNSTABLE_JWT: "LoginType" = "org.matrix.login.jwt"
DEVTURE_SHARED_SECRET: "LoginType" = "com.devture.shared_secret_auth"
@dataclass
class LoginFlow(SerializableAttrs):
"""
A login flow, as specified in the `GET /login endpoint`_
.. _GET /login endpoint:
https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login
"""
type: LoginType
@dataclass
class LoginFlowList(SerializableAttrs):
flows: List[LoginFlow]
def get_first_of_type(self, *types: LoginType) -> Optional[LoginFlow]:
for flow in self.flows:
if flow.type in types:
return flow
return None
def supports_type(self, *types: LoginType) -> bool:
return self.get_first_of_type(*types) is not None
class UserIdentifierType(ExtensibleEnum):
"""
A user identifier type, as specified in the `Identifier types`_ section of the login spec.
.. _Identifier types:
https://spec.matrix.org/v1.2/client-server-api/#identifier-types
"""
MATRIX_USER: "UserIdentifierType" = "m.id.user"
THIRD_PARTY: "UserIdentifierType" = "m.id.thirdparty"
PHONE: "UserIdentifierType" = "m.id.phone"
@dataclass
class MatrixUserIdentifier(SerializableAttrs):
"""
A client can identify a user using their Matrix ID. This can either be the fully qualified
Matrix user ID, or just the localpart of the user ID.
"""
user: str
"""The Matrix user ID or localpart"""
type: UserIdentifierType = UserIdentifierType.MATRIX_USER
@dataclass
class ThirdPartyIdentifier(SerializableAttrs):
"""
A client can identify a user using a 3PID associated with the user's account on the homeserver,
where the 3PID was previously associated using the `/account/3pid`_ API. See the `3PID Types`_
Appendix for a list of Third-party ID media.
.. _/account/3pid:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3account3pid
.. _3PID Types:
https://spec.matrix.org/v1.2/appendices/#3pid-types
"""
medium: str
address: str
type: UserIdentifierType = UserIdentifierType.THIRD_PARTY
@dataclass
class PhoneIdentifier(SerializableAttrs):
"""
A client can identify a user using a phone number associated with the user's account, where the
phone number was previously associated using the `/account/3pid`_ API. The phone number can be
passed in as entered by the user; the homeserver will be responsible for canonicalising it.
If the client wishes to canonicalise the phone number, then it can use the ``m.id.thirdparty``
identifier type with a ``medium`` of ``msisdn`` instead.
.. _/account/3pid:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3account3pid
"""
country: str
phone: str
type: UserIdentifierType = UserIdentifierType.PHONE
UserIdentifier = NewType(
"UserIdentifier", Union[MatrixUserIdentifier, ThirdPartyIdentifier, PhoneIdentifier]
)
@deserializer(UserIdentifier)
def deserialize_user_identifier(data: JSON) -> Union[UserIdentifier, Obj]:
try:
identifier_type = UserIdentifierType.deserialize(data["type"])
except KeyError:
return Obj(**data)
if identifier_type == UserIdentifierType.MATRIX_USER:
return MatrixUserIdentifier.deserialize(data)
elif identifier_type == UserIdentifierType.THIRD_PARTY:
return ThirdPartyIdentifier.deserialize(data)
elif identifier_type == UserIdentifierType.PHONE:
return PhoneIdentifier.deserialize(data)
else:
return Obj(**data)
setattr(UserIdentifier, "deserialize", deserialize_user_identifier)
@dataclass
class DiscoveryServer(SerializableAttrs):
base_url: Optional[str] = None
@dataclass
class DiscoveryIntegrationServer(SerializableAttrs):
ui_url: Optional[str] = None
api_url: Optional[str] = None
@dataclass
class DiscoveryIntegrations(SerializableAttrs):
managers: List[DiscoveryIntegrationServer] = field(factory=lambda: [])
@dataclass
class DiscoveryInformation(SerializableAttrs):
"""
.well-known discovery information, as specified in the `GET /.well-known/matrix/client endpoint`_
.. _GET /.well-known/matrix/client endpoint:
https://spec.matrix.org/v1.2/client-server-api/#getwell-knownmatrixclient
"""
homeserver: Optional[DiscoveryServer] = field(json="m.homeserver", factory=DiscoveryServer)
identity_server: Optional[DiscoveryServer] = field(
json="m.identity_server", factory=DiscoveryServer
)
integrations: Optional[DiscoveryServer] = field(
json="m.integrations", factory=DiscoveryIntegrations
)
@dataclass
class LoginResponse(SerializableAttrs):
"""
The response for a login request, as specified in the `POST /login endpoint`_
.. _POST /login endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login
"""
user_id: UserID
device_id: DeviceID
access_token: str
well_known: DiscoveryInformation = field(factory=DiscoveryInformation)
@dataclass
class WhoamiResponse(SerializableAttrs):
"""
The response for a whoami request, as specified in the `GET /account/whoami endpoint`_
.. _GET /account/whoami endpoint:
https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami
"""
user_id: UserID
device_id: Optional[DeviceID] = None
is_guest: bool = False
python-0.20.4/mautrix/types/crypto.py 0000664 0000000 0000000 00000011507 14547234302 0017655 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Dict, List, NamedTuple, Optional
from enum import IntEnum
from attr import dataclass
from .event import EncryptionAlgorithm, EncryptionKeyAlgorithm, KeyID, ToDeviceEvent
from .primitive import DeviceID, IdentityKey, Signature, SigningKey, UserID
from .util import ExtensibleEnum, SerializableAttrs, field
@dataclass
class UnsignedDeviceInfo(SerializableAttrs):
device_display_name: Optional[str] = None
@dataclass
class DeviceKeys(SerializableAttrs):
user_id: UserID
device_id: DeviceID
algorithms: List[EncryptionAlgorithm]
keys: Dict[KeyID, str]
signatures: Dict[UserID, Dict[KeyID, Signature]]
unsigned: UnsignedDeviceInfo = None
def __attrs_post_init__(self) -> None:
if self.unsigned is None:
self.unsigned = UnsignedDeviceInfo()
@property
def ed25519(self) -> Optional[SigningKey]:
try:
return SigningKey(self.keys[KeyID(EncryptionKeyAlgorithm.ED25519, self.device_id)])
except KeyError:
return None
@property
def curve25519(self) -> Optional[IdentityKey]:
try:
return IdentityKey(self.keys[KeyID(EncryptionKeyAlgorithm.CURVE25519, self.device_id)])
except KeyError:
return None
class CrossSigningUsage(ExtensibleEnum):
MASTER = "master"
SELF = "self_signing"
USER = "user_signing"
@dataclass
class CrossSigningKeys(SerializableAttrs):
user_id: UserID
usage: List[CrossSigningUsage]
keys: Dict[KeyID, SigningKey]
signatures: Dict[UserID, Dict[KeyID, Signature]] = field(factory=lambda: {})
@property
def first_key(self) -> Optional[SigningKey]:
try:
return next(iter(self.keys.values()))
except StopIteration:
return None
@property
def first_ed25519_key(self) -> Optional[SigningKey]:
return self.first_key_with_algorithm(EncryptionKeyAlgorithm.ED25519)
def first_key_with_algorithm(self, alg: EncryptionKeyAlgorithm) -> Optional[SigningKey]:
if not self.keys:
return None
try:
return next(key for key_id, key in self.keys.items() if key_id.algorithm == alg)
except StopIteration:
return None
@dataclass
class QueryKeysResponse(SerializableAttrs):
device_keys: Dict[UserID, Dict[DeviceID, DeviceKeys]] = field(factory=lambda: {})
master_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {})
self_signing_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {})
user_signing_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {})
failures: Dict[str, Any] = field(factory=lambda: {})
@dataclass
class ClaimKeysResponse(SerializableAttrs):
one_time_keys: Dict[UserID, Dict[DeviceID, Dict[KeyID, Any]]]
failures: Dict[str, Any] = field(factory=lambda: {})
class TrustState(IntEnum):
BLACKLISTED = -100
UNVERIFIED = 0
UNKNOWN_DEVICE = 10
FORWARDED = 20
CROSS_SIGNED_UNTRUSTED = 50
CROSS_SIGNED_TOFU = 100
CROSS_SIGNED_TRUSTED = 200
VERIFIED = 300
def __str__(self) -> str:
return _trust_state_to_name[self]
@classmethod
def parse(cls, val: str) -> "TrustState":
try:
return _name_to_trust_state[val]
except KeyError as e:
raise ValueError(f"Invalid trust state {val!r}") from e
_trust_state_to_name: Dict[TrustState, str] = {
val: val.name.lower().replace("_", "-") for val in TrustState
}
_name_to_trust_state: Dict[str, TrustState] = {
value: key for key, value in _trust_state_to_name.items()
}
@dataclass
class DeviceIdentity:
user_id: UserID
device_id: DeviceID
identity_key: IdentityKey
signing_key: SigningKey
trust: TrustState
deleted: bool
name: str
@dataclass
class OlmEventKeys(SerializableAttrs):
ed25519: SigningKey
@dataclass
class DecryptedOlmEvent(ToDeviceEvent, SerializableAttrs):
keys: OlmEventKeys
recipient: UserID
recipient_keys: OlmEventKeys
sender_device: Optional[DeviceID] = None
sender_key: IdentityKey = field(hidden=True, default=None)
class TOFUSigningKey(NamedTuple):
"""
A tuple representing a single cross-signing key. The first value is the current key, and the
second value is the first seen key. If the values don't match, it means the key is not valid
for trust-on-first-use.
"""
key: SigningKey
first: SigningKey
class CrossSigner(NamedTuple):
"""
A tuple containing a user ID and a signing key they own.
The key can either be a device-owned signing key, or one of the user's cross-signing keys.
"""
user_id: UserID
key: SigningKey
python-0.20.4/mautrix/types/event/ 0000775 0000000 0000000 00000000000 14547234302 0017100 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/types/event/__init__.py 0000664 0000000 0000000 00000005712 14547234302 0021216 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from .account_data import (
AccountDataEvent,
AccountDataEventContent,
RoomTagAccountDataEventContent,
RoomTagInfo,
)
from .base import BaseEvent, BaseRoomEvent, BaseUnsigned, GenericEvent
from .batch import BatchSendEvent, BatchSendStateEvent
from .beeper import (
BeeperMessageStatusEvent,
BeeperMessageStatusEventContent,
MessageStatus,
MessageStatusReason,
)
from .encrypted import (
EncryptedEvent,
EncryptedEventContent,
EncryptedMegolmEventContent,
EncryptedOlmEventContent,
EncryptionAlgorithm,
EncryptionKeyAlgorithm,
KeyID,
OlmCiphertext,
OlmMsgType,
)
from .ephemeral import (
EphemeralEvent,
PresenceEvent,
PresenceEventContent,
PresenceState,
ReceiptEvent,
ReceiptEventContent,
ReceiptType,
SingleReceiptEventContent,
TypingEvent,
TypingEventContent,
)
from .generic import Event, EventContent
from .message import (
AudioInfo,
BaseFileInfo,
BaseMessageEventContent,
BaseMessageEventContentFuncs,
EncryptedFile,
FileInfo,
Format,
ImageInfo,
InReplyTo,
JSONWebKey,
LocationInfo,
LocationMessageEventContent,
MediaInfo,
MediaMessageEventContent,
MessageEvent,
MessageEventContent,
MessageType,
MessageUnsigned,
RelatesTo,
RelationType,
TextMessageEventContent,
ThumbnailInfo,
VideoInfo,
)
from .reaction import ReactionEvent, ReactionEventContent
from .redaction import RedactionEvent, RedactionEventContent
from .state import (
CanonicalAliasStateEventContent,
JoinRule,
JoinRulesStateEventContent,
Membership,
MemberStateEventContent,
PowerLevelStateEventContent,
RoomAvatarStateEventContent,
RoomCreateStateEventContent,
RoomEncryptionStateEventContent,
RoomNameStateEventContent,
RoomPinnedEventsStateEventContent,
RoomPredecessor,
RoomTombstoneStateEventContent,
RoomTopicStateEventContent,
RoomType,
SpaceChildStateEventContent,
SpaceParentStateEventContent,
StateEvent,
StateEventContent,
StateUnsigned,
StrippedStateEvent,
)
from .to_device import (
ASToDeviceEvent,
ForwardedRoomKeyEventContent,
KeyRequestAction,
RequestedKeyInfo,
RoomKeyEventContent,
RoomKeyRequestEventContent,
RoomKeyWithheldCode,
RoomKeyWithheldEventContent,
ToDeviceEvent,
ToDeviceEventContent,
)
from .type import EventType
from .voip import (
CallAnswerEventContent,
CallCandidate,
CallCandidatesEventContent,
CallData,
CallDataType,
CallEvent,
CallEventContent,
CallHangupEventContent,
CallHangupReason,
CallInviteEventContent,
CallNegotiateEventContent,
CallRejectEventContent,
CallSelectAnswerEventContent,
)
python-0.20.4/mautrix/types/event/account_data.py 0000664 0000000 0000000 00000003614 14547234302 0022103 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, Union
from attr import dataclass
import attr
from ..primitive import JSON, RoomID, UserID
from ..util import Obj, SerializableAttrs, deserializer
from .base import BaseEvent, EventType
@dataclass
class RoomTagInfo(SerializableAttrs):
order: Union[int, float, str] = None
@dataclass
class RoomTagAccountDataEventContent(SerializableAttrs):
tags: Dict[str, RoomTagInfo] = attr.ib(default=None, metadata={"json": "tags"})
DirectAccountDataEventContent = Dict[UserID, List[RoomID]]
AccountDataEventContent = Union[RoomTagAccountDataEventContent, DirectAccountDataEventContent, Obj]
account_data_event_content_map = {
EventType.TAG: RoomTagAccountDataEventContent,
# m.direct doesn't really need deserializing
# EventType.DIRECT: DirectAccountDataEventContent,
}
# TODO remaining account data event types
@dataclass
class AccountDataEvent(BaseEvent, SerializableAttrs):
content: AccountDataEventContent
@classmethod
def deserialize(cls, data: JSON) -> "AccountDataEvent":
try:
evt_type = EventType.find(data.get("type"))
data.get("content", {})["__mautrix_event_type"] = evt_type
except ValueError:
return Obj(**data)
evt = super().deserialize(data)
evt.type = evt_type
return evt
@staticmethod
@deserializer(AccountDataEventContent)
def deserialize_content(data: JSON) -> AccountDataEventContent:
evt_type = data.pop("__mautrix_event_type", None)
content_type = account_data_event_content_map.get(evt_type, None)
if not content_type:
return Obj(**data)
return content_type.deserialize(data)
python-0.20.4/mautrix/types/event/base.py 0000664 0000000 0000000 00000002740 14547234302 0020367 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
from attr import dataclass
import attr
from ..primitive import EventID, RoomID, UserID
from ..util import Obj, SerializableAttrs
from .type import EventType
@dataclass
class BaseUnsigned:
"""Base unsigned information."""
age: int = None
@dataclass
class BaseEvent:
"""Base event class. The only things an event **must** have are content and event type."""
content: Obj
type: EventType
@dataclass
class BaseRoomEvent(BaseEvent):
"""Base room event class. Room events must have a room ID, event ID, sender and timestamp in
addition to the content and type in the base event."""
room_id: RoomID
event_id: EventID
sender: UserID
timestamp: int = attr.ib(metadata={"json": "origin_server_ts"})
@dataclass
class GenericEvent(BaseEvent, SerializableAttrs):
"""
An event class that contains all possible top-level event keys and uses generic Obj's for object
keys (content and unsigned)
"""
content: Obj
type: EventType
room_id: Optional[RoomID] = None
event_id: Optional[EventID] = None
sender: Optional[UserID] = None
timestamp: Optional[int] = None
state_key: Optional[str] = None
unsigned: Obj = None
redacts: Optional[EventID] = None
python-0.20.4/mautrix/types/event/batch.py 0000664 0000000 0000000 00000002011 14547234302 0020525 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan, Sumner Evans
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Optional
from attr import dataclass
import attr
from ..primitive import EventID, UserID
from ..util import SerializableAttrs
from .base import BaseEvent
@dataclass(kw_only=True)
class BatchSendEvent(BaseEvent, SerializableAttrs):
"""Base event class for events sent via a batch send request."""
sender: UserID
timestamp: int = attr.ib(metadata={"json": "origin_server_ts"})
content: Any
# N.B. Overriding event IDs is not allowed in standard room versions
event_id: Optional[EventID] = None
@dataclass(kw_only=True)
class BatchSendStateEvent(BatchSendEvent, SerializableAttrs):
"""
State events to be used as initial state events on batch send events. These never need to be
deserialized.
"""
state_key: str
python-0.20.4/mautrix/types/event/beeper.py 0000664 0000000 0000000 00000003651 14547234302 0020721 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
from attr import dataclass
from ..primitive import EventID, RoomID, SessionID
from ..util import SerializableAttrs, SerializableEnum, field
from .base import BaseRoomEvent
from .message import RelatesTo
class MessageStatusReason(SerializableEnum):
GENERIC_ERROR = "m.event_not_handled"
UNSUPPORTED = "com.beeper.unsupported_event"
UNDECRYPTABLE = "com.beeper.undecryptable_event"
TOO_OLD = "m.event_too_old"
NETWORK_ERROR = "m.foreign_network_error"
NO_PERMISSION = "m.no_permission"
@property
def checkpoint_status(self):
from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
if self == MessageStatusReason.UNSUPPORTED:
return MessageSendCheckpointStatus.UNSUPPORTED
elif self == MessageStatusReason.TOO_OLD:
return MessageSendCheckpointStatus.TIMEOUT
return MessageSendCheckpointStatus.PERM_FAILURE
class MessageStatus(SerializableEnum):
SUCCESS = "SUCCESS"
PENDING = "PENDING"
RETRIABLE = "FAIL_RETRIABLE"
FAIL = "FAIL_PERMANENT"
@dataclass(kw_only=True)
class BeeperMessageStatusEventContent(SerializableAttrs):
relates_to: RelatesTo = field(json="m.relates_to")
network: str = ""
status: Optional[MessageStatus] = None
reason: Optional[MessageStatusReason] = None
error: Optional[str] = None
message: Optional[str] = None
last_retry: Optional[EventID] = None
@dataclass
class BeeperMessageStatusEvent(BaseRoomEvent, SerializableAttrs):
content: BeeperMessageStatusEventContent
@dataclass
class BeeperRoomKeyAckEventContent(SerializableAttrs):
room_id: RoomID
session_id: SessionID
first_message_index: int
python-0.20.4/mautrix/types/event/encrypted.py 0000664 0000000 0000000 00000010744 14547234302 0021455 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, NewType, Optional, Union
from enum import IntEnum
import warnings
from attr import dataclass
from ..primitive import JSON, DeviceID, IdentityKey, SessionID
from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field
from .base import BaseRoomEvent, BaseUnsigned
from .message import RelatesTo
class EncryptionAlgorithm(ExtensibleEnum):
OLM_V1: "EncryptionAlgorithm" = "m.olm.v1.curve25519-aes-sha2"
MEGOLM_V1: "EncryptionAlgorithm" = "m.megolm.v1.aes-sha2"
class EncryptionKeyAlgorithm(ExtensibleEnum):
CURVE25519: "EncryptionKeyAlgorithm" = "curve25519"
ED25519: "EncryptionKeyAlgorithm" = "ed25519"
SIGNED_CURVE25519: "EncryptionKeyAlgorithm" = "signed_curve25519"
@dataclass(frozen=True)
class KeyID(Serializable):
algorithm: EncryptionKeyAlgorithm
key_id: str
def serialize(self) -> JSON:
return str(self)
@classmethod
def deserialize(cls, raw: JSON) -> "KeyID":
assert isinstance(raw, str), "key IDs must be strings"
alg, key_id = raw.split(":", 1)
return cls(EncryptionKeyAlgorithm(alg), key_id)
def __str__(self) -> str:
return f"{self.algorithm.value}:{self.key_id}"
class OlmMsgType(Serializable, IntEnum):
PREKEY = 0
MESSAGE = 1
def serialize(self) -> JSON:
return self.value
@classmethod
def deserialize(cls, raw: JSON) -> "OlmMsgType":
return cls(raw)
@dataclass
class OlmCiphertext(SerializableAttrs):
body: str
type: OlmMsgType
@dataclass
class EncryptedOlmEventContent(SerializableAttrs):
ciphertext: Dict[str, OlmCiphertext]
sender_key: IdentityKey
algorithm: EncryptionAlgorithm = EncryptionAlgorithm.OLM_V1
@dataclass
class EncryptedMegolmEventContent(SerializableAttrs):
"""The content of an m.room.encrypted event"""
ciphertext: str
session_id: SessionID
algorithm: EncryptionAlgorithm = EncryptionAlgorithm.MEGOLM_V1
_sender_key: Optional[IdentityKey] = field(default=None, json="sender_key")
_device_id: Optional[DeviceID] = field(default=None, json="device_id")
_relates_to: Optional[RelatesTo] = field(default=None, json="m.relates_to")
@property
def sender_key(self) -> Optional[IdentityKey]:
"""
.. deprecated:: 0.17.0
Matrix v1.3 deprecated the device_id and sender_key fields in megolm events.
"""
warnings.warn(
"The sender_key field in Megolm events was deprecated in Matrix 1.3",
DeprecationWarning,
)
return self._sender_key
@property
def device_id(self) -> Optional[DeviceID]:
"""
.. deprecated:: 0.17.0
Matrix v1.3 deprecated the device_id and sender_key fields in megolm events.
"""
warnings.warn(
"The sender_key field in Megolm events was deprecated in Matrix 1.3",
DeprecationWarning,
)
return self._device_id
@property
def relates_to(self) -> RelatesTo:
if self._relates_to is None:
self._relates_to = RelatesTo()
return self._relates_to
@relates_to.setter
def relates_to(self, relates_to: RelatesTo) -> None:
self._relates_to = relates_to
EncryptedEventContent = NewType(
"EncryptedEventContent", Union[EncryptedOlmEventContent, EncryptedMegolmEventContent]
)
@deserializer(EncryptedEventContent)
def deserialize_encrypted(data: JSON) -> Union[EncryptedEventContent, Obj]:
alg = data.get("algorithm", None)
if alg == EncryptionAlgorithm.MEGOLM_V1.value:
return EncryptedMegolmEventContent.deserialize(data)
elif alg == EncryptionAlgorithm.OLM_V1.value:
return EncryptedOlmEventContent.deserialize(data)
return Obj(**data)
setattr(EncryptedEventContent, "deserialize", deserialize_encrypted)
@dataclass
class EncryptedEvent(BaseRoomEvent, SerializableAttrs):
"""A m.room.encrypted event"""
content: EncryptedEventContent
_unsigned: Optional[BaseUnsigned] = field(default=None, json="unsigned")
@property
def unsigned(self) -> BaseUnsigned:
if not self._unsigned:
self._unsigned = BaseUnsigned()
return self._unsigned
@unsigned.setter
def unsigned(self, value: BaseUnsigned) -> None:
self._unsigned = value
python-0.20.4/mautrix/types/event/ephemeral.py 0000664 0000000 0000000 00000004215 14547234302 0021416 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, NewType, Union
from attr import dataclass
from ..primitive import JSON, EventID, RoomID, UserID
from ..util import ExtensibleEnum, SerializableAttrs, SerializableEnum, deserializer
from .base import BaseEvent, GenericEvent
from .type import EventType
@dataclass
class TypingEventContent(SerializableAttrs):
user_ids: List[UserID]
@dataclass
class TypingEvent(BaseEvent, SerializableAttrs):
room_id: RoomID
content: TypingEventContent
class PresenceState(SerializableEnum):
ONLINE = "online"
OFFLINE = "offline"
UNAVAILABLE = "unavailable"
@dataclass
class PresenceEventContent(SerializableAttrs):
presence: PresenceState
last_active_ago: int = None
status_msg: str = None
currently_active: bool = None
@dataclass
class PresenceEvent(BaseEvent, SerializableAttrs):
sender: UserID
content: PresenceEventContent
@dataclass
class SingleReceiptEventContent(SerializableAttrs):
ts: int
class ReceiptType(ExtensibleEnum):
READ = "m.read"
READ_PRIVATE = "m.read.private"
ReceiptEventContent = Dict[EventID, Dict[ReceiptType, Dict[UserID, SingleReceiptEventContent]]]
@dataclass
class ReceiptEvent(BaseEvent, SerializableAttrs):
room_id: RoomID
content: ReceiptEventContent
EphemeralEvent = NewType("EphemeralEvent", Union[PresenceEvent, TypingEvent, ReceiptEvent])
@deserializer(EphemeralEvent)
def deserialize_ephemeral_event(data: JSON) -> EphemeralEvent:
event_type = EventType.find(data.get("type", None))
if event_type == EventType.RECEIPT:
evt = ReceiptEvent.deserialize(data)
elif event_type == EventType.TYPING:
evt = TypingEvent.deserialize(data)
elif event_type == EventType.PRESENCE:
evt = PresenceEvent.deserialize(data)
else:
evt = GenericEvent.deserialize(data)
evt.type = event_type
return evt
setattr(EphemeralEvent, "deserialize", deserialize_ephemeral_event)
python-0.20.4/mautrix/types/event/generic.py 0000664 0000000 0000000 00000006013 14547234302 0021066 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import NewType, Union
from ..primitive import JSON
from ..util import Obj, deserializer
from .account_data import AccountDataEvent, AccountDataEventContent
from .base import EventType, GenericEvent
from .beeper import BeeperMessageStatusEvent, BeeperMessageStatusEventContent
from .encrypted import EncryptedEvent, EncryptedEventContent
from .ephemeral import (
EphemeralEvent,
PresenceEvent,
ReceiptEvent,
ReceiptEventContent,
TypingEvent,
TypingEventContent,
)
from .message import MessageEvent, MessageEventContent
from .reaction import ReactionEvent, ReactionEventContent
from .redaction import RedactionEvent, RedactionEventContent
from .state import StateEvent, StateEventContent
from .to_device import ASToDeviceEvent, ToDeviceEvent, ToDeviceEventContent
from .voip import CallEvent, CallEventContent, type_to_class as voip_types
Event = NewType(
"Event",
Union[
MessageEvent,
ReactionEvent,
RedactionEvent,
StateEvent,
TypingEvent,
ReceiptEvent,
PresenceEvent,
EncryptedEvent,
ToDeviceEvent,
ASToDeviceEvent,
CallEvent,
BeeperMessageStatusEvent,
GenericEvent,
],
)
EventContent = Union[
MessageEventContent,
RedactionEventContent,
ReactionEventContent,
StateEventContent,
AccountDataEventContent,
ReceiptEventContent,
TypingEventContent,
EncryptedEventContent,
ToDeviceEventContent,
CallEventContent,
BeeperMessageStatusEventContent,
Obj,
]
@deserializer(Event)
def deserialize_event(data: JSON) -> Event:
event_type = EventType.find(data.get("type", None))
if event_type == EventType.ROOM_MESSAGE:
return MessageEvent.deserialize(data)
elif event_type == EventType.STICKER:
data.get("content", {})["msgtype"] = "m.sticker"
return MessageEvent.deserialize(data)
elif event_type == EventType.REACTION:
return ReactionEvent.deserialize(data)
elif event_type == EventType.ROOM_REDACTION:
return RedactionEvent.deserialize(data)
elif event_type == EventType.ROOM_ENCRYPTED:
return EncryptedEvent.deserialize(data)
elif event_type in voip_types.keys():
return CallEvent.deserialize(data, event_type=event_type)
elif event_type.is_to_device:
return ToDeviceEvent.deserialize(data)
elif event_type.is_state:
return StateEvent.deserialize(data)
elif event_type.is_account_data:
return AccountDataEvent.deserialize(data)
elif event_type.is_ephemeral:
return EphemeralEvent.deserialize(data)
elif event_type == EventType.BEEPER_MESSAGE_STATUS:
return BeeperMessageStatusEvent.deserialize(data)
else:
return GenericEvent.deserialize(data)
setattr(Event, "deserialize", deserialize_event)
python-0.20.4/mautrix/types/event/message.py 0000664 0000000 0000000 00000035275 14547234302 0021112 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, Optional, Pattern, Union
from html import escape
import re
from attr import dataclass
import attr
from ..primitive import JSON, ContentURI, EventID
from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field
from .base import BaseRoomEvent, BaseUnsigned
# region Message types
class Format(ExtensibleEnum):
"""A message format. Currently only ``org.matrix.custom.html`` is available.
This will probably be deprecated when extensible events are implemented."""
HTML: "Format" = "org.matrix.custom.html"
TEXT_MESSAGE_TYPES = ("m.text", "m.emote", "m.notice")
MEDIA_MESSAGE_TYPES = ("m.image", "m.sticker", "m.video", "m.audio", "m.file")
class MessageType(ExtensibleEnum):
"""A message type."""
TEXT: "MessageType" = "m.text"
EMOTE: "MessageType" = "m.emote"
NOTICE: "MessageType" = "m.notice"
IMAGE: "MessageType" = "m.image"
STICKER: "MessageType" = "m.sticker"
VIDEO: "MessageType" = "m.video"
AUDIO: "MessageType" = "m.audio"
FILE: "MessageType" = "m.file"
LOCATION: "MessageType" = "m.location"
@property
def is_text(self) -> bool:
return self.value in TEXT_MESSAGE_TYPES
@property
def is_media(self) -> bool:
return self.value in MEDIA_MESSAGE_TYPES
# endregion
# region Relations
@dataclass
class InReplyTo(SerializableAttrs):
event_id: EventID
class RelationType(ExtensibleEnum):
ANNOTATION: "RelationType" = "m.annotation"
REFERENCE: "RelationType" = "m.reference"
REPLACE: "RelationType" = "m.replace"
THREAD: "RelationType" = "m.thread"
@dataclass
class RelatesTo(SerializableAttrs):
"""Message relations. Used for reactions, edits and replies."""
rel_type: RelationType = None
event_id: Optional[EventID] = None
key: Optional[str] = None
is_falling_back: Optional[bool] = None
in_reply_to: Optional[InReplyTo] = field(default=None, json="m.in_reply_to")
def __bool__(self) -> bool:
return (bool(self.rel_type) and bool(self.event_id)) or bool(self.in_reply_to)
def serialize(self) -> JSON:
if not self:
return attr.NOTHING
return super().serialize()
# endregion
# region Base event content
class BaseMessageEventContentFuncs:
"""Base class for the contents of all message-type events (currently m.room.message and
m.sticker). Contains relation helpers."""
body: str
_relates_to: Optional[RelatesTo]
def set_reply(self, reply_to: Union[EventID, "MessageEvent"], **kwargs) -> None:
self.relates_to.in_reply_to = InReplyTo(
event_id=reply_to if isinstance(reply_to, str) else reply_to.event_id
)
def set_thread_parent(
self,
thread_parent: Union[EventID, "MessageEvent"],
last_event_in_thread: Union[EventID, "MessageEvent", None] = None,
disable_reply_fallback: bool = False,
**kwargs,
) -> None:
self.relates_to.rel_type = RelationType.THREAD
self.relates_to.event_id = (
thread_parent if isinstance(thread_parent, str) else thread_parent.event_id
)
if isinstance(thread_parent, MessageEvent) and isinstance(
thread_parent.content, BaseMessageEventContentFuncs
):
self.relates_to.event_id = (
thread_parent.content.get_thread_parent() or self.relates_to.event_id
)
if not disable_reply_fallback:
self.set_reply(last_event_in_thread or thread_parent, disable_fallback=True, **kwargs)
self.relates_to.is_falling_back = True
def set_edit(self, edits: Union[EventID, "MessageEvent"]) -> None:
self.relates_to.rel_type = RelationType.REPLACE
self.relates_to.event_id = edits if isinstance(edits, str) else edits.event_id
# Library consumers may create message content by setting a reply first,
# then later marking it as an edit. As edits can't change the reply, just remove
# the reply metadata when marking as a reply.
if self.relates_to.in_reply_to:
self.relates_to.in_reply_to = None
self.relates_to.is_falling_back = None
def serialize(self) -> JSON:
data = SerializableAttrs.serialize(self)
evt = self.get_edit()
if evt:
new_content = {**data}
del new_content["m.relates_to"]
data["m.new_content"] = new_content
if "body" in data:
data["body"] = f"* {data['body']}"
if "formatted_body" in data:
data["formatted_body"] = f"* {data['formatted_body']}"
return data
@property
def relates_to(self) -> RelatesTo:
if self._relates_to is None:
self._relates_to = RelatesTo()
return self._relates_to
@relates_to.setter
def relates_to(self, relates_to: RelatesTo) -> None:
self._relates_to = relates_to
def get_reply_to(self) -> Optional[EventID]:
if self._relates_to and self._relates_to.in_reply_to:
return self._relates_to.in_reply_to.event_id
return None
def get_edit(self) -> Optional[EventID]:
if self._relates_to and self._relates_to.rel_type == RelationType.REPLACE:
return self._relates_to.event_id
return None
def get_thread_parent(self) -> Optional[EventID]:
if self._relates_to and self._relates_to.rel_type == RelationType.THREAD:
return self._relates_to.event_id
return None
def trim_reply_fallback(self) -> None:
pass
@dataclass
class BaseMessageEventContent(BaseMessageEventContentFuncs):
"""Base event content for all m.room.message-type events."""
msgtype: MessageType = None
body: str = ""
external_url: str = None
_relates_to: Optional[RelatesTo] = attr.ib(default=None, metadata={"json": "m.relates_to"})
# endregion
# region Media info
@dataclass
class JSONWebKey(SerializableAttrs):
key: str = attr.ib(metadata={"json": "k"})
algorithm: str = attr.ib(default="A256CTR", metadata={"json": "alg"})
extractable: bool = attr.ib(default=True, metadata={"json": "ext"})
key_type: str = attr.ib(default="oct", metadata={"json": "kty"})
key_ops: List[str] = attr.ib(factory=lambda: ["encrypt", "decrypt"])
@dataclass
class EncryptedFile(SerializableAttrs):
key: JSONWebKey
iv: str
hashes: Dict[str, str]
url: Optional[ContentURI] = None
version: str = attr.ib(default="v2", metadata={"json": "v"})
@dataclass
class BaseFileInfo(SerializableAttrs):
mimetype: str = None
size: int = None
@dataclass
class ThumbnailInfo(BaseFileInfo, SerializableAttrs):
"""Information about the thumbnail for a document, video, image or location."""
height: int = attr.ib(default=None, metadata={"json": "h"})
width: int = attr.ib(default=None, metadata={"json": "w"})
orientation: int = None
@dataclass
class FileInfo(BaseFileInfo, SerializableAttrs):
"""Information about a document message."""
thumbnail_info: Optional[ThumbnailInfo] = None
thumbnail_file: Optional[EncryptedFile] = None
thumbnail_url: Optional[ContentURI] = None
@dataclass
class ImageInfo(FileInfo, SerializableAttrs):
"""Information about an image message."""
height: int = attr.ib(default=None, metadata={"json": "h"})
width: int = attr.ib(default=None, metadata={"json": "w"})
orientation: int = None
@dataclass
class VideoInfo(ImageInfo, SerializableAttrs):
"""Information about a video message."""
duration: int = None
orientation: int = None
@dataclass
class AudioInfo(BaseFileInfo, SerializableAttrs):
"""Information about an audio message."""
duration: int = None
MediaInfo = Union[ImageInfo, VideoInfo, AudioInfo, FileInfo, Obj]
@dataclass
class LocationInfo(SerializableAttrs):
"""Information about a location message."""
thumbnail_url: Optional[ContentURI] = None
thumbnail_info: Optional[ThumbnailInfo] = None
thumbnail_file: Optional[EncryptedFile] = None
# endregion
# region Event content
@dataclass
class MediaMessageEventContent(BaseMessageEventContent, SerializableAttrs):
"""The content of a media message event (m.image, m.audio, m.video, m.file)"""
url: Optional[ContentURI] = None
info: Optional[MediaInfo] = None
file: Optional[EncryptedFile] = None
@staticmethod
@deserializer(MediaInfo)
@deserializer(Optional[MediaInfo])
def deserialize_info(data: JSON) -> MediaInfo:
if not isinstance(data, dict):
return Obj()
msgtype = data.pop("__mautrix_msgtype", None)
if msgtype == "m.image" or msgtype == "m.sticker":
return ImageInfo.deserialize(data)
elif msgtype == "m.video":
return VideoInfo.deserialize(data)
elif msgtype == "m.audio":
return AudioInfo.deserialize(data)
elif msgtype == "m.file":
return FileInfo.deserialize(data)
else:
return Obj(**data)
@dataclass
class LocationMessageEventContent(BaseMessageEventContent, SerializableAttrs):
geo_uri: str = None
info: LocationInfo = None
html_reply_fallback_regex: Pattern = re.compile(r"^[\s\S]+")
@dataclass
class TextMessageEventContent(BaseMessageEventContent, SerializableAttrs):
"""The content of a text message event (m.text, m.notice, m.emote)"""
format: Format = None
formatted_body: str = None
def set_reply(
self,
reply_to: Union["MessageEvent", EventID],
*,
displayname: Optional[str] = None,
disable_fallback: bool = False,
) -> None:
super().set_reply(reply_to)
if isinstance(reply_to, str):
return
if isinstance(reply_to, MessageEvent) and not disable_fallback:
self.ensure_has_html()
if isinstance(reply_to.content, TextMessageEventContent):
reply_to.content.trim_reply_fallback()
self.formatted_body = (
reply_to.make_reply_fallback_html(displayname) + self.formatted_body
)
self.body = reply_to.make_reply_fallback_text(displayname) + self.body
def ensure_has_html(self) -> None:
if not self.formatted_body or self.format != Format.HTML:
self.format = Format.HTML
self.formatted_body = escape(self.body).replace("\n", "
")
def formatted(self, format: Format) -> Optional[str]:
if self.format == format:
return self.formatted_body
return None
def trim_reply_fallback(self) -> None:
if self.get_reply_to() and not getattr(self, "__reply_fallback_trimmed", False):
self._trim_reply_fallback_text()
self._trim_reply_fallback_html()
setattr(self, "__reply_fallback_trimmed", True)
def _trim_reply_fallback_text(self) -> None:
if (
not self.body.startswith("> <") and not self.body.startswith("> * <")
) or "\n" not in self.body:
return
lines = self.body.split("\n")
while len(lines) > 0 and lines[0].startswith("> "):
lines.pop(0)
self.body = "\n".join(lines).strip()
def _trim_reply_fallback_html(self) -> None:
if self.formatted_body and self.format == Format.HTML:
self.formatted_body = html_reply_fallback_regex.sub("", self.formatted_body)
MessageEventContent = Union[
TextMessageEventContent, MediaMessageEventContent, LocationMessageEventContent, Obj
]
# endregion
@dataclass
class MessageUnsigned(BaseUnsigned, SerializableAttrs):
"""Unsigned information sent with message events."""
transaction_id: str = None
html_reply_fallback_format = (
""
"In reply to "
"{displayname}
"
"{content}"
"
"
)
media_reply_fallback_body_map = {
MessageType.IMAGE: "an image",
MessageType.STICKER: "a sticker",
MessageType.AUDIO: "audio",
MessageType.VIDEO: "a video",
MessageType.FILE: "a file",
MessageType.LOCATION: "a location",
}
@dataclass
class MessageEvent(BaseRoomEvent, SerializableAttrs):
"""An m.room.message event"""
content: MessageEventContent
unsigned: Optional[MessageUnsigned] = field(factory=lambda: MessageUnsigned())
@staticmethod
@deserializer(MessageEventContent)
def deserialize_content(data: JSON) -> MessageEventContent:
if not isinstance(data, dict):
return Obj()
rel = data.get("m.relates_to", None) or {}
if rel.get("rel_type", None) == RelationType.REPLACE.value:
data = data.get("m.new_content", data)
data["m.relates_to"] = rel
msgtype = data.get("msgtype", None)
if msgtype in TEXT_MESSAGE_TYPES:
return TextMessageEventContent.deserialize(data)
elif msgtype in MEDIA_MESSAGE_TYPES:
data.get("info", {})["__mautrix_msgtype"] = msgtype
return MediaMessageEventContent.deserialize(data)
elif msgtype == "m.location":
return LocationMessageEventContent.deserialize(data)
else:
return Obj(**data)
def make_reply_fallback_html(self, displayname: Optional[str] = None) -> str:
"""Generate the HTML fallback for messages replying to this event."""
if self.content.msgtype.is_text:
body = self.content.formatted_body or escape(self.content.body).replace("\n", "
")
else:
sent_type = media_reply_fallback_body_map[self.content.msgtype] or "a message"
body = f"sent {sent_type}"
displayname = escape(displayname) if displayname else self.sender
return html_reply_fallback_format.format(
room_id=self.room_id,
event_id=self.event_id,
sender=self.sender,
displayname=displayname,
content=body,
)
def make_reply_fallback_text(self, displayname: Optional[str] = None) -> str:
"""Generate the plaintext fallback for messages replying to this event."""
if self.content.msgtype.is_text:
body = self.content.body
else:
try:
body = media_reply_fallback_body_map[self.content.msgtype]
except KeyError:
body = "an unknown message type"
lines = body.strip().split("\n")
first_line, lines = lines[0], lines[1:]
fallback_text = f"> <{displayname or self.sender}> {first_line}"
for line in lines:
fallback_text += f"\n> {line}"
fallback_text += "\n\n"
return fallback_text
python-0.20.4/mautrix/types/event/reaction.py 0000664 0000000 0000000 00000002564 14547234302 0021265 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
from attr import dataclass
import attr
from ..util import SerializableAttrs
from .base import BaseRoomEvent, BaseUnsigned
from .message import RelatesTo
@dataclass
class ReactionEventContent(SerializableAttrs):
"""The content of an m.reaction event"""
_relates_to: Optional[RelatesTo] = attr.ib(default=None, metadata={"json": "m.relates_to"})
@property
def relates_to(self) -> RelatesTo:
if self._relates_to is None:
self._relates_to = RelatesTo()
return self._relates_to
@relates_to.setter
def relates_to(self, relates_to: RelatesTo) -> None:
self._relates_to = relates_to
@dataclass
class ReactionEvent(BaseRoomEvent, SerializableAttrs):
"""A m.reaction event"""
content: ReactionEventContent
_unsigned: Optional[BaseUnsigned] = attr.ib(default=None, metadata={"json": "unsigned"})
@property
def unsigned(self) -> BaseUnsigned:
if not self._unsigned:
self._unsigned = BaseUnsigned()
return self._unsigned
@unsigned.setter
def unsigned(self, value: BaseUnsigned) -> None:
self._unsigned = value
python-0.20.4/mautrix/types/event/redaction.py 0000664 0000000 0000000 00000002063 14547234302 0021423 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
from attr import dataclass
import attr
from ..primitive import EventID
from ..util import SerializableAttrs
from .base import BaseRoomEvent, BaseUnsigned
@dataclass
class RedactionEventContent(SerializableAttrs):
"""The content of an m.room.redaction event"""
reason: str = None
@dataclass
class RedactionEvent(BaseRoomEvent, SerializableAttrs):
"""A m.room.redaction event"""
content: RedactionEventContent
redacts: EventID
_unsigned: Optional[BaseUnsigned] = attr.ib(default=None, metadata={"json": "unsigned"})
@property
def unsigned(self) -> BaseUnsigned:
if not self._unsigned:
self._unsigned = BaseUnsigned()
return self._unsigned
@unsigned.setter
def unsigned(self, value: BaseUnsigned) -> None:
self._unsigned = value
python-0.20.4/mautrix/types/event/state.py 0000664 0000000 0000000 00000023770 14547234302 0020603 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, Optional, Union
from attr import dataclass
import attr
from ..primitive import JSON, ContentURI, EventID, RoomAlias, RoomID, UserID
from ..util import Obj, SerializableAttrs, SerializableEnum, deserializer, field
from .base import BaseRoomEvent, BaseUnsigned
from .encrypted import EncryptionAlgorithm
from .type import EventType, RoomType
@dataclass
class NotificationPowerLevels(SerializableAttrs):
room: int = 50
@dataclass
class PowerLevelStateEventContent(SerializableAttrs):
"""The content of a power level event."""
users: Dict[UserID, int] = attr.ib(default=attr.Factory(dict), metadata={"omitempty": False})
users_default: int = 0
events: Dict[EventType, int] = attr.ib(
default=attr.Factory(dict), metadata={"omitempty": False}
)
events_default: int = 0
notifications: NotificationPowerLevels = attr.ib(factory=lambda: NotificationPowerLevels())
state_default: int = 50
invite: int = 50
kick: int = 50
ban: int = 50
redact: int = 50
def get_user_level(self, user_id: UserID) -> int:
return int(self.users.get(user_id, self.users_default))
def set_user_level(self, user_id: UserID, level: int) -> None:
if level == self.users_default:
del self.users[user_id]
else:
self.users[user_id] = level
def ensure_user_level(self, user_id: UserID, level: int) -> bool:
if self.get_user_level(user_id) != level:
self.set_user_level(user_id, level)
return True
return False
def get_event_level(self, event_type: EventType) -> int:
return int(
self.events.get(
event_type, (self.state_default if event_type.is_state else self.events_default)
)
)
def set_event_level(self, event_type: EventType, level: int) -> None:
if level == self.state_default if event_type.is_state else self.events_default:
del self.events[event_type]
else:
self.events[event_type] = level
def ensure_event_level(self, event_type: EventType, level: int) -> bool:
if self.get_event_level(event_type) != level:
self.set_event_level(event_type, level)
return True
return False
class Membership(SerializableEnum):
"""
The membership state of a user in a room as specified in section `8.4 Room membership`_ of the
spec.
.. _8.4 Room membership: https://spec.matrix.org/v1.2/client-server-api/#room-membership
"""
JOIN = "join"
LEAVE = "leave"
INVITE = "invite"
BAN = "ban"
KNOCK = "knock"
@dataclass
class MemberStateEventContent(SerializableAttrs):
"""The content of a membership event. `Spec link`_
.. _Spec link: https://spec.matrix.org/v1.2/client-server-api/#mroommember"""
membership: Membership = Membership.LEAVE
avatar_url: ContentURI = None
displayname: str = None
is_direct: bool = False
reason: str = None
third_party_invite: JSON = None
@dataclass
class CanonicalAliasStateEventContent(SerializableAttrs):
"""
The content of a ``m.room.canonical_alias`` event (:attr:`EventType.ROOM_CANONICAL_ALIAS`).
This event is used to inform the room about which alias should be considered the canonical one,
and which other aliases point to the room. This could be for display purposes or as suggestion
to users which alias to use to advertise and access the room.
See also: `m.room.canonical_alias in the spec`_
.. _m.room.canonical_alias in the spec: https://spec.matrix.org/v1.2/client-server-api/#mroomcanonical_alias
"""
canonical_alias: RoomAlias = attr.ib(default=None, metadata={"json": "alias"})
alt_aliases: List[RoomAlias] = attr.ib(factory=lambda: [])
@dataclass
class RoomNameStateEventContent(SerializableAttrs):
name: str = None
@dataclass
class RoomTopicStateEventContent(SerializableAttrs):
topic: str = None
@dataclass
class RoomAvatarStateEventContent(SerializableAttrs):
url: Optional[ContentURI] = None
class JoinRule(SerializableEnum):
PUBLIC = "public"
KNOCK = "knock"
RESTRICTED = "restricted"
INVITE = "invite"
PRIVATE = "private"
KNOCK_RESTRICTED = "knock_restricted"
class JoinRestrictionType(SerializableEnum):
ROOM_MEMBERSHIP = "m.room_membership"
@dataclass
class JoinRestriction(SerializableAttrs):
type: JoinRestrictionType
room_id: Optional[RoomID] = None
@dataclass
class JoinRulesStateEventContent(SerializableAttrs):
join_rule: JoinRule
allow: Optional[List[JoinRestriction]] = None
@dataclass
class RoomPinnedEventsStateEventContent(SerializableAttrs):
pinned: List[EventID] = None
@dataclass
class RoomTombstoneStateEventContent(SerializableAttrs):
body: str = None
replacement_room: RoomID = None
@dataclass
class RoomEncryptionStateEventContent(SerializableAttrs):
algorithm: EncryptionAlgorithm = None
rotation_period_ms: int = 604800000
rotation_period_msgs: int = 100
@dataclass
class RoomPredecessor(SerializableAttrs):
room_id: RoomID = None
event_id: EventID = None
@dataclass
class RoomCreateStateEventContent(SerializableAttrs):
room_version: str = "1"
federate: bool = field(json="m.federate", omit_default=True, default=True)
predecessor: Optional[RoomPredecessor] = None
type: Optional[RoomType] = None
@dataclass
class SpaceChildStateEventContent(SerializableAttrs):
via: List[str] = None
order: str = ""
suggested: bool = False
@dataclass
class SpaceParentStateEventContent(SerializableAttrs):
via: List[str] = None
canonical: bool = False
StateEventContent = Union[
PowerLevelStateEventContent,
MemberStateEventContent,
CanonicalAliasStateEventContent,
RoomNameStateEventContent,
RoomAvatarStateEventContent,
RoomTopicStateEventContent,
RoomPinnedEventsStateEventContent,
RoomTombstoneStateEventContent,
RoomEncryptionStateEventContent,
RoomCreateStateEventContent,
SpaceChildStateEventContent,
SpaceParentStateEventContent,
JoinRulesStateEventContent,
Obj,
]
@dataclass
class StrippedStateUnsigned(BaseUnsigned, SerializableAttrs):
"""Unsigned information sent with state events."""
prev_content: StateEventContent = None
prev_sender: UserID = None
replaces_state: EventID = None
@dataclass
class StrippedStateEvent(SerializableAttrs):
"""Stripped state events included with some invite events."""
content: StateEventContent = None
room_id: RoomID = None
sender: UserID = None
type: EventType = None
state_key: str = None
unsigned: Optional[StrippedStateUnsigned] = None
@property
def prev_content(self) -> StateEventContent:
if self.unsigned and self.unsigned.prev_content:
return self.unsigned.prev_content
return state_event_content_map.get(self.type, Obj)()
@classmethod
def deserialize(cls, data: JSON) -> "StrippedStateEvent":
try:
event_type = EventType.find(data.get("type", None))
data.get("content", {})["__mautrix_event_type"] = event_type
(data.get("unsigned") or {}).get("prev_content", {})[
"__mautrix_event_type"
] = event_type
except ValueError:
pass
return super().deserialize(data)
@dataclass
class StateUnsigned(StrippedStateUnsigned, SerializableAttrs):
invite_room_state: Optional[List[StrippedStateEvent]] = None
state_event_content_map = {
EventType.ROOM_CREATE: RoomCreateStateEventContent,
EventType.ROOM_POWER_LEVELS: PowerLevelStateEventContent,
EventType.ROOM_MEMBER: MemberStateEventContent,
EventType.ROOM_PINNED_EVENTS: RoomPinnedEventsStateEventContent,
EventType.ROOM_CANONICAL_ALIAS: CanonicalAliasStateEventContent,
EventType.ROOM_NAME: RoomNameStateEventContent,
EventType.ROOM_AVATAR: RoomAvatarStateEventContent,
EventType.ROOM_TOPIC: RoomTopicStateEventContent,
EventType.ROOM_JOIN_RULES: JoinRulesStateEventContent,
EventType.ROOM_TOMBSTONE: RoomTombstoneStateEventContent,
EventType.ROOM_ENCRYPTION: RoomEncryptionStateEventContent,
EventType.SPACE_CHILD: SpaceChildStateEventContent,
EventType.SPACE_PARENT: SpaceParentStateEventContent,
}
@dataclass
class StateEvent(BaseRoomEvent, SerializableAttrs):
"""A room state event."""
state_key: str
content: StateEventContent
unsigned: Optional[StateUnsigned] = field(factory=lambda: StateUnsigned())
@property
def prev_content(self) -> StateEventContent:
if self.unsigned and self.unsigned.prev_content:
return self.unsigned.prev_content
return state_event_content_map.get(self.type, Obj)()
@classmethod
def deserialize(cls, data: JSON) -> "StateEvent":
try:
event_type = EventType.find(data.get("type"), t_class=EventType.Class.STATE)
data.get("content", {})["__mautrix_event_type"] = event_type
if "prev_content" in data and "prev_content" not in (data.get("unsigned") or {}):
# This if is a workaround for Conduit being extremely dumb
if data.get("unsigned", {}) is None:
data["unsigned"] = {}
data.setdefault("unsigned", {})["prev_content"] = data["prev_content"]
data.get("unsigned", {}).get("prev_content", {})["__mautrix_event_type"] = event_type
except ValueError:
return Obj(**data)
evt = super().deserialize(data)
evt.type = event_type
return evt
@staticmethod
@deserializer(StateEventContent)
def deserialize_content(data: JSON) -> StateEventContent:
evt_type = data.pop("__mautrix_event_type", None)
content_type = state_event_content_map.get(evt_type, None)
if not content_type:
return Obj(**data)
return content_type.deserialize(data)
python-0.20.4/mautrix/types/event/to_device.py 0000664 0000000 0000000 00000010137 14547234302 0021415 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import List, Optional, Union
from attr import dataclass
import attr
from ..primitive import JSON, DeviceID, IdentityKey, RoomID, SessionID, SigningKey, UserID
from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field
from .base import BaseEvent, EventType
from .beeper import BeeperRoomKeyAckEventContent
from .encrypted import EncryptedOlmEventContent, EncryptionAlgorithm
class RoomKeyWithheldCode(ExtensibleEnum):
BLACKLISTED: "RoomKeyWithheldCode" = "m.blacklisted"
UNVERIFIED: "RoomKeyWithheldCode" = "m.unverified"
UNAUTHORIZED: "RoomKeyWithheldCode" = "m.unauthorised"
UNAVAILABLE: "RoomKeyWithheldCode" = "m.unavailable"
NO_OLM_SESSION: "RoomKeyWithheldCode" = "m.no_olm"
BEEPER_REDACTED: "RoomKeyWithheldCode" = "com.beeper.redacted"
@dataclass
class RoomKeyWithheldEventContent(SerializableAttrs):
algorithm: EncryptionAlgorithm
sender_key: IdentityKey
code: RoomKeyWithheldCode
reason: Optional[str] = None
room_id: Optional[RoomID] = None
session_id: Optional[SessionID] = None
@dataclass
class RoomKeyEventContent(SerializableAttrs):
algorithm: EncryptionAlgorithm
room_id: RoomID
session_id: SessionID
session_key: str
beeper_max_age_ms: Optional[int] = field(json="com.beeper.max_age_ms", default=None)
beeper_max_messages: Optional[int] = field(json="com.beeper.max_messages", default=None)
beeper_is_scheduled: Optional[bool] = field(json="com.beeper.is_scheduled", default=False)
class KeyRequestAction(ExtensibleEnum):
REQUEST: "KeyRequestAction" = "request"
CANCEL: "KeyRequestAction" = "request_cancellation"
@dataclass
class RequestedKeyInfo(SerializableAttrs):
algorithm: EncryptionAlgorithm
room_id: RoomID
sender_key: IdentityKey
session_id: SessionID
@dataclass
class RoomKeyRequestEventContent(SerializableAttrs):
action: KeyRequestAction
requesting_device_id: DeviceID
request_id: str
body: Optional[RequestedKeyInfo] = None
@dataclass(kw_only=True)
class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs):
sender_key: IdentityKey
signing_key: SigningKey = attr.ib(metadata={"json": "sender_claimed_ed25519_key"})
forwarding_key_chain: List[str] = attr.ib(metadata={"json": "forwarding_curve25519_key_chain"})
ToDeviceEventContent = Union[
Obj,
EncryptedOlmEventContent,
RoomKeyWithheldEventContent,
RoomKeyEventContent,
RoomKeyRequestEventContent,
ForwardedRoomKeyEventContent,
BeeperRoomKeyAckEventContent,
]
to_device_event_content_map = {
EventType.TO_DEVICE_ENCRYPTED: EncryptedOlmEventContent,
EventType.ROOM_KEY_WITHHELD: RoomKeyWithheldEventContent,
EventType.ROOM_KEY_REQUEST: RoomKeyRequestEventContent,
EventType.ROOM_KEY: RoomKeyEventContent,
EventType.FORWARDED_ROOM_KEY: ForwardedRoomKeyEventContent,
EventType.BEEPER_ROOM_KEY_ACK: BeeperRoomKeyAckEventContent,
}
@dataclass
class ToDeviceEvent(BaseEvent, SerializableAttrs):
sender: UserID
content: ToDeviceEventContent
@classmethod
def deserialize(cls, data: JSON) -> "ToDeviceEvent":
try:
evt_type = EventType.find(data.get("type"), t_class=EventType.Class.TO_DEVICE)
data.setdefault("content", {})["__mautrix_event_type"] = evt_type
except ValueError:
return Obj(**data)
evt = super().deserialize(data)
evt.type = evt_type
return evt
@staticmethod
@deserializer(ToDeviceEventContent)
def deserialize_content(data: JSON) -> ToDeviceEventContent:
evt_type = data.pop("__mautrix_event_type", None)
content_type = to_device_event_content_map.get(evt_type, None)
if not content_type:
return Obj(**data)
return content_type.deserialize(data)
@dataclass
class ASToDeviceEvent(ToDeviceEvent, SerializableAttrs):
to_user_id: UserID
to_device_id: DeviceID
python-0.20.4/mautrix/types/event/type.py 0000664 0000000 0000000 00000020356 14547234302 0020441 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Optional
import json
from ..primitive import JSON
from ..util import ExtensibleEnum, Serializable, SerializableEnum
class RoomType(ExtensibleEnum):
SPACE = "m.space"
class EventType(Serializable):
"""
An immutable enum-like class that represents a specific Matrix event type.
In addition to the plain event type string, this also includes the context that the event is
used in (see: :class:`Class`). Comparing ``EventType`` instances for equality will check both
the type string and the class.
The idea behind the wrapper is that incoming event parsers will always create an ``EventType``
instance with the correct class, regardless of what the usual context for the event is. Then
when the event is being handled, the type will not be equal to ``EventType`` instances with a
different class. For example, if someone sends a non-state ``m.room.name`` event, checking
``if event.type == EventType.ROOM_NAME`` would return ``False``, because the class would be
different. Bugs caused by not checking the context of an event (especially state event vs
message event) were very common in the past, and using a wrapper like this helps prevent them.
"""
_by_event_type = {}
class Class(SerializableEnum):
"""The context that an event type is used in."""
UNKNOWN = "unknown"
STATE = "state"
"""Room state events"""
MESSAGE = "message"
"""Room message events, i.e. room events that are not state events"""
ACCOUNT_DATA = "account_data"
"""Account data events, user-specific storage used for synchronizing info between clients.
Can be global or room-specific."""
EPHEMERAL = "ephemeral"
"""Ephemeral events. Currently only typing notifications, read receipts and presence are
in this class, as custom ephemeral events are not yet possible."""
TO_DEVICE = "to_device"
"""Device-to-device events, primarily used for exchanging encryption keys"""
__slots__ = ("t", "t_class")
t: str
"""The type string of the event."""
t_class: Class
"""The context where the event appeared."""
def __init__(self, t: str, t_class: Class) -> None:
object.__setattr__(self, "t", t)
object.__setattr__(self, "t_class", t_class)
if t not in self._by_event_type:
self._by_event_type[t] = self
def serialize(self) -> JSON:
return self.t
@classmethod
def deserialize(cls, raw: JSON) -> Any:
return cls.find(raw)
@classmethod
def find(cls, t: str, t_class: Optional[Class] = None) -> "EventType":
"""
Create a new ``EventType`` instance with the given type and class.
If an ``EventType`` instance with the same type string and class has been created before,
or if no class is specified here, this will return the same instance instead of making a
new one.
Examples:
>>> from mautrix.client import Client
>>> from mautrix.types import EventType
>>> MY_CUSTOM_TYPE = EventType.find("com.example.custom_event", EventType.Class.STATE)
>>> client = Client(...)
>>> @client.on(MY_CUSTOM_TYPE)
... async def handle_event(evt): ...
Args:
t: The type string.
t_class: The class of the event type.
Returns:
An ``EventType`` instance with the given parameters.
"""
try:
return cls._by_event_type[t].with_class(t_class)
except KeyError:
return EventType(t, t_class=t_class or cls.Class.UNKNOWN)
def json(self) -> str:
return json.dumps(self.serialize())
@classmethod
def parse_json(cls, data: str) -> "EventType":
return cls.deserialize(json.loads(data))
def __setattr__(self, *args, **kwargs) -> None:
raise TypeError("EventTypes are frozen")
def __delattr__(self, *args, **kwargs) -> None:
raise TypeError("EventTypes are frozen")
def __str__(self):
return self.t
def __repr__(self):
return f'EventType("{self.t}", EventType.Class.{self.t_class.name})'
def __hash__(self):
return hash(self.t) ^ hash(self.t_class)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, EventType):
return False
return self.t == other.t and self.t_class == other.t_class
def with_class(self, t_class: Optional[Class]) -> "EventType":
"""Return a copy of this ``EventType`` with the given class. If the given class is the
same as what this instance has, or if the given class is ``None``, this returns ``self``
instead of making a copy."""
if t_class is None or self.t_class == t_class:
return self
return EventType(t=self.t, t_class=t_class)
@property
def is_message(self) -> bool:
"""A shortcut for ``type.t_class == EventType.Class.MESSAGE``"""
return self.t_class == EventType.Class.MESSAGE
@property
def is_state(self) -> bool:
"""A shortcut for ``type.t_class == EventType.Class.STATE``"""
return self.t_class == EventType.Class.STATE
@property
def is_ephemeral(self) -> bool:
"""A shortcut for ``type.t_class == EventType.Class.EPHEMERAL``"""
return self.t_class == EventType.Class.EPHEMERAL
@property
def is_account_data(self) -> bool:
"""A shortcut for ``type.t_class == EventType.Class.ACCOUNT_DATA``"""
return self.t_class == EventType.Class.ACCOUNT_DATA
@property
def is_to_device(self) -> bool:
"""A shortcut for ``type.t_class == EventType.Class.TO_DEVICE``"""
return self.t_class == EventType.Class.TO_DEVICE
_standard_types = {
EventType.Class.STATE: {
"m.room.canonical_alias": "ROOM_CANONICAL_ALIAS",
"m.room.create": "ROOM_CREATE",
"m.room.join_rules": "ROOM_JOIN_RULES",
"m.room.member": "ROOM_MEMBER",
"m.room.power_levels": "ROOM_POWER_LEVELS",
"m.room.history_visibility": "ROOM_HISTORY_VISIBILITY",
"m.room.name": "ROOM_NAME",
"m.room.topic": "ROOM_TOPIC",
"m.room.avatar": "ROOM_AVATAR",
"m.room.pinned_events": "ROOM_PINNED_EVENTS",
"m.room.tombstone": "ROOM_TOMBSTONE",
"m.room.encryption": "ROOM_ENCRYPTION",
"m.space.child": "SPACE_CHILD",
"m.space.parent": "SPACE_PARENT",
},
EventType.Class.MESSAGE: {
"m.room.redaction": "ROOM_REDACTION",
"m.room.message": "ROOM_MESSAGE",
"m.room.encrypted": "ROOM_ENCRYPTED",
"m.sticker": "STICKER",
"m.reaction": "REACTION",
"m.call.invite": "CALL_INVITE",
"m.call.candidates": "CALL_CANDIDATES",
"m.call.select_answer": "CALL_SELECT_ANSWER",
"m.call.answer": "CALL_ANSWER",
"m.call.hangup": "CALL_HANGUP",
"m.call.reject": "CALL_REJECT",
"m.call.negotiate": "CALL_NEGOTIATE",
"com.beeper.message_send_status": "BEEPER_MESSAGE_STATUS",
},
EventType.Class.EPHEMERAL: {
"m.receipt": "RECEIPT",
"m.typing": "TYPING",
"m.presence": "PRESENCE",
},
EventType.Class.ACCOUNT_DATA: {
"m.direct": "DIRECT",
"m.push_rules": "PUSH_RULES",
"m.tag": "TAG",
"m.ignored_user_list": "IGNORED_USER_LIST",
},
EventType.Class.TO_DEVICE: {
"m.room.encrypted": "TO_DEVICE_ENCRYPTED",
"m.room_key": "ROOM_KEY",
"m.room_key.withheld": "ROOM_KEY_WITHHELD",
"org.matrix.room_key.withheld": "ORG_MATRIX_ROOM_KEY_WITHHELD",
"m.room_key_request": "ROOM_KEY_REQUEST",
"m.forwarded_room_key": "FORWARDED_ROOM_KEY",
"m.dummy": "TO_DEVICE_DUMMY",
"com.beeper.room_key.ack": "BEEPER_ROOM_KEY_ACK",
},
EventType.Class.UNKNOWN: {
"__ALL__": "ALL", # This is not a real event type
},
}
for _t_class, _types in _standard_types.items():
for _t, _name in _types.items():
_event_type = EventType(t=_t, t_class=_t_class)
setattr(EventType, _name, _event_type)
python-0.20.4/mautrix/types/event/type.pyi 0000664 0000000 0000000 00000004733 14547234302 0020613 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, ClassVar, Optional
from mautrix.types import JSON, ExtensibleEnum, Serializable, SerializableEnum
class RoomType(ExtensibleEnum):
SPACE: "RoomType"
class EventType(Serializable):
class Class(SerializableEnum):
UNKNOWN = "unknown"
STATE = "state"
MESSAGE = "message"
ACCOUNT_DATA = "account_data"
EPHEMERAL = "ephemeral"
TO_DEVICE = "to_device"
_by_event_type: ClassVar[dict[str, EventType]]
ROOM_CANONICAL_ALIAS: "EventType"
ROOM_CREATE: "EventType"
ROOM_JOIN_RULES: "EventType"
ROOM_MEMBER: "EventType"
ROOM_POWER_LEVELS: "EventType"
ROOM_HISTORY_VISIBILITY: "EventType"
ROOM_NAME: "EventType"
ROOM_TOPIC: "EventType"
ROOM_AVATAR: "EventType"
ROOM_PINNED_EVENTS: "EventType"
ROOM_TOMBSTONE: "EventType"
ROOM_ENCRYPTION: "EventType"
SPACE_CHILD: "EventType"
SPACE_PARENT: "EventType"
ROOM_REDACTION: "EventType"
ROOM_MESSAGE: "EventType"
ROOM_ENCRYPTED: "EventType"
STICKER: "EventType"
REACTION: "EventType"
CALL_INVITE: "EventType"
CALL_CANDIDATES: "EventType"
CALL_SELECT_ANSWER: "EventType"
CALL_ANSWER: "EventType"
CALL_HANGUP: "EventType"
CALL_REJECT: "EventType"
CALL_NEGOTIATE: "EventType"
BEEPER_MESSAGE_STATUS: "EventType"
RECEIPT: "EventType"
TYPING: "EventType"
PRESENCE: "EventType"
DIRECT: "EventType"
PUSH_RULES: "EventType"
TAG: "EventType"
IGNORED_USER_LIST: "EventType"
TO_DEVICE_ENCRYPTED: "EventType"
TO_DEVICE_DUMMY: "EventType"
ROOM_KEY: "EventType"
ROOM_KEY_WITHHELD: "EventType"
ORG_MATRIX_ROOM_KEY_WITHHELD: "EventType"
ROOM_KEY_REQUEST: "EventType"
FORWARDED_ROOM_KEY: "EventType"
BEEPER_ROOM_KEY_ACK: "EventType"
ALL: "EventType"
is_message: bool
is_state: bool
is_ephemeral: bool
is_account_data: bool
is_to_device: bool
t: str
t_class: Class
def __init__(self, t: str, t_class: Class) -> None: ...
@classmethod
def find(cls, t: str, t_class: Optional[Class] = None) -> "EventType": ...
def serialize(self) -> JSON: ...
@classmethod
def deserialize(cls, raw: JSON) -> Any: ...
def with_class(self, t_class: Class) -> "EventType": ...
python-0.20.4/mautrix/types/event/voip.py 0000664 0000000 0000000 00000006367 14547234302 0020443 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Generic, List, Optional, TypeVar, Union
from attr import dataclass
import attr
from ..primitive import JSON, UserID
from ..util import ExtensibleEnum, SerializableAttrs
from .base import BaseRoomEvent
from .type import EventType
class CallDataType(ExtensibleEnum):
OFFER = "offer"
ANSWER = "answer"
class CallHangupReason(ExtensibleEnum):
ICE_FAILED = "ice_failed"
INVITE_TIMEOUT = "invite_timeout"
USER_HANGUP = "user_hangup"
USER_MEDIA_FAILED = "user_media_failed"
UNKNOWN_ERROR = "unknown_error"
@dataclass
class CallData(SerializableAttrs):
sdp: str
type: CallDataType
@dataclass
class CallCandidate(SerializableAttrs):
candidate: str
sdp_m_line_index: int = attr.ib(metadata={"json": "sdpMLineIndex"}, default=None)
sdp_mid: str = attr.ib(metadata={"json": "sdpMid"}, default=None)
@dataclass
class CallInviteEventContent(SerializableAttrs):
call_id: str
lifetime: int
version: int
offer: CallData
party_id: Optional[str] = None
invitee: Optional[UserID] = None
@dataclass
class CallCandidatesEventContent(SerializableAttrs):
call_id: str
version: int
candidates: List[CallCandidate]
party_id: Optional[str] = None
@dataclass
class CallSelectAnswerEventContent(SerializableAttrs):
call_id: str
version: int
party_id: str
selected_party_id: str
@dataclass
class CallAnswerEventContent(SerializableAttrs):
call_id: str
version: int
answer: CallData
party_id: Optional[str] = None
@dataclass
class CallHangupEventContent(SerializableAttrs):
call_id: str
version: int
reason: CallHangupReason = CallHangupReason.USER_HANGUP
party_id: Optional[str] = None
@dataclass
class CallRejectEventContent(SerializableAttrs):
call_id: str
version: int
party_id: str
@dataclass
class CallNegotiateEventContent(SerializableAttrs):
call_id: str
version: int
lifetime: int
party_id: str
description: CallData
type_to_class = {
EventType.CALL_INVITE: CallInviteEventContent,
EventType.CALL_CANDIDATES: CallCandidatesEventContent,
EventType.CALL_SELECT_ANSWER: CallSelectAnswerEventContent,
EventType.CALL_ANSWER: CallAnswerEventContent,
EventType.CALL_HANGUP: CallHangupEventContent,
EventType.CALL_NEGOTIATE: CallNegotiateEventContent,
EventType.CALL_REJECT: CallRejectEventContent,
}
CallEventContent = Union[
CallInviteEventContent,
CallCandidatesEventContent,
CallAnswerEventContent,
CallSelectAnswerEventContent,
CallHangupEventContent,
CallNegotiateEventContent,
CallRejectEventContent,
]
T = TypeVar("T", bound=CallEventContent)
@dataclass
class CallEvent(BaseRoomEvent, SerializableAttrs, Generic[T]):
content: T
@classmethod
def deserialize(cls, data: JSON, event_type: Optional[EventType] = None) -> "CallEvent":
event_type = event_type or EventType.find(data.get("type"))
data["content"] = type_to_class[event_type].deserialize(data["content"])
return super().deserialize(data)
python-0.20.4/mautrix/types/filter.py 0000664 0000000 0000000 00000014242 14547234302 0017621 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import List
from attr import dataclass
from .event import EventType
from .primitive import RoomID, UserID
from .util import SerializableAttrs, SerializableEnum
class EventFormat(SerializableEnum):
"""
Federation event format enum, as specified in the `create filter endpoint`_.
.. _create filter endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter
"""
CLIENT = "client"
FEDERATION = "federation"
@dataclass
class EventFilter(SerializableAttrs):
"""
Event filter object, as specified in the `create filter endpoint`_.
.. _create filter endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter
"""
limit: int = None
"""The maximum number of events to return."""
not_senders: List[UserID] = None
"""A list of sender IDs to exclude. If this list is absent then no senders are excluded
A matching sender will be excluded even if it is listed in the :attr:`senders` filter."""
not_types: List[EventType] = None
"""A list of event types to exclude. If this list is absent then no event types are excluded.
A matching type will be excluded even if it is listed in the :attr:`types` filter.
A ``'*'`` can be used as a wildcard to match any sequence of characters."""
senders: List[UserID] = None
"""A list of senders IDs to include. If this list is absent then all senders are included."""
types: List[EventType] = None
"""A list of event types to include. If this list is absent then all event types are included.
A ``'*'`` can be used as a wildcard to match any sequence of characters."""
@dataclass
class RoomEventFilter(EventFilter, SerializableAttrs):
"""
Room event filter object, as specified in the `create filter endpoint`_.
.. _create filter endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter
"""
lazy_load_members: bool = False
"""
If ``True``, enables lazy-loading of membership events. See `Lazy-loading room members`_ for more information.
.. _Lazy-loading room members:
https://matrix.org/docs/spec/client_server/r0.5.0#lazy-loading-room-members
"""
include_redundant_members: bool = False
"""
If ``True``, sends all membership events for all events, even if they have already been sent
to the client. Does not apply unless :attr:`lazy_load_members` is true.
See `Lazy-loading room members`_ for more information."""
not_rooms: List[RoomID] = None
"""A list of room IDs to exclude. If this list is absent then no rooms are excluded.
A matching room will be excluded even if it is listed in the :attr:`rooms` filter."""
rooms: List[RoomID] = None
"""A list of room IDs to include. If this list is absent then all rooms are included."""
contains_url: bool = None
"""If ``True``, includes only events with a url key in their content. If ``False``, excludes
those events. If omitted, ``url`` key is not considered for filtering."""
@dataclass
class StateFilter(RoomEventFilter, SerializableAttrs):
"""
State event filter object, as specified in the `create filter endpoint`_. Currently this is the
same as :class:`RoomEventFilter`.
.. _create filter endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter
"""
pass
@dataclass
class RoomFilter(SerializableAttrs):
"""
Room filter object, as specified in the `create filter endpoint`_.
.. _create filter endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter
"""
not_rooms: List[RoomID] = None
"""A list of room IDs to exclude. If this list is absent then no rooms are excluded.
A matching room will be excluded even if it is listed in the ``'rooms'`` filter.
This filter is applied before the filters in :attr:`ephemeral`, :attr:`state`,
:attr:`timeline` or :attr:`account_data`."""
rooms: List[RoomID] = None
"""A list of room IDs to include. If this list is absent then all rooms are included.
This filter is applied before the filters in :attr:`ephemeral`, :attr:`state`,
:attr:`timeline` or :attr:`account_data`."""
ephemeral: RoomEventFilter = None
"""The events that aren't recorded in the room history, e.g. typing and receipts,
to include for rooms."""
include_leave: bool = False
"""Include rooms that the user has left in the sync."""
state: StateFilter = None
"""The state events to include for rooms."""
timeline: RoomEventFilter = None
"""The message and state update events to include for rooms."""
account_data: RoomEventFilter = None
"""The per user account data to include for rooms."""
@dataclass
class Filter(SerializableAttrs):
"""
Base filter object, as specified in the `create filter endpoint`_.
.. _create filter endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter
"""
event_fields: List[str] = None
"""List of event fields to include. If this list is absent then all fields are included.
The entries may include ``.`` charaters to indicate sub-fields. So ``['content.body']`` will
include the ``body`` field of the ``content`` object. A literal ``.`` character in a field name
may be escaped using a ``\\``. A server may include more fields than were requested."""
event_format: EventFormat = None
"""The format to use for events. ``'client'`` will return the events in a format suitable for
clients. ``'federation'`` will return the raw event as receieved over federation. The default
is :attr:`~EventFormat.CLIENT`."""
presence: EventFilter = None
"""The presence updates to include."""
account_data: EventFilter = None
"""The user account data that isn't associated with rooms to include."""
room: RoomFilter = None
"""Filters to be applied to room data."""
python-0.20.4/mautrix/types/matrixuri.py 0000664 0000000 0000000 00000034375 14547234302 0020371 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import ClassVar, NamedTuple
from enum import Enum
import urllib.parse
from yarl import URL
from .primitive import EventID, RoomAlias, RoomID, UserID
from .util import ExtensibleEnum
class IdentifierType(Enum):
"""The type qualifier for entities in a Matrix URI."""
EVENT = "$"
USER = "@"
ROOM_ALIAS = "#"
ROOM_ID = "!"
@property
def sigil(self) -> str:
"""The sigil of the identifier, used in Matrix events, matrix.to URLs and other places"""
return _type_to_sigil[self]
@property
def uri_type_qualifier(self) -> str:
"""The type qualifier of the identifier, only used in ``matrix:`` URIs."""
return _type_to_path[self]
@classmethod
def from_sigil(cls, sigil: str) -> IdentifierType:
"""Get the IdentifierType corresponding to the given sigil."""
return _sigil_to_type[sigil]
@classmethod
def from_uri_type_qualifier(cls, uri_type_qualifier: str) -> IdentifierType:
"""Get the IdentifierType corresponding to the given ``matrix:`` URI type qualifier."""
return _path_to_type[uri_type_qualifier]
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
class URIAction(ExtensibleEnum):
"""Represents an intent for what the client should do with a Matrix URI."""
JOIN = "join"
CHAT = "chat"
_type_to_path: dict[IdentifierType, str] = {
IdentifierType.EVENT: "e",
IdentifierType.USER: "u",
IdentifierType.ROOM_ALIAS: "r",
IdentifierType.ROOM_ID: "roomid",
}
_path_to_type: dict[str, IdentifierType] = {v: k for k, v in _type_to_path.items()}
_type_to_sigil: dict[IdentifierType, str] = {it: it.value for it in IdentifierType}
_sigil_to_type: dict[str, IdentifierType] = {v: k for k, v in _type_to_sigil.items()}
class _PathPart(NamedTuple):
type: IdentifierType
identifier: str
@classmethod
def from_mxid(cls, mxid: UserID | RoomID | EventID | RoomAlias | str) -> _PathPart:
return _PathPart(type=IdentifierType.from_sigil(mxid[0]), identifier=mxid[1:])
@property
def mxid(self) -> str:
return f"{self.type.sigil}{self.identifier}"
def __str__(self) -> str:
return self.mxid
def __repr__(self) -> str:
return f"_PathPart({self.type!r}, {self.identifier!r})"
def __eq__(self, other: _PathPart) -> bool:
if not isinstance(other, _PathPart):
return False
return other.type == self.type and other.identifier == self.identifier
_uri_base = URL.build(scheme="matrix")
class MatrixURIError(ValueError):
"""Raised by :meth:`MatrixURI.parse` when parsing a URI fails."""
class MatrixURI:
"""
A container for Matrix URI data. Supports parsing and generating both ``matrix:`` URIs
and ``https://matrix.to`` URLs with the same interface.
"""
URI_BY_DEFAULT: ClassVar[bool] = False
"""Whether :meth:`__str__` should return the matrix: URI instead of matrix.to URL."""
_part1: _PathPart
_part2: _PathPart | None
via: list[str] | None
"""Servers that know about the resource. Important for room ID links."""
action: URIAction | None
"""The intent for what clients should do with the URI."""
def __init__(self) -> None:
"""Internal initializer for MatrixURI, external users should use
either :meth:`build` or :meth:`parse`."""
self._part2 = None
self.via = None
self.action = None
@classmethod
def build(
cls,
part1: RoomID | UserID | RoomAlias,
part2: EventID | None = None,
via: list[str] | None = None,
action: URIAction | None = None,
) -> MatrixURI:
"""
Construct a MatrixURI instance using an identifier.
Args:
part1: The first part of the URI, a user ID, room ID, or room alias.
part2: The second part of the URI. Only event IDs are allowed,
and only allowed when the first part is a room ID or alias.
via: Servers that know about the resource. Important for room ID links.
action: The intent for what clients should do with the URI.
Returns:
The constructed MatrixURI.
Raises:
ValueError: if one of the identifiers doesn't have a valid sigil.
Examples:
>>> from mautrix.types import MatrixURI, UserID, RoomAlias, EventID
>>> MatrixURI.build(UserID("@user:example.com")).matrix_to_url
'https://matrix.to/#/%40user%3Aexample.com'
>>> MatrixURI.build(UserID("@user:example.com")).matrix_uri
'matrix:u/user:example.com'
>>> # Picks the format based on the URI_BY_DEFAULT field.
>>> # The default value will be changed to True in a later release.
>>> str(MatrixURI.build(UserID("@user:example.com")))
'https://matrix.to/#/%40user%3Aexample.com'
>>> MatrixURI.build(RoomAlias("#room:example.com"), EventID("$abc123")).matrix_uri
'matrix:r/room:example.com/e/abc123'
"""
uri = cls()
try:
uri._part1 = _PathPart.from_mxid(part1)
except KeyError as e:
raise ValueError(f"Invalid sigil in part 1 '{part1[0]}'") from e
if uri._part1.type == IdentifierType.EVENT:
raise ValueError(f"Event ID links must have a room ID or alias too")
if part2:
try:
uri._part2 = _PathPart.from_mxid(part2)
except KeyError as e:
raise ValueError(f"Invalid sigil in part 2 '{part2[0]}'") from e
if uri._part2.type != IdentifierType.EVENT:
raise ValueError("The second part of Matrix URIs can only be an event ID")
if uri._part1.type not in (IdentifierType.ROOM_ID, IdentifierType.ROOM_ALIAS):
raise ValueError("Can't create an event ID link without a room link as the base")
uri.via = via
uri.action = action
return uri
@classmethod
def try_parse(cls, url: str | URL) -> MatrixURI | None:
"""
Try to parse a ``matrix:`` URI or ``https://matrix.to`` URL into parts.
If parsing fails, return ``None`` instead of throwing an error.
Args:
url: The URI to parse, either as a string or a :class:`yarl.URL` instance.
Returns:
The parsed data, or ``None`` if parsing failed.
"""
try:
return cls.parse(url)
except ValueError:
return None
@classmethod
def parse(cls, url: str | URL) -> MatrixURI:
"""
Parse a ``matrix:`` URI or ``https://matrix.to`` URL into parts.
Args:
url: The URI to parse, either as a string or a :class:`yarl.URL` instance.
Returns:
The parsed data.
Raises:
ValueError: if yarl fails to parse the given URL string.
MatrixURIError: if the URL isn't valid in the Matrix spec.
Examples:
>>> from mautrix.types import MatrixURI
>>> MatrixURI.parse("https://matrix.to/#/@user:example.com").user_id
'@user:example.com'
>>> MatrixURI.parse("https://matrix.to/#/#room:example.com/$abc123").event_id
'$abc123'
>>> MatrixURI.parse("matrix:r/room:example.com/e/abc123").event_id
'$abc123'
"""
url = URL(url)
if url.scheme == "matrix":
return cls._parse_matrix_uri(url)
elif url.scheme == "https" and url.host == "matrix.to":
return cls._parse_matrix_to_url(url)
else:
raise MatrixURIError("Invalid URI (not matrix: nor https://matrix.to)")
@classmethod
def _parse_matrix_to_url(cls, url: URL) -> MatrixURI:
path, *rest = url.raw_fragment.split("?", 1)
path_parts = path.split("/")
if len(path_parts) < 2:
raise MatrixURIError("matrix.to URL doesn't have enough parts")
# The first component is the blank part between the # and /
if path_parts[0] != "":
raise MatrixURIError("first component of matrix.to URL is not empty as expected")
query = urllib.parse.parse_qs(rest[0] if len(rest) > 0 else "")
uri = cls()
part1 = urllib.parse.unquote(path_parts[1])
if len(part1) < 2:
raise MatrixURIError(f"Invalid first entity '{part1}' in matrix.to URL")
try:
uri._part1 = _PathPart.from_mxid(part1)
except KeyError as e:
raise MatrixURIError(
f"Invalid sigil '{part1[0]}' in first entity of matrix.to URL"
) from e
if len(path_parts) > 2 and len(path_parts[2]) > 0:
part2 = urllib.parse.unquote(path_parts[2])
if len(part2) < 2:
raise MatrixURIError(f"Invalid second entity '{part2}' in matrix.to URL")
try:
uri._part2 = _PathPart.from_mxid(part2)
except KeyError as e:
raise MatrixURIError(
f"Invalid sigil '{part2[0]}' in second entity of matrix.to URL"
) from e
uri.via = query.get("via", None)
try:
uri.action = URIAction(query["action"])
except KeyError:
pass
return uri
@classmethod
def _parse_matrix_uri(cls, url: URL) -> MatrixURI:
components = url.raw_path.split("/")
if len(components) < 2:
raise MatrixURIError("URI doesn't contain enough parts")
try:
type1 = IdentifierType.from_uri_type_qualifier(components[0])
except KeyError as e:
raise MatrixURIError(
f"Invalid type qualifier '{components[0]}' in first entity of matrix: URI"
) from e
if not components[1]:
raise MatrixURIError("Identifier in first entity of matrix: URI is empty")
uri = cls()
uri._part1 = _PathPart(type1, components[1])
if len(components) >= 3 and components[2]:
try:
type2 = IdentifierType.from_uri_type_qualifier(components[2])
except KeyError as e:
raise MatrixURIError(
f"Invalid type qualifier '{components[2]}' in second entity of matrix: URI"
) from e
if len(components) < 4 or not components[3]:
raise MatrixURIError("Identifier in second entity of matrix: URI is empty")
uri._part2 = _PathPart(type2, components[3])
uri.via = url.query.getall("via", None)
try:
uri.action = URIAction(url.query["action"])
except KeyError:
pass
return uri
@property
def user_id(self) -> UserID | None:
"""
Get the user ID from this parsed URI.
Returns:
The user ID in this URI, or ``None`` if this is not a link to a user.
"""
if self._part1.type == IdentifierType.USER:
return UserID(self._part1.mxid)
return None
@property
def room_id(self) -> RoomID | None:
"""
Get the room ID from this parsed URI.
Returns:
The room ID in this URI, or ``None`` if this is not a link to a room (or event).
"""
if self._part1.type == IdentifierType.ROOM_ID:
return RoomID(self._part1.mxid)
return None
@property
def room_alias(self) -> RoomAlias | None:
"""
Get the room alias from this parsed URI.
Returns:
The room alias in this URI, or ``None`` if this is not a link to a room (or event).
"""
if self._part1.type == IdentifierType.ROOM_ALIAS:
return RoomAlias(self._part1.mxid)
return None
@property
def event_id(self) -> EventID | None:
"""
Get the event ID from this parsed URI.
Returns:
The event ID in this URI, or ``None`` if this is not a link to an event in a room.
"""
if (
self._part2
and (self.room_id or self.room_alias)
and self._part2.type == IdentifierType.EVENT
):
return EventID(self._part2.mxid)
return None
@property
def matrix_to_url(self) -> str:
"""
Convert this parsed URI into a ``https://matrix.to`` URL.
Returns:
The link as a matrix.to URL.
"""
url = f"https://matrix.to/#/{urllib.parse.quote(self._part1.mxid)}"
if self._part2:
url += f"/{urllib.parse.quote(self._part2.mxid)}"
qp = []
if self.via:
qp += (("via", server) for server in self.via)
if self.action:
qp.append(("action", self.action))
if qp:
url += f"?{urllib.parse.urlencode(qp)}"
return url
@property
def matrix_uri(self) -> str:
"""
Convert this parsed URI into a ``matrix:`` URI.
Returns:
The link as a ``matrix:`` URI.
"""
u = _uri_base / self._part1.type.uri_type_qualifier / self._part1.identifier
if self._part2:
u = u / self._part2.type.uri_type_qualifier / self._part2.identifier
if self.via:
u = u.update_query({"via": self.via})
if self.action:
u = u.update_query({"action": self.action.value})
return str(u)
def __str__(self) -> str:
if self.URI_BY_DEFAULT:
return self.matrix_uri
else:
return self.matrix_to_url
def __repr__(self) -> str:
parts = ", ".join(f"{part!r}" for part in (self._part1, self._part2) if part)
return f"MatrixURI({parts}, via={self.via!r}, action={self.action!r})"
def __eq__(self, other: MatrixURI) -> bool:
"""
Checks equality between two parsed Matrix URIs. The order of the via parameters is ignored,
but otherwise everything has to match exactly.
"""
if not isinstance(other, MatrixURI):
return False
return (
self._part1 == other._part1
and self._part2 == other._part2
and set(self.via or []) == set(other.via or [])
and self.action == other.action
)
python-0.20.4/mautrix/types/matrixuri_test.py 0000664 0000000 0000000 00000013067 14547234302 0021423 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import NamedTuple
import pytest
from .matrixuri import IdentifierType, MatrixURI, MatrixURIError, URIAction, _PathPart
from .primitive import EventID, RoomAlias, RoomID, UserID
def test_basic_parse_uri() -> None:
for test in basic_tests:
assert MatrixURI.parse(test.uri) == test.parsed
def test_basic_stringify_uri() -> None:
for test in basic_tests:
assert test.uri == test.parsed.matrix_uri
def test_basic_parse_url() -> None:
for test in basic_tests:
assert MatrixURI.parse(test.url) == test.parsed
def test_basic_stringify_url() -> None:
for test in basic_tests:
assert test.url == test.parsed.matrix_to_url
def test_basic_build() -> None:
for test in basic_tests:
assert MatrixURI.build(*test.params) == test.parsed
def test_parse_unescaped() -> None:
assert MatrixURI.parse("https://matrix.to/#/#hello:world").room_alias == "#hello:world"
def test_parse_trailing_slash() -> None:
assert MatrixURI.parse("https://matrix.to/#/#hello:world/").room_alias == "#hello:world"
assert MatrixURI.parse("matrix:r/hello:world/").room_alias == "#hello:world"
def test_parse_errors() -> None:
tests = [
"https://example.com",
"matrix:invalid/foo",
"matrix:hello world",
"matrix:/roomid",
"matrix:roomid/",
"matrix:roomid/foo/e/",
"matrix:roomid/foo/e",
"https://matrix.to",
"https://matrix.to/#/",
"https://matrix.to/#foo/#hello:world",
"https://matrix.to/#/#hello:world/hmm",
]
for test in tests:
with pytest.raises(MatrixURIError):
print(MatrixURI.parse(test))
def test_build_errors() -> None:
with pytest.raises(ValueError):
MatrixURI.build("hello world")
with pytest.raises(ValueError):
MatrixURI.build(EventID("$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"))
with pytest.raises(ValueError):
MatrixURI.build(
UserID("@user:example.org"),
EventID("$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"),
)
with pytest.raises(ValueError):
MatrixURI.build(
RoomID("!room:example.org"),
RoomID("!anotherroom:example.com"),
)
with pytest.raises(ValueError):
MatrixURI.build(
RoomID("!room:example.org"),
"hmm",
)
def _make_parsed(
part1: _PathPart,
part2: _PathPart | None = None,
via: list[str] | None = None,
action: URIAction | None = None,
) -> MatrixURI:
uri = MatrixURI()
uri._part1 = part1
uri._part2 = part2
uri.via = via
uri.action = action
return uri
class BasicTestItems(NamedTuple):
url: str
uri: str
parsed: MatrixURI
params: tuple[RoomID | UserID | RoomAlias, EventID | None, list[str] | None, URIAction | None]
basic_tests = [
BasicTestItems(
"https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org",
"matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org",
_make_parsed(_PathPart(IdentifierType.ROOM_ID, "7NdBVvkd4aLSbgKt9RXl:example.org")),
(RoomID("!7NdBVvkd4aLSbgKt9RXl:example.org"), None, None, None),
),
BasicTestItems(
"https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org?via=maunium.net&via=matrix.org",
"matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org",
_make_parsed(
_PathPart(IdentifierType.ROOM_ID, "7NdBVvkd4aLSbgKt9RXl:example.org"),
via=["maunium.net", "matrix.org"],
),
(RoomID("!7NdBVvkd4aLSbgKt9RXl:example.org"), None, ["maunium.net", "matrix.org"], None),
),
BasicTestItems(
"https://matrix.to/#/%23someroom%3Aexample.org",
"matrix:r/someroom:example.org",
_make_parsed(_PathPart(IdentifierType.ROOM_ALIAS, "someroom:example.org")),
(RoomAlias("#someroom:example.org"), None, None, None),
),
BasicTestItems(
"https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s",
"matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s",
_make_parsed(
_PathPart(IdentifierType.ROOM_ID, "7NdBVvkd4aLSbgKt9RXl:example.org"),
_PathPart(IdentifierType.EVENT, "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"),
),
(
RoomID("!7NdBVvkd4aLSbgKt9RXl:example.org"),
EventID("$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"),
None,
None,
),
),
BasicTestItems(
"https://matrix.to/#/%23someroom%3Aexample.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s",
"matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s",
_make_parsed(
_PathPart(IdentifierType.ROOM_ALIAS, "someroom:example.org"),
_PathPart(IdentifierType.EVENT, "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"),
),
(
RoomAlias("#someroom:example.org"),
EventID("$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"),
None,
None,
),
),
BasicTestItems(
"https://matrix.to/#/%40user%3Aexample.org",
"matrix:u/user:example.org",
_make_parsed(_PathPart(IdentifierType.USER, "user:example.org")),
(UserID("@user:example.org"), None, None, None),
),
]
python-0.20.4/mautrix/types/media.py 0000664 0000000 0000000 00000004651 14547234302 0017416 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
from attr import dataclass
from .primitive import ContentURI
from .util import SerializableAttrs, field
@dataclass
class MediaRepoConfig(SerializableAttrs):
"""
Matrix media repo config. See `GET /_matrix/media/v3/config`_.
.. _GET /_matrix/media/v3/config:
https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3config
"""
upload_size: int = field(default=50 * 1024 * 1024, json="m.upload.size")
@dataclass
class OpenGraphImage(SerializableAttrs):
url: ContentURI = field(default=None, json="og:image")
mimetype: str = field(default=None, json="og:image:type")
height: int = field(default=None, json="og:image:width")
width: int = field(default=None, json="og:image:height")
size: int = field(default=None, json="matrix:image:size")
@dataclass
class OpenGraphVideo(SerializableAttrs):
url: ContentURI = field(default=None, json="og:video")
mimetype: str = field(default=None, json="og:video:type")
height: int = field(default=None, json="og:video:width")
width: int = field(default=None, json="og:video:height")
size: int = field(default=None, json="matrix:video:size")
@dataclass
class OpenGraphAudio(SerializableAttrs):
url: ContentURI = field(default=None, json="og:audio")
mimetype: str = field(default=None, json="og:audio:type")
@dataclass
class MXOpenGraph(SerializableAttrs):
"""
Matrix URL preview response. See `GET /_matrix/media/v3/preview_url`_.
.. _GET /_matrix/media/v3/preview_url:
https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url
"""
title: str = field(default=None, json="og:title")
description: str = field(default=None, json="og:description")
image: OpenGraphImage = field(default=None, flatten=True)
video: OpenGraphVideo = field(default=None, flatten=True)
audio: OpenGraphAudio = field(default=None, flatten=True)
@dataclass
class MediaCreateResponse(SerializableAttrs):
"""
Matrix media create response including MSC3870
"""
content_uri: ContentURI
unused_expired_at: Optional[int] = None
unstable_upload_url: Optional[str] = field(default=None, json="com.beeper.msc3870.upload_url")
python-0.20.4/mautrix/types/misc.py 0000664 0000000 0000000 00000006437 14547234302 0017276 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import List, NamedTuple, NewType, Optional
from enum import Enum
from attr import dataclass
import attr
from .event import Event, StateEvent
from .primitive import BatchID, ContentURI, EventID, RoomAlias, RoomID, SyncToken, UserID
from .util import SerializableAttrs
@dataclass
class DeviceLists(SerializableAttrs):
changed: List[UserID] = attr.ib(factory=lambda: [])
left: List[UserID] = attr.ib(factory=lambda: [])
def __bool__(self) -> bool:
return bool(self.changed or self.left)
@dataclass
class DeviceOTKCount(SerializableAttrs):
signed_curve25519: int = 0
curve25519: int = 0
class RoomCreatePreset(Enum):
"""
Room creation preset, as specified in the `createRoom endpoint`_
.. _createRoom endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom
"""
PRIVATE = "private_chat"
TRUSTED_PRIVATE = "trusted_private_chat"
PUBLIC = "public_chat"
class RoomDirectoryVisibility(Enum):
"""
Room directory visibility, as specified in the `createRoom endpoint`_
.. _createRoom endpoint:
https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom
"""
PRIVATE = "private"
PUBLIC = "public"
class PaginationDirection(Enum):
"""Pagination direction used in various endpoints that support pagination."""
FORWARD = "f"
BACKWARD = "b"
@dataclass
class RoomAliasInfo(SerializableAttrs):
"""
Room alias query result, as specified in the `alias resolve endpoint`_
.. _alias resolve endpoint:
https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3directoryroomroomalias
"""
room_id: RoomID = None
"""The room ID for this room alias."""
servers: List[str] = None
"""A list of servers that are aware of this room alias."""
DirectoryPaginationToken = NewType("DirectoryPaginationToken", str)
@dataclass
class PublicRoomInfo(SerializableAttrs):
room_id: RoomID
num_joined_members: int
world_readable: bool
guest_can_join: bool
name: str = None
topic: str = None
avatar_url: ContentURI = None
aliases: List[RoomAlias] = None
canonical_alias: RoomAlias = None
@dataclass
class RoomDirectoryResponse(SerializableAttrs):
chunk: List[PublicRoomInfo]
next_batch: DirectoryPaginationToken = None
prev_batch: DirectoryPaginationToken = None
total_room_count_estimate: int = None
PaginatedMessages = NamedTuple(
"PaginatedMessages", start=SyncToken, end=SyncToken, events=List[Event]
)
@dataclass
class EventContext(SerializableAttrs):
end: SyncToken
start: SyncToken
event: Event
events_after: List[Event]
events_before: List[Event]
state: List[StateEvent]
@dataclass
class BatchSendResponse(SerializableAttrs):
state_event_ids: List[EventID]
event_ids: List[EventID]
insertion_event_id: EventID
batch_event_id: EventID
next_batch_id: BatchID
base_insertion_event_id: Optional[EventID] = None
@dataclass
class BeeperBatchSendResponse(SerializableAttrs):
event_ids: List[EventID]
python-0.20.4/mautrix/types/primitive.py 0000664 0000000 0000000 00000004202 14547234302 0020337 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, NewType, Union
JSON = NewType("JSON", Union[str, int, float, bool, None, Dict[str, "JSON"], List["JSON"]])
JSON.__doc__ = "A union type that covers all JSON-serializable data."
UserID = NewType("UserID", str)
UserID.__doc__ = "A Matrix user ID (``@user:example.com``)"
EventID = NewType("EventID", str)
EventID.__doc__ = "A Matrix event ID (``$base64`` or ``$legacyid:example.com``)"
RoomID = NewType("RoomID", str)
RoomID.__doc__ = "An internal Matrix room ID (``!randomstring:example.com``)"
RoomAlias = NewType("RoomAlias", str)
RoomAlias.__doc__ = "A Matrix room address (``#alias:example.com``)"
FilterID = NewType("FilterID", str)
FilterID.__doc__ = """
A filter ID returned by ``POST /filter`` (:meth:`mautrix.client.ClientAPI.create_filter`)
"""
BatchID = NewType("BatchID", str)
BatchID.__doc__ = """
A message batch ID returned by ``POST /batch_send`` (:meth:`mautrix.appservice.IntentAPI.batch_send`)
"""
ContentURI = NewType("ContentURI", str)
ContentURI.__doc__ = """
A Matrix `content URI`_, used by the content repository.
.. _content URI:
https://spec.matrix.org/v1.2/client-server-api/#matrix-content-mxc-uris
"""
SyncToken = NewType("SyncToken", str)
SyncToken.__doc__ = """
A ``next_batch`` token from a ``/sync`` response (:meth:`mautrix.client.ClientAPI.sync`)
"""
DeviceID = NewType("DeviceID", str)
DeviceID.__doc__ = "A Matrix device ID. Arbitrary, potentially client-specified string."
SessionID = NewType("SessionID", str)
SessionID.__doc__ = """
A `Megolm`_ session ID.
.. _Megolm:
https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/megolm.md
"""
SigningKey = NewType("SigningKey", str)
SigningKey.__doc__ = "A ed25519 public key as unpadded base64"
IdentityKey = NewType("IdentityKey", str)
IdentityKey.__doc__ = "A curve25519 public key as unpadded base64"
Signature = NewType("Signature", str)
Signature.__doc__ = "An ed25519 signature as unpadded base64"
python-0.20.4/mautrix/types/push_rules.py 0000664 0000000 0000000 00000004063 14547234302 0020525 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import List, Optional, Union
from attr import dataclass
import attr
from .primitive import JSON, RoomID, UserID
from .util import ExtensibleEnum, SerializableAttrs, deserializer
PushRuleID = Union[RoomID, UserID, str]
class PushActionType(ExtensibleEnum):
NOTIFY = "notify"
DONT_NOTIFY = "dont_notify"
COALESCE = "coalesce"
@dataclass
class PushActionDict(SerializableAttrs):
set_tweak: Optional[str] = None
value: Optional[str] = None
PushAction = Union[PushActionDict, PushActionType]
@deserializer(PushAction)
def deserialize_push_action(data: JSON) -> PushAction:
if isinstance(data, str):
return PushActionType(data)
else:
return PushActionDict.deserialize(data)
class PushOperator(ExtensibleEnum):
EQUALS = "=="
LESS_THAN = "<"
GREATER_THAN = ">"
LESS_THAN_OR_EQUAL = "<="
GREATER_THAN_OR_EQUAL = ">="
class PushRuleScope(ExtensibleEnum):
GLOBAL = "global"
class PushConditionKind(ExtensibleEnum):
EVENT_MATCH = "event_match"
CONTAINS_DISPLAY_NAME = "contains_display_name"
ROOM_MEMBER_COUNT = "room_member_count"
SENDER_NOTIFICATION_PERMISSION = "sender_notification_permission"
class PushRuleKind(ExtensibleEnum):
OVERRIDE = "override"
SENDER = "sender"
ROOM = "room"
CONTENT = "content"
UNDERRIDE = "underride"
@dataclass
class PushCondition(SerializableAttrs):
kind: PushConditionKind
key: Optional[str] = None
pattern: Optional[str] = None
operator: PushOperator = attr.ib(
default=PushOperator.EQUALS, metadata={"json": "is", "omitdefault": True}
)
@dataclass
class PushRule(SerializableAttrs):
rule_id: PushRuleID
default: bool
enabled: bool
actions: List[PushAction]
conditions: List[PushCondition] = attr.ib(factory=lambda: [])
pattern: Optional[str] = None
python-0.20.4/mautrix/types/users.py 0000664 0000000 0000000 00000001362 14547234302 0017474 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import List, NamedTuple
from attr import dataclass
from .event import Membership
from .primitive import ContentURI, UserID
from .util import SerializableAttrs
@dataclass
class Member(SerializableAttrs):
membership: Membership = None
avatar_url: ContentURI = None
displayname: str = None
@dataclass
class User(SerializableAttrs):
user_id: UserID
avatar_url: ContentURI = None
displayname: str = None
class UserSearchResults(NamedTuple):
results: List[User]
limit: int
python-0.20.4/mautrix/types/util/ 0000775 0000000 0000000 00000000000 14547234302 0016734 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/types/util/__init__.py 0000664 0000000 0000000 00000000330 14547234302 0021041 0 ustar 00root root 0000000 0000000 from .enum import ExtensibleEnum
from .obj import Lst, Obj
from .serializable import Serializable, SerializableEnum, SerializerError
from .serializable_attrs import SerializableAttrs, deserializer, field, serializer
python-0.20.4/mautrix/types/util/enum.py 0000664 0000000 0000000 00000010036 14547234302 0020252 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Iterable, Type, cast
from ..primitive import JSON
from .serializable import Serializable
def _is_descriptor(obj):
return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
class ExtensibleEnumMeta(type):
_by_value: dict[Any, ExtensibleEnum]
_by_key: dict[str, ExtensibleEnum]
def __new__(
mcs: Type[ExtensibleEnumMeta],
name: str,
bases: tuple[Type, ...],
classdict: dict[str, Any],
) -> Type[ExtensibleEnum]:
create = [
(key, val)
for key, val in classdict.items()
if not key.startswith("_") and not _is_descriptor(val)
]
classdict = {
key: val
for key, val in classdict.items()
if key.startswith("_") or _is_descriptor(val)
}
classdict["_by_value"] = {}
classdict["_by_key"] = {}
enum_class = cast(Type["ExtensibleEnum"], super().__new__(mcs, name, bases, classdict))
for key, val in create:
ExtensibleEnum.__new__(enum_class, val).key = key
return enum_class
def __bool__(cls: Type["ExtensibleEnum"]) -> bool:
return True
def __contains__(cls: Type["ExtensibleEnum"], value: Any) -> bool:
if isinstance(value, cls):
return value in cls._by_value.values()
else:
return value in cls._by_value.keys()
def __getattr__(cls: Type["ExtensibleEnum"], name: Any) -> "ExtensibleEnum":
try:
return cls._by_key[name]
except KeyError:
raise AttributeError(name) from None
def __setattr__(cls: Type["ExtensibleEnum"], key: str, value: Any) -> None:
if key.startswith("_"):
return super().__setattr__(key, value)
if not isinstance(value, cls):
value = cls(value)
value.key = key
def __getitem__(cls: Type["ExtensibleEnum"], name: str) -> "ExtensibleEnum":
try:
return cls._by_key[name]
except KeyError:
raise KeyError(name) from None
def __setitem__(cls: Type["ExtensibleEnum"], key: str, value: Any) -> None:
return cls.__setattr__(cls, key, value)
def __iter__(cls: Type["ExtensibleEnum"]) -> Iterable["ExtensibleEnum"]:
return cls._by_key.values().__iter__()
def __len__(cls: Type["ExtensibleEnum"]) -> int:
return len(cls._by_key)
def __repr__(cls: Type["ExtensibleEnum"]) -> str:
return f""
class ExtensibleEnum(Serializable, metaclass=ExtensibleEnumMeta):
_by_value: dict[Any, ExtensibleEnum]
_by_key: dict[str, ExtensibleEnum]
_inited: bool
_key: str | None
value: Any
def __init__(self, value: Any) -> None:
if getattr(self, "_inited", False):
return
self.value = value
self._key = None
self._inited = True
def __new__(cls: Type[ExtensibleEnum], value: Any) -> ExtensibleEnum:
try:
return cls._by_value[value]
except KeyError as e:
self = super().__new__(cls)
self.__objclass__ = cls
self.__init__(value)
cls._by_value[value] = self
return self
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:
if self._key:
return f"<{self.__class__.__name__}.{self._key}: {self.value!r}>"
else:
return f"{self.__class__.__name__}({self.value!r})"
@property
def key(self) -> str:
return self._key
@key.setter
def key(self, key: str) -> None:
self._key = key
self._by_key[key] = self
def serialize(self) -> JSON:
return self.value
@classmethod
def deserialize(cls, raw: JSON) -> Any:
return cls(raw)
python-0.20.4/mautrix/types/util/enum_test.py 0000664 0000000 0000000 00000003610 14547234302 0021311 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from attr import dataclass
from .enum import ExtensibleEnum
from .serializable_attrs import SerializableAttrs
def test_extensible_enum_int():
class Hello(ExtensibleEnum):
HI = 1
HMM = 2
assert Hello.HI.value == 1
assert Hello.HI.key == "HI"
assert 1 in Hello
assert Hello(1) == Hello.HI
assert Hello["HMM"] == Hello.HMM
assert len(Hello) == 2
hello3 = Hello(3)
assert hello3.value == 3
assert not hello3.key
Hello.YAY = hello3
assert len(Hello) == 3
assert hello3.key == "YAY"
@dataclass
class Wrapper(SerializableAttrs):
hello: Hello
assert Wrapper.deserialize({"hello": 1}).hello == Hello.HI
assert Wrapper.deserialize({"hello": 2}).hello == Hello.HMM
assert Wrapper.deserialize({"hello": 3}).hello == hello3
assert Wrapper.deserialize({"hello": 4}).hello.value == 4
def test_extensible_enum_str():
class Hello(ExtensibleEnum):
HI = "hi"
HMM = "🤔"
assert Hello.HI.value == "hi"
assert Hello.HI.key == "HI"
assert "🤔" in Hello
assert Hello("🤔") == Hello.HMM
assert Hello["HI"] == Hello.HI
assert len(Hello) == 2
hello3 = Hello("yay")
assert hello3.value == "yay"
assert not hello3.key
Hello.YAY = hello3
assert len(Hello) == 3
assert hello3.key == "YAY"
@dataclass
class Wrapper(SerializableAttrs):
hello: Hello
assert Wrapper.deserialize({"hello": "hi"}).hello == Hello.HI
assert Wrapper.deserialize({"hello": "🤔"}).hello == Hello.HMM
assert Wrapper.deserialize({"hello": "yay"}).hello == hello3
assert Wrapper.deserialize({"hello": "thonk"}).hello.value == "thonk"
python-0.20.4/mautrix/types/util/obj.py 0000664 0000000 0000000 00000004357 14547234302 0020071 0 ustar 00root root 0000000 0000000 # From https://github.com/Lonami/dumbot/blob/master/dumbot.py
# Modified to add Serializable base
from __future__ import annotations
from ..primitive import JSON
from .serializable import AbstractSerializable, Serializable
class Obj(AbstractSerializable):
""""""
def __init__(self, **kwargs):
self.__dict__ = {
k: Obj(**v) if isinstance(v, dict) else (Lst(v) if isinstance(v, list) else v)
for k, v in kwargs.items()
}
def __getattr__(self, name):
name = name.rstrip("_")
obj = self.__dict__.get(name)
if obj is None:
obj = Obj()
self.__dict__[name] = obj
return obj
def __getitem__(self, name):
return self.__dict__.get(name)
def __setitem__(self, key, value):
self.__dict__[key] = value
def __str__(self):
return str(self.serialize())
def __repr__(self):
return repr(self.serialize())
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
def __bool__(self):
return bool(self.__dict__)
def __contains__(self, item):
return item in self.__dict__
def popitem(self):
return self.__dict__.popitem()
def get(self, key, default=None):
obj = self.__dict__.get(key)
if obj is None:
return default
else:
return obj
def serialize(self) -> dict[str, JSON]:
return {
k: v.serialize() if isinstance(v, Serializable) else v
for k, v in self.__dict__.items()
}
@classmethod
def deserialize(cls, data: dict[str, JSON]) -> Obj:
return cls(**data)
class Lst(list, AbstractSerializable):
def __init__(self, iterable=()):
list.__init__(
self,
(
Obj(**x) if isinstance(x, dict) else (Lst(x) if isinstance(x, list) else x)
for x in iterable
),
)
def __repr__(self):
return super().__repr__()
def serialize(self) -> list[JSON]:
return [v.serialize() if isinstance(v, Serializable) else v for v in self]
@classmethod
def deserialize(cls, data: list[JSON]) -> Lst:
return cls(data)
python-0.20.4/mautrix/types/util/serializable.py 0000664 0000000 0000000 00000006260 14547234302 0021760 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Type, TypeVar, Union
from abc import ABC, abstractmethod
from enum import Enum
import json
from ..primitive import JSON
SerializableSubtype = TypeVar("SerializableSubtype", bound="SerializableAttrs")
class Serializable:
"""Serializable is the base class for types with custom JSON serializers."""
def serialize(self) -> JSON:
"""Convert this object into objects directly serializable with `json`."""
raise NotImplementedError()
@classmethod
def deserialize(cls: Type[SerializableSubtype], raw: JSON) -> SerializableSubtype:
"""Convert the given data parsed from JSON into an object of this type."""
raise NotImplementedError()
def json(self) -> str:
"""Serialize this object and dump the output as JSON."""
return json.dumps(self.serialize())
@classmethod
def parse_json(cls: Type[SerializableSubtype], data: Union[str, bytes]) -> SerializableSubtype:
"""Parse the given string as JSON and deserialize the result into this type."""
return cls.deserialize(json.loads(data))
class SerializerError(Exception):
"""
SerializerErrors are raised if something goes wrong during serialization or deserialization.
"""
pass
class UnknownSerializationError(SerializerError):
def __init__(self) -> None:
super().__init__("Unknown serialization error")
class AbstractSerializable(ABC, Serializable):
"""
An abstract Serializable that adds ``@abstractmethod`` decorators.
"""
@abstractmethod
def serialize(self) -> JSON:
pass
@classmethod
@abstractmethod
def deserialize(cls: Type[SerializableSubtype], raw: JSON) -> SerializableSubtype:
pass
class SerializableEnum(Serializable, Enum):
"""
A simple Serializable implementation for Enums.
Examples:
>>> class MyEnum(SerializableEnum):
... FOO = "foo value"
... BAR = "hmm"
>>> MyEnum.FOO.serialize()
"foo value"
>>> MyEnum.BAR.json()
'"hmm"'
"""
def __init__(self, _) -> None:
"""
A fake ``__init__`` to stop the type checker from complaining.
Enum's ``__new__`` overrides this.
"""
super().__init__()
def serialize(self) -> str:
"""
Convert this object into objects directly serializable with `json`, i.e. return the value
set to this enum value.
"""
return self.value
@classmethod
def deserialize(cls: Type[SerializableSubtype], raw: str) -> SerializableSubtype:
"""
Convert the given data parsed from JSON into an object of this type, i.e. find the enum
value for the given string using ``cls(raw)``.
"""
try:
return cls(raw)
except ValueError as e:
raise SerializerError() from e
def __str__(self):
return str(self.value)
def __repr__(self):
return f"{self.__class__.__name__}.{self.name}"
python-0.20.4/mautrix/types/util/serializable_attrs.py 0000664 0000000 0000000 00000032037 14547234302 0023176 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Callable, Dict, Iterator, NewType, Optional, Tuple, Type, TypeVar, Union
from uuid import UUID
import copy
import logging
import attr
from ..primitive import JSON
from .obj import Lst, Obj
from .serializable import (
AbstractSerializable,
Serializable,
SerializableSubtype,
SerializerError,
UnknownSerializationError,
)
T = TypeVar("T")
T2 = TypeVar("T2")
Serializer = NewType("Serializer", Callable[[T], JSON])
Deserializer = NewType("Deserializer", Callable[[JSON], T])
serializer_map: Dict[Type[T], Serializer] = {
UUID: str,
}
deserializer_map: Dict[Type[T], Deserializer] = {
UUID: UUID,
}
META_JSON = "json"
META_FLATTEN = "flatten"
META_HIDDEN = "hidden"
META_IGNORE_ERRORS = "ignore_errors"
META_OMIT_EMPTY = "omitempty"
META_OMIT_DEFAULT = "omitdefault"
log = logging.getLogger("mau.attrs")
def field(
default: Any = attr.NOTHING,
factory: Optional[Callable[[], Any]] = None,
json: Optional[str] = None,
flatten: bool = False,
hidden: bool = False,
ignore_errors: bool = False,
omit_empty: bool = True,
omit_default: bool = False,
metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
A wrapper around :meth:`attr.ib` to conveniently add SerializableAttrs metadata fields.
Args:
default: Same as attr.ib, the default value for the field.
factory: Same as attr.ib, a factory function that creates the default value.
json: The JSON key used for de/serializing the object.
flatten: Set to flatten subfields inside this field to be a part of the parent object in
serialized objects. When deserializing, the input data will be deserialized into both
the parent and child fields, so the classes should ignore unknown keys.
hidden: Set to always omit the key from serialized objects.
ignore_errors: Set to ignore type errors while deserializing.
omit_empty: Set to omit the key from serialized objects if the value is ``None``.
omit_default: Set to omit the key from serialized objects if the value is equal to the
default.
metadata: Additional metadata for attr.ib.
**kwargs: Additional keyword arguments for attr.ib.
Returns:
The decorator function returned by attr.ib.
Examples:
>>> from attr import dataclass
>>> from mautrix.types import SerializableAttrs, field
>>> @dataclass
... class SomeData(SerializableAttrs):
... my_field: str = field(json="com.example.namespaced_field", default="hi")
...
>>> SomeData().serialize()
{'com.example.namespaced_field': 'hi'}
>>> SomeData.deserialize({"com.example.namespaced_field": "hmm"})
SomeData(my_field='hmm')
"""
custom_meta = {
META_JSON: json,
META_FLATTEN: flatten,
META_HIDDEN: hidden,
META_IGNORE_ERRORS: ignore_errors,
META_OMIT_EMPTY: omit_empty,
META_OMIT_DEFAULT: omit_default,
}
metadata = metadata or {}
metadata.update({k: v for k, v in custom_meta.items() if v is not None})
return attr.ib(default=default, factory=factory, metadata=metadata, **kwargs)
def serializer(elem_type: Type[T]) -> Callable[[Serializer], Serializer]:
"""
Define a custom serialization function for the given type.
Args:
elem_type: The type to define the serializer for.
Returns:
Decorator for the function. The decorator will simply add the function to a map of
deserializers and return the function.
Examples:
>>> from datetime import datetime
>>> from mautrix.types import serializer, JSON
>>> @serializer(datetime)
... def serialize_datetime(dt: datetime) -> JSON:
... return dt.timestamp()
"""
def decorator(func: Serializer) -> Serializer:
serializer_map[elem_type] = func
return func
return decorator
def deserializer(elem_type: Type[T]) -> Callable[[Deserializer], Deserializer]:
"""
Define a custom deserialization function for a given type hint.
Args:
elem_type: The type hint to define the deserializer for.
Returns:
Decorator for the function. The decorator will simply add the function to a map of
deserializers and return the function.
Examples:
>>> from datetime import datetime
>>> from mautrix.types import deserializer, JSON
>>> @deserializer(datetime)
... def deserialize_datetime(data: JSON) -> datetime:
... return datetime.fromtimestamp(data)
"""
def decorator(func: Deserializer) -> Deserializer:
deserializer_map[elem_type] = func
return func
return decorator
def _fields(attrs_type: Type[T], only_if_flatten: bool = None) -> Iterator[Tuple[str, Type[T2]]]:
for field in attr.fields(attrs_type):
if field.metadata.get(META_HIDDEN, False):
continue
if only_if_flatten is None or field.metadata.get(META_FLATTEN, False) == only_if_flatten:
yield field.metadata.get(META_JSON, field.name), field
immutable = int, str, float, bool, type(None)
def _safe_default(val: T) -> T:
if isinstance(val, immutable):
return val
elif val is attr.NOTHING:
return None
elif isinstance(val, attr.Factory):
if val.takes_self:
# TODO implement?
return None
else:
return val.factory()
return copy.copy(val)
def _dict_to_attrs(
attrs_type: Type[T], data: JSON, default: Optional[T] = None, default_if_empty: bool = False
) -> T:
data = data or {}
unrecognized = {}
new_items = {
field_meta.name.lstrip("_"): _try_deserialize(field_meta, data)
for _, field_meta in _fields(attrs_type, only_if_flatten=True)
}
fields = dict(_fields(attrs_type, only_if_flatten=False))
for key, value in data.items():
try:
field_meta = fields[key]
except KeyError:
unrecognized[key] = value
continue
name = field_meta.name.lstrip("_")
try:
new_items[name] = _try_deserialize(field_meta, value)
except UnknownSerializationError as e:
raise SerializerError(
f"Failed to deserialize {value} into key {name} of {attrs_type.__name__}"
) from e
except SerializerError:
raise
except Exception as e:
raise SerializerError(
f"Failed to deserialize {value} into key {name} of {attrs_type.__name__}"
) from e
if len(new_items) == 0 and default_if_empty and default is not attr.NOTHING:
return _safe_default(default)
try:
obj = attrs_type(**new_items)
except TypeError as e:
for key, field_meta in _fields(attrs_type):
if field_meta.default is attr.NOTHING and key not in new_items:
log.debug("Failed to deserialize %s into %s", data, attrs_type.__name__)
json_key = field_meta.metadata.get(META_JSON, key)
raise SerializerError(
f"Missing value for required key {json_key} in {attrs_type.__name__}"
) from e
raise UnknownSerializationError() from e
if len(unrecognized) > 0:
obj.unrecognized_ = unrecognized
return obj
def _try_deserialize(field, value: JSON) -> T:
try:
return _deserialize(field.type, value, field.default)
except SerializerError:
if not field.metadata.get(META_IGNORE_ERRORS, False):
raise
except (TypeError, ValueError, KeyError) as e:
if not field.metadata.get(META_IGNORE_ERRORS, False):
raise UnknownSerializationError() from e
def _has_custom_deserializer(cls) -> bool:
return issubclass(cls, Serializable) and getattr(cls.deserialize, "__func__") != getattr(
SerializableAttrs.deserialize, "__func__"
)
def _deserialize(cls: Type[T], value: JSON, default: Optional[T] = None) -> T:
if value is None:
return _safe_default(default)
try:
deser = deserializer_map[cls]
except KeyError:
pass
else:
return deser(value)
supertype = getattr(cls, "__supertype__", None)
if supertype:
cls = supertype
try:
deser = deserializer_map[supertype]
except KeyError:
pass
else:
return deser(value)
if attr.has(cls):
if _has_custom_deserializer(cls):
return cls.deserialize(value)
return _dict_to_attrs(cls, value, default, default_if_empty=True)
elif cls == Any or cls == JSON:
return value
elif isinstance(cls, type) and issubclass(cls, Serializable):
return cls.deserialize(value)
type_class = getattr(cls, "__origin__", None)
args = getattr(cls, "__args__", None)
if type_class is Union:
if len(args) == 2 and isinstance(None, args[1]):
return _deserialize(args[0], value, default)
elif type_class == list:
(item_cls,) = args
return [_deserialize(item_cls, item) for item in value]
elif type_class == set:
(item_cls,) = args
return {_deserialize(item_cls, item) for item in value}
elif type_class == dict:
key_cls, val_cls = args
return {
_deserialize(key_cls, key): _deserialize(val_cls, item) for key, item in value.items()
}
if isinstance(value, list):
return Lst(value)
elif isinstance(value, dict):
return Obj(**value)
return value
def _actual_type(cls: Type[T]) -> Type[T]:
if cls is None:
return cls
if getattr(cls, "__origin__", None) is Union:
if len(cls.__args__) == 2 and isinstance(None, cls.__args__[1]):
return cls.__args__[0]
return cls
def _get_serializer(cls: Type[T]) -> Serializer:
return serializer_map.get(_actual_type(cls), _serialize)
def _serialize_attrs_field(data: T, field: T2) -> JSON:
field_val = getattr(data, field.name)
if field_val is None:
if not field.metadata.get(META_OMIT_EMPTY, True):
if field.default is not attr.NOTHING:
field_val = _safe_default(field.default)
else:
return attr.NOTHING
if field.metadata.get(META_OMIT_DEFAULT, False) and field_val == field.default:
return attr.NOTHING
return _get_serializer(field.type)(field_val)
def _attrs_to_dict(data: T) -> JSON:
new_dict = {}
for json_name, field in _fields(data.__class__):
if not json_name:
continue
serialized = _serialize_attrs_field(data, field)
if serialized is not attr.NOTHING:
if field.metadata.get(META_FLATTEN, False) and isinstance(serialized, dict):
new_dict.update(serialized)
else:
new_dict[json_name] = serialized
try:
new_dict.update(data.unrecognized_)
except (AttributeError, TypeError):
pass
return new_dict
def _serialize(val: Any) -> JSON:
if isinstance(val, Serializable):
return val.serialize()
elif isinstance(val, (tuple, list, set)):
return [_serialize(subval) for subval in val]
elif isinstance(val, dict):
return {_serialize(subkey): _serialize(subval) for subkey, subval in val.items()}
elif attr.has(val.__class__):
return _attrs_to_dict(val)
return val
class SerializableAttrs(AbstractSerializable):
"""
An abstract :class:`Serializable` that assumes the subclass is an attrs dataclass.
Examples:
>>> from attr import dataclass
>>> from mautrix.types import SerializableAttrs
>>> @dataclass
... class Foo(SerializableAttrs):
... index: int
... field: Optional[str] = None
"""
unrecognized_: Dict[str, JSON]
def __init__(self):
self.unrecognized_ = {}
@classmethod
def deserialize(cls: Type[SerializableSubtype], data: JSON) -> SerializableSubtype:
return _dict_to_attrs(cls, data)
def serialize(self) -> JSON:
return _attrs_to_dict(self)
def get(self, item: str, default: Any = None) -> Any:
try:
return self[item]
except KeyError:
return default
def __contains__(self, item: str) -> bool:
return hasattr(self, item) or item in getattr(self, "unrecognized_", {})
def __getitem__(self, item: str) -> Any:
try:
return getattr(self, item)
except AttributeError:
try:
return self.unrecognized_[item]
except AttributeError:
self.unrecognized_ = {}
raise KeyError(item)
def __setitem__(self, item: str, value: Any) -> None:
if hasattr(self, item):
setattr(self, item, value)
else:
try:
self.unrecognized_[item] = value
except AttributeError:
self.unrecognized_ = {item: value}
python-0.20.4/mautrix/types/util/serializable_attrs_test.py 0000664 0000000 0000000 00000021670 14547234302 0024236 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import List, Optional
from attr import dataclass
import pytest
from ..primitive import JSON
from .serializable_attrs import Serializable, SerializableAttrs, SerializerError, field
def test_simple_class():
@dataclass
class Foo(SerializableAttrs):
hello: int
world: str
serialized = {"hello": 5, "world": "hi"}
deserialized = Foo.deserialize(serialized)
assert deserialized == Foo(5, "hi")
assert deserialized.serialize() == serialized
with pytest.raises(SerializerError):
Foo.deserialize({"world": "hi"})
def test_default():
@dataclass
class Default(SerializableAttrs):
no_default: int
defaultful_value: int = 5
d1 = Default.deserialize({"no_default": 3})
assert d1.no_default == 3
assert d1.defaultful_value == 5
d2 = Default.deserialize({"no_default": 4, "defaultful_value": 6})
assert d2.no_default == 4
assert d2.defaultful_value == 6
def test_factory():
@dataclass
class Factory(SerializableAttrs):
manufactured_value: List[str] = field(factory=lambda: ["hi"])
assert Factory.deserialize({}).manufactured_value == ["hi"]
assert Factory.deserialize({"manufactured_value": None}).manufactured_value == ["hi"]
factory1 = Factory.deserialize({})
factory2 = Factory.deserialize({})
assert factory1.manufactured_value is not factory2.manufactured_value
assert Factory.deserialize({"manufactured_value": ["bye"]}).manufactured_value == ["bye"]
def test_hidden():
@dataclass
class HiddenField(SerializableAttrs):
visible: str
hidden: int = field(hidden=True, default=5)
deserialized_hidden = HiddenField.deserialize({"visible": "yay", "hidden": 4})
assert deserialized_hidden.hidden == 5
assert deserialized_hidden.unrecognized_["hidden"] == 4
assert HiddenField("hmm", 5).serialize() == {"visible": "hmm"}
def test_ignore_errors():
@dataclass
class Something(SerializableAttrs):
required: bool
@dataclass
class Wrapper(SerializableAttrs):
something: Optional[Something] = field(ignore_errors=True)
@dataclass
class ErroringWrapper(SerializableAttrs):
something: Optional[Something] = field(ignore_errors=False)
assert Wrapper.deserialize({"something": {"required": True}}) == Wrapper(Something(True))
assert Wrapper.deserialize({"something": {}}) == Wrapper(None)
with pytest.raises(SerializerError):
ErroringWrapper.deserialize({"something": 5})
with pytest.raises(SerializerError):
ErroringWrapper.deserialize({"something": {}})
def test_json_key_override():
@dataclass
class Meow(SerializableAttrs):
meow: int = field(json="fi.mau.namespaced_meow")
serialized = {"fi.mau.namespaced_meow": 123}
deserialized = Meow.deserialize(serialized)
assert deserialized == Meow(123)
assert deserialized.serialize() == serialized
def test_omit_empty():
@dataclass
class OmitEmpty(SerializableAttrs):
omitted: Optional[int] = field(omit_empty=True)
@dataclass
class DontOmitEmpty(SerializableAttrs):
not_omitted: Optional[int] = field(omit_empty=False)
assert OmitEmpty(None).serialize() == {}
assert OmitEmpty(0).serialize() == {"omitted": 0}
assert DontOmitEmpty(None).serialize() == {"not_omitted": None}
assert DontOmitEmpty(0).serialize() == {"not_omitted": 0}
def test_omit_default():
@dataclass
class OmitDefault(SerializableAttrs):
omitted: int = field(default=5, omit_default=True)
@dataclass
class DontOmitDefault(SerializableAttrs):
not_omitted: int = 5
assert OmitDefault().serialize() == {}
assert OmitDefault(5).serialize() == {}
assert OmitDefault(6).serialize() == {"omitted": 6}
assert DontOmitDefault().serialize() == {"not_omitted": 5}
assert DontOmitDefault(5).serialize() == {"not_omitted": 5}
assert DontOmitDefault(6).serialize() == {"not_omitted": 6}
def test_flatten():
from mautrix.types import ContentURI
@dataclass
class OpenGraphImage(SerializableAttrs):
url: ContentURI = field(default=None, json="og:image")
mimetype: str = field(default=None, json="og:image:type")
height: int = field(default=None, json="og:image:width")
width: int = field(default=None, json="og:image:height")
size: int = field(default=None, json="matrix:image:size")
@dataclass
class OpenGraphVideo(SerializableAttrs):
url: ContentURI = field(default=None, json="og:video")
mimetype: str = field(default=None, json="og:video:type")
height: int = field(default=None, json="og:video:width")
width: int = field(default=None, json="og:video:height")
size: int = field(default=None, json="matrix:video:size")
@dataclass
class OpenGraphAudio(SerializableAttrs):
url: ContentURI = field(default=None, json="og:audio")
mimetype: str = field(default=None, json="og:audio:type")
@dataclass
class MXOpenGraph(SerializableAttrs):
title: str = field(default=None, json="og:title")
description: str = field(default=None, json="og:description")
image: OpenGraphImage = field(default=None, flatten=True)
video: OpenGraphVideo = field(default=None, flatten=True)
audio: OpenGraphAudio = field(default=None, flatten=True)
example_com_preview = {
"og:title": "Example Domain",
"og:description": "Example Domain\n\nThis domain is for use in illustrative examples in "
"documents. You may use this domain in literature without prior "
"coordination or asking for permission.\n\nMore information...",
}
google_com_preview = {
"og:title": "Google",
"og:image": "mxc://maunium.net/2021-06-20_jkscuJXkHjvzNaUJ",
"og:description": "Search\n\nImages\n\nMaps\n\nPlay\n\nYouTube\n\nNews\n\nGmail\n\n"
"Drive\n\nMore\n\n\u00bb\n\nWeb History\n\n|\n\nSettings\n\n|\n\n"
"Sign in\n\nAdvanced search\n\nGoogle offered in:\n\nDeutsch\n\n"
"AdvertisingPrograms\n\nBusiness Solutions",
"og:image:width": 128,
"og:image:height": 128,
"og:image:type": "image/png",
"matrix:image:size": 3428,
}
example_com_deserialized = MXOpenGraph.deserialize(example_com_preview)
assert example_com_deserialized.title == "Example Domain"
assert example_com_deserialized.image is None
assert example_com_deserialized.video is None
assert example_com_deserialized.audio is None
google_com_deserialized = MXOpenGraph.deserialize(google_com_preview)
assert google_com_deserialized.title == "Google"
assert google_com_deserialized.image is not None
assert google_com_deserialized.image.url == "mxc://maunium.net/2021-06-20_jkscuJXkHjvzNaUJ"
assert google_com_deserialized.image.width == 128
assert google_com_deserialized.image.height == 128
assert google_com_deserialized.image.mimetype == "image/png"
assert google_com_deserialized.image.size == 3428
assert google_com_deserialized.video is None
assert google_com_deserialized.audio is None
def test_flatten_arbitrary_serializable():
@dataclass
class CustomSerializable(Serializable):
is_hello: bool = True
def serialize(self) -> JSON:
if self.is_hello:
return {"hello": "world"}
return {}
@classmethod
def deserialize(cls, raw: JSON) -> "CustomSerializable":
return CustomSerializable(is_hello=raw.get("hello") == "world")
@dataclass
class Thing(SerializableAttrs):
custom: CustomSerializable = field(flatten=True)
another: int = field(default=5)
thing_1 = {
"hello": "world",
"another": 6,
}
thing_1_deserialized = Thing.deserialize(thing_1)
assert thing_1_deserialized.custom.is_hello is True
assert thing_1_deserialized.another == 6
thing_2 = {
"another": 4,
}
thing_2_deserialized = Thing.deserialize(thing_2)
assert thing_2_deserialized.custom.is_hello is False
assert thing_2_deserialized.another == 4
assert Thing(custom=CustomSerializable(is_hello=True)).serialize() == {
"hello": "world",
"another": 5,
}
def test_flatten_optional():
@dataclass
class OptionalThing(SerializableAttrs):
key: str
@classmethod
def deserialize(cls, data: JSON) -> Optional["OptionalThing"]:
if "key" in data:
return super().deserialize(data)
return None
@dataclass
class ThingWithOptional(SerializableAttrs):
optional: OptionalThing = field(flatten=True)
another_field: int = 2
assert ThingWithOptional.deserialize({}).optional is None
assert ThingWithOptional.deserialize({"key": "hi"}).optional.key == "hi"
python-0.20.4/mautrix/types/versions.py 0000664 0000000 0000000 00000011127 14547234302 0020203 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, NamedTuple, Optional, Union
from enum import IntEnum
import re
from attr import dataclass
import attr
from . import JSON
from .util import Serializable, SerializableAttrs
class VersionFormat(IntEnum):
UNKNOWN = -1
LEGACY = 0
MODERN = 1
def __repr__(self) -> str:
return f"VersionFormat.{self.name}"
legacy_version_regex = re.compile(r"^r(\d+)\.(\d+)\.(\d+)$")
modern_version_regex = re.compile(r"^v(\d+)\.(\d+)$")
@attr.dataclass(frozen=True)
class Version(Serializable):
format: VersionFormat
major: int
minor: int
patch: int
raw: str
def __str__(self) -> str:
if self.format == VersionFormat.MODERN:
return f"v{self.major}.{self.minor}"
elif self.format == VersionFormat.LEGACY:
return f"r{self.major}.{self.minor}.{self.patch}"
else:
return self.raw
def serialize(self) -> JSON:
return str(self)
@classmethod
def deserialize(cls, raw: JSON) -> "Version":
assert isinstance(raw, str), "versions must be strings"
if modern := modern_version_regex.fullmatch(raw):
major, minor = modern.groups()
return Version(VersionFormat.MODERN, int(major), int(minor), 0, raw)
elif legacy := legacy_version_regex.fullmatch(raw):
major, minor, patch = legacy.groups()
return Version(VersionFormat.LEGACY, int(major), int(minor), int(patch), raw)
else:
return Version(VersionFormat.UNKNOWN, 0, 0, 0, raw)
class SpecVersions:
R010 = Version.deserialize("r0.1.0")
R020 = Version.deserialize("r0.2.0")
R030 = Version.deserialize("r0.3.0")
R040 = Version.deserialize("r0.4.0")
R050 = Version.deserialize("r0.5.0")
R060 = Version.deserialize("r0.6.0")
R061 = Version.deserialize("r0.6.1")
V11 = Version.deserialize("v1.1")
V12 = Version.deserialize("v1.2")
V13 = Version.deserialize("v1.3")
V14 = Version.deserialize("v1.4")
V15 = Version.deserialize("v1.5")
V16 = Version.deserialize("v1.6")
V17 = Version.deserialize("v1.7")
@dataclass
class VersionsResponse(SerializableAttrs):
versions: List[Version]
unstable_features: Dict[str, bool] = attr.ib(factory=lambda: {})
def supports(self, thing: Union[Version, str]) -> Optional[bool]:
"""
Check if the versions response contains the given spec version or unstable feature.
Args:
thing: The spec version (as a :class:`Version` or string)
or unstable feature name (as a string) to check.
Returns:
``True`` if the exact version or unstable feature is supported,
``False`` if it's not supported,
``None`` for unstable features which are not included in the response at all.
"""
if isinstance(thing, Version):
return thing in self.versions
elif (parsed_version := Version.deserialize(thing)).format != VersionFormat.UNKNOWN:
return parsed_version in self.versions
return self.unstable_features.get(thing)
def supports_at_least(self, version: Union[Version, str]) -> bool:
"""
Check if the versions response contains the given spec version or any higher version.
Args:
version: The spec version as a :class:`Version` or a string.
Returns:
``True`` if a version equal to or higher than the given version is found,
``False`` otherwise.
"""
if isinstance(version, str):
version = Version.deserialize(version)
return any(v for v in self.versions if v > version)
@property
def latest_version(self) -> Version:
return max(self.versions)
@property
def has_legacy_versions(self) -> bool:
"""
Check if the response contains any legacy (r0.x.y) versions.
.. deprecated:: 0.16.10
:meth:`supports_at_least` and :meth:`supports` methods are now preferred.
"""
return any(v for v in self.versions if v.format == VersionFormat.LEGACY)
@property
def has_modern_versions(self) -> bool:
"""
Check if the response contains any modern (v1.1 or higher) versions.
.. deprecated:: 0.16.10
:meth:`supports_at_least` and :meth:`supports` methods are now preferred.
"""
return self.supports_at_least(SpecVersions.V11)
python-0.20.4/mautrix/util/ 0000775 0000000 0000000 00000000000 14547234302 0015570 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/util/__init__.py 0000664 0000000 0000000 00000000776 14547234302 0017713 0 ustar 00root root 0000000 0000000 __all__ = [
# Directory modules
"async_db",
"config",
"db",
"formatter",
"logging",
# File modules
"async_body",
"async_getter_lock",
"background_task",
"bridge_state",
"color_log",
"ffmpeg",
"file_store",
"format_duration",
"magic",
"manhole",
"markdown",
"message_send_checkpoint",
"opt_prometheus",
"program",
"signed_token",
"simple_lock",
"simple_template",
"utf16_surrogate",
"variation_selector",
]
python-0.20.4/mautrix/util/async_body.py 0000664 0000000 0000000 00000006361 14547234302 0020302 0 ustar 00root root 0000000 0000000 # Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import AsyncGenerator, Union
import logging
import aiohttp
AsyncBody = AsyncGenerator[Union[bytes, bytearray, memoryview], None]
async def async_iter_bytes(data: bytearray | bytes, chunk_size: int = 1024**2) -> AsyncBody:
"""
Return memory views into a byte array in chunks. This is used to prevent aiohttp from copying
the entire request body.
Args:
data: The underlying data to iterate through.
chunk_size: How big each returned chunk should be.
Returns:
An async generator that yields the given data in chunks.
"""
with memoryview(data) as mv:
for i in range(0, len(data), chunk_size):
yield mv[i : i + chunk_size]
class FileTooLargeError(Exception):
def __init__(self, max_size: int) -> None:
super().__init__(f"File size larger than maximum ({max_size / 1024 / 1024} MiB)")
_default_dl_log = logging.getLogger("mau.util.download")
async def read_response_chunks(
resp: aiohttp.ClientResponse, max_size: int, log: logging.Logger = _default_dl_log
) -> bytearray:
"""
Read the body from an aiohttp response in chunks into a mutable bytearray.
Args:
resp: The aiohttp response object to read the body from.
max_size: The maximum size to read. FileTooLargeError will be raised if the Content-Length
is higher than this, or if the body exceeds this size during reading.
log: A logger for logging download status.
Returns:
The body data as a byte array.
Raises:
FileTooLargeError: if the body is larger than the provided max_size.
"""
content_length = int(resp.headers.get("Content-Length", "0"))
if 0 < max_size < content_length:
raise FileTooLargeError(max_size)
size_str = "unknown length" if content_length == 0 else f"{content_length} bytes"
log.info(f"Reading file download response with {size_str} (max: {max_size})")
data = bytearray(content_length)
mv = memoryview(data) if content_length > 0 else None
read_size = 0
max_size += 1
while True:
block = await resp.content.readany()
if not block:
break
max_size -= len(block)
if max_size <= 0:
raise FileTooLargeError(max_size)
if len(data) >= read_size + len(block):
mv[read_size : read_size + len(block)] = block
elif len(data) > read_size:
log.warning("File being downloaded is bigger than expected")
mv[read_size:] = block[: len(data) - read_size]
mv.release()
mv = None
data.extend(block[len(data) - read_size :])
else:
if mv is not None:
mv.release()
mv = None
data.extend(block)
read_size += len(block)
if mv is not None:
mv.release()
log.info(f"Successfully read {read_size} bytes of file download response")
return data
__all__ = ["AsyncBody", "FileTooLargeError", "async_iter_bytes", "async_read_bytes"]
python-0.20.4/mautrix/util/async_db/ 0000775 0000000 0000000 00000000000 14547234302 0017352 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/util/async_db/__init__.py 0000664 0000000 0000000 00000001705 14547234302 0021466 0 ustar 00root root 0000000 0000000 from mautrix import __optional_imports__
from .connection import LoggingConnection as Connection
from .database import Database
from .errors import (
DatabaseException,
DatabaseNotOwned,
ForeignTablesFound,
UnsupportedDatabaseVersion,
)
from .scheme import Scheme
from .upgrade import UpgradeTable, register_upgrade
try:
from .asyncpg import PostgresDatabase
except ImportError:
if __optional_imports__:
raise
PostgresDatabase = None
try:
from aiosqlite import Cursor as SQLiteCursor
from .aiosqlite import SQLiteDatabase
except ImportError:
if __optional_imports__:
raise
SQLiteDatabase = None
SQLiteCursor = None
__all__ = [
"Database",
"UpgradeTable",
"register_upgrade",
"PostgresDatabase",
"SQLiteDatabase",
"SQLiteCursor",
"Connection",
"Scheme",
"DatabaseException",
"DatabaseNotOwned",
"UnsupportedDatabaseVersion",
"ForeignTablesFound",
]
python-0.20.4/mautrix/util/async_db/aiosqlite.py 0000664 0000000 0000000 00000015412 14547234302 0021721 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, AsyncContextManager
from contextlib import asynccontextmanager
import asyncio
import logging
import os
import re
import sqlite3
from yarl import URL
import aiosqlite
from .connection import LoggingConnection
from .database import Database
from .scheme import Scheme
from .upgrade import UpgradeTable
POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)")
class TxnConnection(aiosqlite.Connection):
def __init__(self, path: str, **kwargs) -> None:
def connector() -> sqlite3.Connection:
return sqlite3.connect(
path, detect_types=sqlite3.PARSE_DECLTYPES, isolation_level=None, **kwargs
)
super().__init__(connector, iter_chunk_size=64)
@asynccontextmanager
async def transaction(self) -> None:
await self.execute("BEGIN TRANSACTION")
try:
yield
except Exception:
await self.rollback()
raise
else:
await self.commit()
def __execute(self, query: str, *args: Any):
query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query)
return super().execute(query, args)
async def execute(
self, query: str, *args: Any, timeout: float | None = None
) -> aiosqlite.Cursor:
return await self.__execute(query, *args)
async def executemany(
self, query: str, *args: Any, timeout: float | None = None
) -> aiosqlite.Cursor:
query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query)
return await super().executemany(query, *args)
async def fetch(
self, query: str, *args: Any, timeout: float | None = None
) -> list[sqlite3.Row]:
async with self.__execute(query, *args) as cursor:
return list(await cursor.fetchall())
async def fetchrow(
self, query: str, *args: Any, timeout: float | None = None
) -> sqlite3.Row | None:
async with self.__execute(query, *args) as cursor:
return await cursor.fetchone()
async def fetchval(
self, query: str, *args: Any, column: int = 0, timeout: float | None = None
) -> Any:
row = await self.fetchrow(query, *args)
if row is None:
return None
return row[column]
class SQLiteDatabase(Database):
scheme = Scheme.SQLITE
_parent: SQLiteDatabase | None
_pool: asyncio.Queue[TxnConnection]
_stopped: bool
_conns: int
_init_commands: list[str]
def __init__(
self,
url: URL,
upgrade_table: UpgradeTable,
db_args: dict[str, Any] | None = None,
log: logging.Logger | None = None,
owner_name: str | None = None,
ignore_foreign_tables: bool = True,
) -> None:
super().__init__(
url,
db_args=db_args,
upgrade_table=upgrade_table,
log=log,
owner_name=owner_name,
ignore_foreign_tables=ignore_foreign_tables,
)
self._parent = None
self._path = url.path
self._pool = asyncio.Queue(self._db_args.pop("min_size", 1))
self._db_args.pop("max_size", None)
self._stopped = False
self._conns = 0
self._init_commands = self._add_missing_pragmas(self._db_args.pop("init_commands", []))
@staticmethod
def _add_missing_pragmas(init_commands: list[str]) -> list[str]:
has_foreign_keys = False
has_journal_mode = False
has_synchronous = False
has_busy_timeout = False
for cmd in init_commands:
if "PRAGMA" not in cmd:
continue
if "foreign_keys" in cmd:
has_foreign_keys = True
elif "journal_mode" in cmd:
has_journal_mode = True
elif "synchronous" in cmd:
has_synchronous = True
elif "busy_timeout" in cmd:
has_busy_timeout = True
if not has_foreign_keys:
init_commands.append("PRAGMA foreign_keys = ON")
if not has_journal_mode:
init_commands.append("PRAGMA journal_mode = WAL")
if not has_synchronous and "PRAGMA journal_mode = WAL" in init_commands:
init_commands.append("PRAGMA synchronous = NORMAL")
if not has_busy_timeout:
init_commands.append("PRAGMA busy_timeout = 5000")
return init_commands
def override_pool(self, db: Database) -> None:
assert isinstance(db, SQLiteDatabase)
self._parent = db
async def start(self) -> None:
if self._parent:
await super().start()
return
if self._conns:
raise RuntimeError("database pool has already been started")
elif self._stopped:
raise RuntimeError("database pool can't be restarted")
self.log.debug(f"Connecting to {self.url}")
self.log.debug(f"Database connection init commands: {self._init_commands}")
if os.path.exists(self._path):
if not os.access(self._path, os.W_OK):
self.log.warning("Database file doesn't seem writable")
elif not os.access(os.path.dirname(os.path.abspath(self._path)), os.W_OK):
self.log.warning("Database file doesn't exist and directory doesn't seem writable")
for _ in range(self._pool.maxsize):
conn = await TxnConnection(self._path, **self._db_args)
if self._init_commands:
cur = await conn.cursor()
for command in self._init_commands:
self.log.trace("Executing init command: %s", command)
await cur.execute(command)
await conn.commit()
conn.row_factory = sqlite3.Row
self._pool.put_nowait(conn)
self._conns += 1
await super().start()
async def stop(self) -> None:
if self._parent:
return
self._stopped = True
while self._conns > 0:
conn = await self._pool.get()
self._conns -= 1
await conn.close()
def acquire(self) -> AsyncContextManager[LoggingConnection]:
if self._parent:
return self._parent.acquire()
return self._acquire()
@asynccontextmanager
async def _acquire(self) -> LoggingConnection:
if self._stopped:
raise RuntimeError("database pool has been stopped")
conn = await self._pool.get()
try:
yield LoggingConnection(self.scheme, conn, self.log)
finally:
self._pool.put_nowait(conn)
Database.schemes["sqlite"] = SQLiteDatabase
Database.schemes["sqlite3"] = SQLiteDatabase
python-0.20.4/mautrix/util/async_db/asyncpg.py 0000664 0000000 0000000 00000007070 14547234302 0021374 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any
from contextlib import asynccontextmanager
import asyncio
import logging
import sys
import traceback
from yarl import URL
import asyncpg
from .connection import LoggingConnection
from .database import Database
from .scheme import Scheme
from .upgrade import UpgradeTable
class PostgresDatabase(Database):
scheme = Scheme.POSTGRES
_pool: asyncpg.pool.Pool | None
_pool_override: bool
_exit_on_ice: bool
def __init__(
self,
url: URL,
upgrade_table: UpgradeTable,
db_args: dict[str, Any] = None,
log: logging.Logger | None = None,
owner_name: str | None = None,
ignore_foreign_tables: bool = True,
) -> None:
if url.scheme in ("cockroach", "cockroachdb"):
self.scheme = Scheme.COCKROACH
# Send postgres scheme to asyncpg
url = url.with_scheme("postgres")
self._exit_on_ice = True
if db_args:
self._exit_on_ice = db_args.pop("meow_exit_on_ice", True)
db_args.pop("init_commands", None)
super().__init__(
url,
db_args=db_args,
upgrade_table=upgrade_table,
log=log,
owner_name=owner_name,
ignore_foreign_tables=ignore_foreign_tables,
)
self._pool = None
self._pool_override = False
def override_pool(self, db: PostgresDatabase) -> None:
self._pool = db._pool
self._pool_override = True
async def start(self) -> None:
if not self._pool_override:
if self._pool:
raise RuntimeError("Database has already been started")
self._db_args["loop"] = asyncio.get_running_loop()
log_url = self.url
if log_url.password:
log_url = log_url.with_password("password-redacted")
self.log.debug(f"Connecting to {log_url}")
self._pool = await asyncpg.create_pool(str(self.url), **self._db_args)
await super().start()
@property
def pool(self) -> asyncpg.pool.Pool:
if not self._pool:
raise RuntimeError("Database has not been started")
return self._pool
async def stop(self) -> None:
if not self._pool_override and self._pool is not None:
await self._pool.close()
async def _handle_exception(self, err: Exception) -> None:
if self._exit_on_ice and isinstance(err, asyncpg.InternalClientError):
pre_stack = traceback.format_stack()[:-2]
post_stack = traceback.format_exception(err)
header = post_stack[0]
post_stack = post_stack[1:]
self.log.critical(
"Got asyncpg internal client error, exiting...\n%s%s%s",
header,
"".join(pre_stack),
"".join(post_stack),
)
sys.exit(26)
@asynccontextmanager
async def acquire(self) -> LoggingConnection:
async with self.pool.acquire() as conn:
yield LoggingConnection(
self.scheme, conn, self.log, handle_exception=self._handle_exception
)
Database.schemes["postgres"] = PostgresDatabase
Database.schemes["postgresql"] = PostgresDatabase
Database.schemes["cockroach"] = PostgresDatabase
Database.schemes["cockroachdb"] = PostgresDatabase
python-0.20.4/mautrix/util/async_db/connection.py 0000664 0000000 0000000 00000012040 14547234302 0022060 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Awaitable, Callable, TypeVar
from contextlib import asynccontextmanager
from logging import WARNING
import functools
import time
from mautrix import __optional_imports__
from mautrix.util.logging import SILLY, TraceLogger
from .scheme import Scheme
if __optional_imports__:
from sqlite3 import Row
from aiosqlite import Cursor
from asyncpg import Record
import asyncpg
from . import aiosqlite
Decorated = TypeVar("Decorated", bound=Callable[..., Any])
LOG_MESSAGE = "%s(%r) took %.3f seconds"
def log_duration(func: Decorated) -> Decorated:
func_name = func.__name__
@functools.wraps(func)
async def wrapper(self: LoggingConnection, arg: str, *args: Any, **kwargs: str) -> Any:
start = time.monotonic()
ret = await func(self, arg, *args, **kwargs)
duration = time.monotonic() - start
self.log.log(WARNING if duration > 1 else SILLY, LOG_MESSAGE, func_name, arg, duration)
return ret
return wrapper
async def handle_exception_noop(_: Exception) -> None:
pass
class LoggingConnection:
def __init__(
self,
scheme: Scheme,
wrapped: aiosqlite.TxnConnection | asyncpg.Connection,
log: TraceLogger,
handle_exception: Callable[[Exception], Awaitable[None]] = handle_exception_noop,
) -> None:
self.scheme = scheme
self.wrapped = wrapped
self.log = log
self._handle_exception = handle_exception
self._inited = True
def __setattr__(self, key: str, value: Any) -> None:
if getattr(self, "_inited", False):
raise RuntimeError("LoggingConnection fields are frozen")
super().__setattr__(key, value)
@asynccontextmanager
async def transaction(self) -> None:
try:
async with self.wrapped.transaction():
yield
except Exception as e:
await self._handle_exception(e)
raise
@log_duration
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor:
try:
return await self.wrapped.execute(query, *args, timeout=timeout)
except Exception as e:
await self._handle_exception(e)
raise
@log_duration
async def executemany(
self, query: str, *args: Any, timeout: float | None = None
) -> str | Cursor:
try:
return await self.wrapped.executemany(query, *args, timeout=timeout)
except Exception as e:
await self._handle_exception(e)
raise
@log_duration
async def fetch(
self, query: str, *args: Any, timeout: float | None = None
) -> list[Row | Record]:
try:
return await self.wrapped.fetch(query, *args, timeout=timeout)
except Exception as e:
await self._handle_exception(e)
raise
@log_duration
async def fetchval(
self, query: str, *args: Any, column: int = 0, timeout: float | None = None
) -> Any:
try:
return await self.wrapped.fetchval(query, *args, column=column, timeout=timeout)
except Exception as e:
await self._handle_exception(e)
raise
@log_duration
async def fetchrow(
self, query: str, *args: Any, timeout: float | None = None
) -> Row | Record | None:
try:
return await self.wrapped.fetchrow(query, *args, timeout=timeout)
except Exception as e:
await self._handle_exception(e)
raise
async def table_exists(self, name: str) -> bool:
if self.scheme == Scheme.SQLITE:
return await self.fetchval(
"SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name=?1)", name
)
elif self.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
return await self.fetchval(
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)", name
)
else:
raise RuntimeError(f"Unknown scheme {self.scheme}")
@log_duration
async def copy_records_to_table(
self,
table_name: str,
*,
records: list[tuple[Any, ...]],
columns: tuple[str, ...] | list[str],
schema_name: str | None = None,
timeout: float | None = None,
) -> None:
if self.scheme != Scheme.POSTGRES:
raise RuntimeError("copy_records_to_table is only supported on Postgres")
try:
return await self.wrapped.copy_records_to_table(
table_name,
records=records,
columns=columns,
schema_name=schema_name,
timeout=timeout,
)
except Exception as e:
await self._handle_exception(e)
raise
python-0.20.4/mautrix/util/async_db/connection.pyi 0000664 0000000 0000000 00000003465 14547234302 0022244 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, AsyncContextManager, Awaitable, Callable
from sqlite3 import Row
from asyncpg import Record
import asyncpg
from mautrix.util.logging import TraceLogger
from . import aiosqlite
from .scheme import Scheme
class LoggingConnection:
scheme: Scheme
wrapped: aiosqlite.TxnConnection | asyncpg.Connection
_handle_exception: Callable[[Exception], Awaitable[None]]
log: TraceLogger
def __init__(
self,
scheme: Scheme,
wrapped: aiosqlite.TxnConnection | asyncpg.Connection,
log: TraceLogger,
handle_exception: Callable[[Exception], Awaitable[None]] = None,
) -> None: ...
async def transaction(self) -> AsyncContextManager[None]: ...
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: ...
async def executemany(self, query: str, *args: Any, timeout: float | None = None) -> str: ...
async def fetch(
self, query: str, *args: Any, timeout: float | None = None
) -> list[Row | Record]: ...
async def fetchval(
self, query: str, *args: Any, column: int = 0, timeout: float | None = None
) -> Any: ...
async def fetchrow(
self, query: str, *args: Any, timeout: float | None = None
) -> Row | Record | None: ...
async def table_exists(self, name: str) -> bool: ...
async def copy_records_to_table(
self,
table_name: str,
*,
records: list[tuple[Any, ...]],
columns: tuple[str, ...] | list[str],
schema_name: str | None = None,
timeout: float | None = None,
) -> None: ...
python-0.20.4/mautrix/util/async_db/database.py 0000664 0000000 0000000 00000013364 14547234302 0021477 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, AsyncContextManager, Type
from abc import ABC, abstractmethod
import logging
from yarl import URL
from mautrix import __optional_imports__
from mautrix.util.logging import TraceLogger
from .connection import LoggingConnection
from .errors import DatabaseNotOwned, ForeignTablesFound
from .scheme import Scheme
from .upgrade import UpgradeTable, upgrade_tables
if __optional_imports__:
from aiosqlite import Cursor
from asyncpg import Record
class Database(ABC):
schemes: dict[str, Type[Database]] = {}
log: TraceLogger
scheme: Scheme
url: URL
_db_args: dict[str, Any]
upgrade_table: UpgradeTable | None
owner_name: str | None
ignore_foreign_tables: bool
def __init__(
self,
url: URL,
upgrade_table: UpgradeTable | None,
db_args: dict[str, Any] | None = None,
log: TraceLogger | None = None,
owner_name: str | None = None,
ignore_foreign_tables: bool = True,
) -> None:
self.url = url
self._db_args = {**db_args} if db_args else {}
self.upgrade_table = upgrade_table
self.owner_name = owner_name
self.ignore_foreign_tables = ignore_foreign_tables
self.log = log or logging.getLogger("mau.db")
assert isinstance(self.log, TraceLogger)
@classmethod
def create(
cls,
url: str | URL,
*,
db_args: dict[str, Any] | None = None,
upgrade_table: UpgradeTable | str | None = None,
log: logging.Logger | TraceLogger | None = None,
owner_name: str | None = None,
ignore_foreign_tables: bool = True,
) -> Database:
url = URL(url)
try:
impl = cls.schemes[url.scheme]
except KeyError as e:
if url.scheme in ("postgres", "postgresql"):
raise RuntimeError(
f"Unknown database scheme {url.scheme}."
" Perhaps you forgot to install asyncpg?"
) from e
elif url.scheme in ("sqlite", "sqlite3"):
raise RuntimeError(
f"Unknown database scheme {url.scheme}."
" Perhaps you forgot to install aiosqlite?"
) from e
raise RuntimeError(f"Unknown database scheme {url.scheme}") from e
if isinstance(upgrade_table, str):
upgrade_table = upgrade_tables[upgrade_table]
elif upgrade_table is None:
upgrade_table = UpgradeTable()
elif not isinstance(upgrade_table, UpgradeTable):
raise ValueError(f"Can't use {type(upgrade_table)} as the upgrade table")
return impl(
url,
db_args=db_args,
upgrade_table=upgrade_table,
log=log,
owner_name=owner_name,
ignore_foreign_tables=ignore_foreign_tables,
)
def override_pool(self, db: Database) -> None:
pass
async def start(self) -> None:
if not self.ignore_foreign_tables:
await self._check_foreign_tables()
if self.owner_name:
await self._check_owner()
if self.upgrade_table and len(self.upgrade_table.upgrades) > 0:
await self.upgrade_table.upgrade(self)
async def _check_foreign_tables(self) -> None:
if await self.table_exists("state_groups_state"):
raise ForeignTablesFound("found state_groups_state likely belonging to Synapse")
elif await self.table_exists("roomserver_rooms"):
raise ForeignTablesFound("found roomserver_rooms likely belonging to Dendrite")
async def _check_owner(self) -> None:
await self.execute(
"""CREATE TABLE IF NOT EXISTS database_owner (
key INTEGER PRIMARY KEY DEFAULT 0,
owner TEXT NOT NULL
)"""
)
owner = await self.fetchval("SELECT owner FROM database_owner WHERE key=0")
if not owner:
await self.execute("INSERT INTO database_owner (owner) VALUES ($1)", self.owner_name)
elif owner != self.owner_name:
raise DatabaseNotOwned(owner)
@abstractmethod
async def stop(self) -> None:
pass
@abstractmethod
def acquire(self) -> AsyncContextManager[LoggingConnection]:
pass
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor:
async with self.acquire() as conn:
return await conn.execute(query, *args, timeout=timeout)
async def executemany(
self, query: str, *args: Any, timeout: float | None = None
) -> str | Cursor:
async with self.acquire() as conn:
return await conn.executemany(query, *args, timeout=timeout)
async def fetch(self, query: str, *args: Any, timeout: float | None = None) -> list[Record]:
async with self.acquire() as conn:
return await conn.fetch(query, *args, timeout=timeout)
async def fetchval(
self, query: str, *args: Any, column: int = 0, timeout: float | None = None
) -> Any:
async with self.acquire() as conn:
return await conn.fetchval(query, *args, column=column, timeout=timeout)
async def fetchrow(
self, query: str, *args: Any, timeout: float | None = None
) -> Record | None:
async with self.acquire() as conn:
return await conn.fetchrow(query, *args, timeout=timeout)
async def table_exists(self, name: str) -> bool:
async with self.acquire() as conn:
return await conn.table_exists(name)
python-0.20.4/mautrix/util/async_db/errors.py 0000664 0000000 0000000 00000002456 14547234302 0021247 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
class DatabaseException(RuntimeError):
pass
@property
def explanation(self) -> str | None:
return None
class UnsupportedDatabaseVersion(DatabaseException):
def __init__(self, name: str, version: int, latest: int) -> None:
super().__init__(
f"Unsupported {name} schema version v{version} (latest known is v{latest})"
)
@property
def explanation(self) -> str:
return "Downgrading is not supported"
class ForeignTablesFound(DatabaseException):
def __init__(self, explanation: str) -> None:
super().__init__(f"The database contains foreign tables ({explanation})")
@property
def explanation(self) -> str:
return "You can use --ignore-foreign-tables to ignore this error"
class DatabaseNotOwned(DatabaseException):
def __init__(self, owner: str) -> None:
super().__init__(f"The database is owned by {owner}")
@property
def explanation(self) -> str:
return "Sharing the same database with different programs is not supported"
python-0.20.4/mautrix/util/async_db/scheme.py 0000664 0000000 0000000 00000001074 14547234302 0021172 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from enum import Enum
class Scheme(Enum):
POSTGRES = "postgres"
COCKROACH = "cockroach"
SQLITE = "sqlite"
def __eq__(self, other: Scheme | str) -> bool:
if isinstance(other, str):
return self.value == other
else:
return super().__eq__(other)
python-0.20.4/mautrix/util/async_db/upgrade.py 0000664 0000000 0000000 00000015271 14547234302 0021361 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Awaitable, Callable, Optional, cast
import functools
import inspect
import logging
from mautrix.util.logging import TraceLogger
from .. import async_db
from .connection import LoggingConnection
from .errors import UnsupportedDatabaseVersion
from .scheme import Scheme
Upgrade = Callable[[LoggingConnection, Scheme], Awaitable[Optional[int]]]
UpgradeWithoutScheme = Callable[[LoggingConnection], Awaitable[Optional[int]]]
async def noop_upgrade(_: LoggingConnection, _2: Scheme) -> None:
pass
def _wrap_upgrade(fn: UpgradeWithoutScheme | Upgrade) -> Upgrade:
params = inspect.signature(fn).parameters
if len(params) == 1:
_wrapped: UpgradeWithoutScheme = cast(UpgradeWithoutScheme, fn)
@functools.wraps(_wrapped)
async def _wrapper(conn: LoggingConnection, _: Scheme) -> Optional[int]:
return await _wrapped(conn)
return _wrapper
else:
return fn
class UpgradeTable:
upgrades: list[Upgrade]
allow_unsupported: bool
database_name: str
version_table_name: str
log: TraceLogger
def __init__(
self,
allow_unsupported: bool = False,
version_table_name: str = "version",
database_name: str = "database",
log: logging.Logger | TraceLogger | None = None,
) -> None:
self.upgrades = []
self.allow_unsupported = allow_unsupported
self.version_table_name = version_table_name
self.database_name = database_name
self.log = log or logging.getLogger("mau.db.upgrade")
def register(
self,
_outer_fn: Upgrade | UpgradeWithoutScheme | None = None,
*,
index: int = -1,
description: str = "",
transaction: bool = True,
upgrades_to: int | Upgrade | None = None,
) -> Upgrade | Callable[[Upgrade | UpgradeWithoutScheme], Upgrade]:
if isinstance(index, str):
description = index
index = -1
def actually_register(fn: Upgrade | UpgradeWithoutScheme) -> Upgrade:
fn = _wrap_upgrade(fn)
fn.__mau_db_upgrade_description__ = description
fn.__mau_db_upgrade_transaction__ = transaction
fn.__mau_db_upgrade_destination__ = (
upgrades_to
if not upgrades_to or isinstance(upgrades_to, int)
else _wrap_upgrade(upgrades_to)
)
if index == -1 or index == len(self.upgrades):
self.upgrades.append(fn)
else:
if len(self.upgrades) <= index:
self.upgrades += [noop_upgrade] * (index - len(self.upgrades) + 1)
self.upgrades[index] = fn
return fn
return actually_register(_outer_fn) if _outer_fn else actually_register
async def _save_version(self, conn: LoggingConnection, version: int) -> None:
self.log.trace(f"Saving current version (v{version}) to database")
await conn.execute(f"DELETE FROM {self.version_table_name}")
await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version)
async def upgrade(self, db: async_db.Database) -> None:
await db.execute(
f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} (
version INTEGER PRIMARY KEY
)"""
)
row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1")
version = row["version"] if row else 0
if len(self.upgrades) < version:
unsupported_version_error = UnsupportedDatabaseVersion(
self.database_name, version, len(self.upgrades)
)
if not self.allow_unsupported:
raise unsupported_version_error
else:
self.log.warning(str(unsupported_version_error))
return
elif len(self.upgrades) == version:
self.log.debug(f"Database at v{version}, not upgrading")
return
async with db.acquire() as conn:
while version < len(self.upgrades):
old_version = version
upgrade = self.upgrades[version]
new_version = (
getattr(upgrade, "__mau_db_upgrade_destination__", None) or version + 1
)
if callable(new_version):
new_version = await new_version(conn, db.scheme)
desc = getattr(upgrade, "__mau_db_upgrade_description__", None)
suffix = f": {desc}" if desc else ""
self.log.debug(
f"Upgrading {self.database_name} from v{old_version} to v{new_version}{suffix}"
)
if getattr(upgrade, "__mau_db_upgrade_transaction__", True):
async with conn.transaction():
version = await upgrade(conn, db.scheme) or new_version
await self._save_version(conn, version)
else:
version = await upgrade(conn, db.scheme) or new_version
await self._save_version(conn, version)
if version != new_version:
self.log.warning(
f"Upgrading {self.database_name} actually went from v{old_version} "
f"to v{version}"
)
upgrade_tables: dict[str, UpgradeTable] = {}
def register_upgrade_table_parent_module(name: str) -> None:
upgrade_tables[name] = UpgradeTable()
def _find_upgrade_table(fn: Upgrade) -> UpgradeTable:
try:
module = fn.__module__
except AttributeError as e:
raise ValueError(
"Registering upgrades without an UpgradeTable requires the function "
"to have the __module__ attribute."
) from e
parts = module.split(".")
used_parts = []
last_error = None
for part in parts:
used_parts.append(part)
try:
return upgrade_tables[".".join(used_parts)]
except KeyError as e:
last_error = e
raise KeyError(
"Registering upgrades without an UpgradeTable requires you to register a parent "
"module with register_upgrade_table_parent_module first."
) from last_error
def register_upgrade(index: int = -1, description: str = "") -> Callable[[Upgrade], Upgrade]:
def actually_register(fn: Upgrade) -> Upgrade:
return _find_upgrade_table(fn).register(fn, index=index, description=description)
return actually_register
python-0.20.4/mautrix/util/async_getter_lock.py 0000664 0000000 0000000 00000003655 14547234302 0021652 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any
import functools
from mautrix import __optional_imports__
if __optional_imports__:
from typing import Awaitable, Callable, ParamSpec
Param = ParamSpec("Param")
Func = Callable[Param, Awaitable[Any]]
def async_getter_lock(fn: Func) -> Func:
"""
A utility decorator for locking async getters that have caches
(preventing race conditions between cache check and e.g. async database actions).
The class must have an ```_async_get_locks`` defaultdict that contains :class:`asyncio.Lock`s
(see example for exact definition). Non-cache-affecting arguments should be only passed as
keyword args.
Args:
fn: The function to decorate.
Returns:
The decorated function.
Examples:
>>> import asyncio
>>> from collections import defaultdict
>>> class User:
... _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
... db: Any
... cache: dict[str, User]
... @classmethod
... @async_getter_lock
... async def get(cls, id: str, *, create: bool = False) -> User | None:
... try:
... return cls.cache[id]
... except KeyError:
... pass
... user = await cls.db.fetch_user(id)
... if user:
... return user
... elif create:
... return await cls.db.create_user(id)
... return None
"""
@functools.wraps(fn)
async def wrapper(cls, *args, **kwargs) -> Any:
async with cls._async_get_locks[args]:
return await fn(cls, *args, **kwargs)
return wrapper
python-0.20.4/mautrix/util/background_task.py 0000664 0000000 0000000 00000003434 14547234302 0021307 0 ustar 00root root 0000000 0000000 # Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Coroutine
import asyncio
import logging
_tasks = set()
log = logging.getLogger("mau.background_task")
async def catch(coro: Coroutine, caller: str) -> None:
try:
await coro
except Exception:
log.exception(f"Uncaught error in background task (created in {caller})")
# Logger.findCaller finds the 3rd stack frame, so add an intermediate function
# to get the caller of create().
def _find_caller() -> tuple[str, int, str, None]:
return log.findCaller()
def create(coro: Coroutine, *, name: str | None = None, catch_errors: bool = True) -> asyncio.Task:
"""
Create a background asyncio task safely, ensuring a reference is kept until the task completes.
It also catches and logs uncaught errors (unless disabled via the parameter).
Args:
coro: The coroutine to wrap in a task and execute.
name: An optional name for the created task.
catch_errors: Should the task be wrapped in a try-except block to log any uncaught errors?
Returns:
An asyncio Task object wrapping the given coroutine.
"""
if catch_errors:
try:
file_name, line_number, function_name, _ = _find_caller()
caller = f"{function_name} at {file_name}:{line_number}"
except ValueError:
caller = "unknown function"
task = asyncio.create_task(catch(coro, caller), name=name)
else:
task = asyncio.create_task(coro, name=name)
_tasks.add(task)
task.add_done_callback(_tasks.discard)
return task
python-0.20.4/mautrix/util/bridge_state.py 0000664 0000000 0000000 00000011756 14547234302 0020610 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, ClassVar, Dict, Optional
import logging
import time
from attr import dataclass
import aiohttp
from mautrix.api import HTTPAPI
from mautrix.types import SerializableAttrs, SerializableEnum, UserID, field
class BridgeStateEvent(SerializableEnum):
#####################################
# Global state events, no remote ID #
#####################################
# Bridge process is starting up
STARTING = "STARTING"
# Bridge has started but has no valid credentials
UNCONFIGURED = "UNCONFIGURED"
# Bridge is running
RUNNING = "RUNNING"
# The server was unable to reach the bridge
BRIDGE_UNREACHABLE = "BRIDGE_UNREACHABLE"
################################################
# Remote state events, should have a remote ID #
################################################
# Bridge has credentials and has started connecting to a remote network
CONNECTING = "CONNECTING"
# Bridge has begun backfilling
BACKFILLING = "BACKFILLING"
# Bridge has happily connected and is bridging messages
CONNECTED = "CONNECTED"
# Bridge has temporarily disconnected, expected to reconnect automatically
TRANSIENT_DISCONNECT = "TRANSIENT_DISCONNECT"
# Bridge has disconnected, will require user to log in again
BAD_CREDENTIALS = "BAD_CREDENTIALS"
# Bridge has disconnected for an unknown/unexpected reason - we should investigate
UNKNOWN_ERROR = "UNKNOWN_ERROR"
# User has logged out - stop tracking this remote
LOGGED_OUT = "LOGGED_OUT"
ok_ish_states = (
BridgeStateEvent.STARTING,
BridgeStateEvent.UNCONFIGURED,
BridgeStateEvent.RUNNING,
BridgeStateEvent.CONNECTING,
BridgeStateEvent.CONNECTED,
BridgeStateEvent.BACKFILLING,
)
@dataclass(kw_only=True)
class BridgeState(SerializableAttrs):
human_readable_errors: ClassVar[Dict[Optional[str], str]] = {}
default_source: ClassVar[str] = "bridge"
default_error_ttl: ClassVar[int] = 3600
default_ok_ttl: ClassVar[int] = 21600
state_event: BridgeStateEvent
user_id: Optional[UserID] = None
remote_id: Optional[str] = None
remote_name: Optional[str] = None
timestamp: Optional[int] = None
ttl: int = 0
source: Optional[str] = None
error: Optional[str] = None
message: Optional[str] = None
info: Optional[Dict[str, Any]] = None
reason: Optional[str] = None
send_attempts_: int = field(default=0, hidden=True)
def fill(self) -> "BridgeState":
self.timestamp = self.timestamp or int(time.time())
self.source = self.source or self.default_source
if not self.ttl:
self.ttl = (
self.default_ok_ttl
if self.state_event in ok_ish_states
else self.default_error_ttl
)
if self.error:
try:
msg = self.human_readable_errors[self.error]
except KeyError:
pass
else:
self.message = msg.format(message=self.message) if self.message else msg
return self
def should_deduplicate(self, prev_state: Optional["BridgeState"]) -> bool:
if (
not prev_state
or prev_state.state_event != self.state_event
or prev_state.error != self.error
or prev_state.info != self.info
):
# If there's no previous state or the state was different, send this one.
return False
# If the previous state is recent, drop this one
return prev_state.timestamp + prev_state.ttl > self.timestamp
async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool = True) -> bool:
if not url:
return True
self.send_attempts_ += 1
headers = {"Authorization": f"Bearer {token}", "User-Agent": HTTPAPI.default_ua}
try:
async with aiohttp.ClientSession() as sess, sess.post(
url, json=self.serialize(), headers=headers
) as resp:
if not 200 <= resp.status < 300:
text = await resp.text()
text = text.replace("\n", "\\n")
log.warning(
f"Unexpected status code {resp.status} "
f"sending bridge state update: {text}"
)
return False
elif log_sent:
log.debug(f"Sent new bridge state {self}")
except Exception as e:
log.warning(f"Failed to send updated bridge state: {e}")
return False
return True
@dataclass(kw_only=True)
class GlobalBridgeState(SerializableAttrs):
remote_states: Optional[Dict[str, BridgeState]] = field(json="remoteState", default=None)
bridge_state: BridgeState = field(json="bridgeState")
python-0.20.4/mautrix/util/color_log.py 0000664 0000000 0000000 00000000136 14547234302 0020121 0 ustar 00root root 0000000 0000000 # This only exists for compatibility with old log configs
from .logging import ColorFormatter
python-0.20.4/mautrix/util/config/ 0000775 0000000 0000000 00000000000 14547234302 0017035 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/util/config/__init__.py 0000664 0000000 0000000 00000001126 14547234302 0021146 0 ustar 00root root 0000000 0000000 from .base import BaseConfig, BaseMissingError, ConfigUpdateHelper
from .file import BaseFileConfig, yaml
from .proxy import BaseProxyConfig
from .recursive_dict import RecursiveDict
from .string import BaseStringConfig
from .validation import BaseValidatableConfig, ConfigValueError, ForbiddenDefault, ForbiddenKey
__all__ = [
"BaseConfig",
"BaseMissingError",
"ConfigUpdateHelper",
"BaseFileConfig",
"yaml",
"BaseProxyConfig",
"RecursiveDict",
"BaseStringConfig",
"BaseValidatableConfig",
"ConfigValueError",
"ForbiddenDefault",
"ForbiddenKey",
]
python-0.20.4/mautrix/util/config/base.py 0000664 0000000 0000000 00000004561 14547234302 0020327 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from abc import ABC, abstractmethod
from ruamel.yaml.comments import Comment, CommentedBase, CommentedMap
from .recursive_dict import RecursiveDict
class BaseMissingError(ValueError):
pass
class ConfigUpdateHelper:
base: RecursiveDict[CommentedMap]
def __init__(self, base: RecursiveDict, config: RecursiveDict) -> None:
self.base = base
self.source = config
def copy(self, from_path: str, to_path: str | None = None) -> None:
if from_path in self.source:
val = self.source[from_path]
# Small hack to make sure comments from the user config don't
# partially leak into the updated version.
if isinstance(val, CommentedBase):
setattr(val, Comment.attrib, Comment())
self.base[to_path or from_path] = val
def copy_dict(
self,
from_path: str,
to_path: str | None = None,
override_existing_map: bool = True,
) -> None:
if from_path in self.source:
to_path = to_path or from_path
if override_existing_map or to_path not in self.base:
self.base[to_path] = CommentedMap()
for key, value in self.source[from_path].items():
self.base[to_path][key] = value
def __iter__(self):
yield self.copy
yield self.copy_dict
yield self.base
class BaseConfig(ABC, RecursiveDict[CommentedMap]):
@abstractmethod
def load(self) -> None:
pass
@abstractmethod
def load_base(self) -> RecursiveDict[CommentedMap] | None:
pass
def load_and_update(self) -> None:
self.load()
self.update()
@abstractmethod
def save(self) -> None:
pass
def update(self, save: bool = True) -> None:
base = self.load_base()
if not base:
raise BaseMissingError("Can't update() without base config")
self.do_update(ConfigUpdateHelper(base, self))
self._data = base._data
if save:
self.save()
@abstractmethod
def do_update(self, helper: ConfigUpdateHelper) -> None:
pass
python-0.20.4/mautrix/util/config/file.py 0000664 0000000 0000000 00000004364 14547234302 0020335 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from abc import ABC
import logging
import os
import pkgutil
import tempfile
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
from yarl import URL
from .base import BaseConfig
from .recursive_dict import RecursiveDict
yaml = YAML()
yaml.indent(4)
yaml.width = 200
log: logging.Logger = logging.getLogger("mau.util.config")
class BaseFileConfig(BaseConfig, ABC):
def __init__(self, path: str, base_path: str) -> None:
super().__init__()
self._data = CommentedMap()
self.path: str = path
self.base_path: str = base_path
def load(self) -> None:
with open(self.path, "r") as stream:
self._data = yaml.load(stream)
def load_base(self) -> RecursiveDict[CommentedMap] | None:
if self.base_path.startswith("pkg://"):
url = URL(self.base_path)
return RecursiveDict(yaml.load(pkgutil.get_data(url.host, url.path)), CommentedMap)
try:
with open(self.base_path, "r") as stream:
return RecursiveDict(yaml.load(stream), CommentedMap)
except OSError:
pass
return None
def save(self) -> None:
try:
tf = tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".yaml", dir=os.path.dirname(self.path)
)
except OSError as e:
log.warning(f"Failed to create tempfile to write updated config to disk: {e}")
return
try:
yaml.dump(self._data, tf)
except OSError as e:
log.warning(f"Failed to write updated config to tempfile: {e}")
tf.file.close()
os.remove(tf.name)
return
tf.file.close()
try:
os.rename(tf.name, self.path)
except OSError as e:
log.warning(f"Failed to rename tempfile with updated config to {self.path}: {e}")
try:
os.remove(tf.name)
except FileNotFoundError:
pass
python-0.20.4/mautrix/util/config/proxy.py 0000664 0000000 0000000 00000002144 14547234302 0020571 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Callable
from abc import ABC
from ruamel.yaml.comments import CommentedMap
from .base import BaseConfig
from .recursive_dict import RecursiveDict
class BaseProxyConfig(BaseConfig, ABC):
def __init__(
self,
load: Callable[[], CommentedMap],
load_base: Callable[[], RecursiveDict[CommentedMap] | None],
save: Callable[[RecursiveDict[CommentedMap]], None],
) -> None:
super().__init__()
self._data = CommentedMap()
self._load_proxy = load
self._load_base_proxy = load_base
self._save_proxy = save
def load(self) -> None:
self._data = self._load_proxy() or CommentedMap()
def load_base(self) -> RecursiveDict[CommentedMap] | None:
return self._load_base_proxy()
def save(self) -> None:
self._save_proxy(self._data)
python-0.20.4/mautrix/util/config/recursive_dict.py 0000664 0000000 0000000 00000006673 14547234302 0022435 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Generic, Type, TypeVar
import copy
from ruamel.yaml.comments import CommentedMap
T = TypeVar("T")
class RecursiveDict(Generic[T]):
def __init__(self, data: T | None = None, dict_factory: Type[T] | None = None) -> None:
self._dict_factory = dict_factory or dict
self._data: CommentedMap = data or self._dict_factory()
def clone(self) -> RecursiveDict:
return RecursiveDict(data=copy.deepcopy(self._data), dict_factory=self._dict_factory)
@staticmethod
def parse_key(key: str) -> tuple[str, str | None]:
if "." not in key:
return key, None
key, next_key = key.split(".", 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2 :] if len(next_key) > end_index + 1 else None
return key, next_key
def _recursive_get(self, data: T, key: str, default_value: Any) -> Any:
key, next_key = self.parse_key(key)
if next_key is not None:
next_data = data.get(key, self._dict_factory())
return self._recursive_get(next_data, next_key, default_value)
try:
return data[key]
except (AttributeError, KeyError):
return default_value
def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
if allow_recursion and "." in key:
return self._recursive_get(self._data, key, default_value)
return self._data.get(key, default_value)
def __getitem__(self, key: str) -> Any:
return self.get(key, None)
def __contains__(self, key: str) -> bool:
return self.get(key, None) is not None
def _recursive_set(self, data: T, key: str, value: Any) -> None:
key, next_key = self.parse_key(key)
if next_key is not None:
if key not in data:
data[key] = self._dict_factory()
next_data = data.get(key, self._dict_factory())
return self._recursive_set(next_data, next_key, value)
data[key] = value
def set(self, key: str, value: Any, allow_recursion: bool = True) -> None:
if allow_recursion and "." in key:
self._recursive_set(self._data, key, value)
return
self._data[key] = value
def __setitem__(self, key: str, value: Any) -> None:
self.set(key, value)
def _recursive_del(self, data: T, key: str) -> None:
key, next_key = self.parse_key(key)
if next_key is not None:
if key not in data:
return
next_data = data[key]
return self._recursive_del(next_data, next_key)
try:
del data[key]
del data.ca.items[key]
except KeyError:
pass
def delete(self, key: str, allow_recursion: bool = True) -> None:
if allow_recursion and "." in key:
self._recursive_del(self._data, key)
return
try:
del self._data[key]
del self._data.ca.items[key]
except KeyError:
pass
def __delitem__(self, key: str) -> None:
self.delete(key)
python-0.20.4/mautrix/util/config/string.py 0000664 0000000 0000000 00000001732 14547234302 0020720 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from abc import ABC
import io
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
from .base import BaseConfig
from .recursive_dict import RecursiveDict
yaml = YAML()
yaml.indent(4)
yaml.width = 200
class BaseStringConfig(BaseConfig, ABC):
def __init__(self, data: str, base_data: str) -> None:
super().__init__()
self._data = yaml.load(data)
self._base = RecursiveDict(yaml.load(base_data), CommentedMap)
def load(self) -> None:
pass
def load_base(self) -> RecursiveDict[CommentedMap] | None:
return self._base
def save(self) -> str:
buf = io.StringIO()
yaml.dump(self._data, buf)
return buf.getvalue()
python-0.20.4/mautrix/util/config/validation.py 0000664 0000000 0000000 00000003020 14547234302 0021534 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any
from abc import ABC, abstractmethod
from attr import dataclass
import attr
from .base import BaseConfig
class ConfigValueError(ValueError):
def __init__(self, key: str, message: str) -> None:
super().__init__(
f"{key} not configured. {message}" if message else f"{key} not configured"
)
class ForbiddenKey(str):
pass
@dataclass
class ForbiddenDefault:
key: str
value: Any
error: str | None = None
condition: str | None = attr.ib(default=None, kw_only=True)
def check(self, config: BaseConfig) -> bool:
if self.condition and not config[self.condition]:
return False
elif isinstance(self.value, ForbiddenKey):
return str(self.value) in config[self.key]
else:
return config[self.key] == self.value
@property
def exception(self) -> ConfigValueError:
return ConfigValueError(self.key, self.error)
class BaseValidatableConfig(BaseConfig, ABC):
@property
@abstractmethod
def forbidden_defaults(self) -> list[ForbiddenDefault]:
pass
def check_default_values(self) -> None:
for default in self.forbidden_defaults:
if default.check(self):
raise default.exception
python-0.20.4/mautrix/util/db/ 0000775 0000000 0000000 00000000000 14547234302 0016155 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/util/db/__init__.py 0000664 0000000 0000000 00000000103 14547234302 0020260 0 ustar 00root root 0000000 0000000 from .base import Base, BaseClass
__all__ = ["Base", "BaseClass"]
python-0.20.4/mautrix/util/db/base.py 0000664 0000000 0000000 00000020130 14547234302 0017435 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Type, TypeVar, cast
from contextlib import contextmanager
from sqlalchemy import Constraint, Table
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.ext.declarative import as_declarative, declarative_base
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.sql.expression import ClauseElement, Select, and_
if TYPE_CHECKING:
from sqlalchemy.engine.result import ResultProxy, RowProxy
T = TypeVar("T", bound="BaseClass")
class BaseClass:
"""
Base class for SQLAlchemy models. Provides SQLAlchemy declarative base features and some
additional utilities.
.. deprecated:: 0.15.0
The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy.
"""
__tablename__: str
db: Engine
t: Table
__table__: Table
c: ImmutableColumnCollection
column_names: List[str]
@classmethod
def bind(cls, db_engine: Engine) -> None:
cls.db = db_engine
cls.t = cls.__table__
cls.c = cls.t.columns
cls.column_names = cls.c.keys()
@classmethod
def copy(
cls, bind: Optional[Engine] = None, rebase: Optional[declarative_base] = None
) -> Type[T]:
copy = cast(Type[T], type(cls.__name__, (cls, rebase) if rebase else (cls,), {}))
if bind is not None:
copy.bind(db_engine=bind)
return copy
@classmethod
def _one_or_none(cls: Type[T], rows: "ResultProxy") -> Optional[T]:
"""
Try scanning one row from a ResultProxy and return ``None`` if it fails.
Args:
rows: The SQLAlchemy result to scan.
Returns:
The scanned object, or ``None`` if there were no rows.
"""
try:
return cls.scan(next(rows))
except StopIteration:
return None
@classmethod
def _all(cls: Type[T], rows: "ResultProxy") -> Iterator[T]:
"""
Scan all rows from a ResultProxy.
Args:
rows: The SQLAlchemy result to scan.
Yields:
Each row scanned with :meth:`scan`
"""
for row in rows:
yield cls.scan(row)
@classmethod
def scan(cls: Type[T], row: "RowProxy") -> T:
"""
Read the data from a row into an object.
Args:
row: The RowProxy object.
Returns:
An object containing the information in the row.
"""
return cls(**dict(zip(cls.column_names, row)))
@classmethod
def _make_simple_select(cls: Type[T], *args: ClauseElement) -> Select:
"""
Create a simple ``SELECT * FROM table WHERE `` statement.
Args:
*args: The WHERE clauses. If there are many elements, they're joined with AND.
Returns:
The SQLAlchemy SELECT statement object.
"""
if len(args) > 1:
return cls.t.select().where(and_(*args))
elif len(args) == 1:
return cls.t.select().where(args[0])
else:
return cls.t.select()
@classmethod
def _select_all(cls: Type[T], *args: ClauseElement) -> Iterator[T]:
"""
Select all rows with given conditions. This is intended to be used by table-specific
select methods.
Args:
*args: The WHERE clauses. If there are many elements, they're joined with AND.
Yields:
The objects representing the rows read with :meth:`scan`
"""
yield from cls._all(cls.db.execute(cls._make_simple_select(*args)))
@classmethod
def _select_one_or_none(cls: Type[T], *args: ClauseElement) -> T:
"""
Select one row with given conditions. If no row is found, return ``None``. This is intended
to be used by table-specific select methods.
Args:
*args: The WHERE clauses. If there are many elements, they're joined with AND.
Returns:
The object representing the matched row read with :meth:`scan`, or ``None`` if no rows
matched.
"""
return cls._one_or_none(cls.db.execute(cls._make_simple_select(*args)))
def _constraint_to_clause(self, constraint: Constraint) -> ClauseElement:
return and_(
*[column == self.__dict__[name] for name, column in constraint.columns.items()]
)
@property
def _edit_identity(self: T) -> ClauseElement:
"""The SQLAlchemy WHERE clause used for editing and deleting individual rows.
Usually AND of primary keys."""
return self._constraint_to_clause(self.t.primary_key)
def edit(self: T, *, _update_values: bool = True, **values) -> None:
"""
Edit this row.
Args:
_update_values: Whether or not the values in memory should be updated as well as the
values in the database.
**values: The values to change.
"""
with self.db.begin() as conn:
conn.execute(self.t.update().where(self._edit_identity).values(**values))
if _update_values:
for key, value in values.items():
setattr(self, key, value)
@contextmanager
def edit_mode(self: T) -> None:
"""
Edit this row in a fancy context manager way. This stores the current edit identity, then
yields to the context manager and finally puts the new values into the row using the old
edit identity in the WHERE clause.
>>> class TableClass(Base):
... ...
>>> db_instance = TableClass(id="something")
>>> with db_instance.edit_mode():
... db_instance.id = "new_id"
"""
old_identity = self._edit_identity
yield old_identity
with self.db.begin() as conn:
conn.execute(self.t.update().where(old_identity).values(**self._insert_values))
def delete(self: T) -> None:
"""Delete this row."""
with self.db.begin() as conn:
conn.execute(self.t.delete().where(self._edit_identity))
@property
def _insert_values(self: T) -> Dict[str, Any]:
"""Values for inserts. Generally you want all the values in the table."""
return {
column_name: self.__dict__[column_name]
for column_name in self.column_names
if column_name in self.__dict__
}
def insert(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.insert().values(**self._insert_values))
@property
def _upsert_values(self: T) -> Dict[str, Any]:
"""The values to set when an upsert-insert conflicts and moves to the update part."""
return self._insert_values
def _upsert_postgres(self: T, conn: Connection) -> None:
conn.execute(
pg_insert(self.t)
.values(**self._insert_values)
.on_conflict_do_update(constraint=self.t.primary_key, set_=self._upsert_values)
)
def _upsert_sqlite(self: T, conn: Connection) -> None:
conn.execute(self.t.insert().values(**self._insert_values).prefix_with("OR REPLACE"))
def _upsert_generic(self: T, conn: Connection):
conn.execute(self.t.delete().where(self._edit_identity))
conn.execute(self.t.insert().values(**self._insert_values))
def upsert(self: T) -> None:
with self.db.begin() as conn:
if self.db.dialect.name == "postgresql":
self._upsert_postgres(conn)
elif self.db.dialect.name == "sqlite":
self._upsert_sqlite(conn)
else:
self._upsert_generic(conn)
def __iter__(self):
for key in self.column_names:
yield self.__dict__[key]
@as_declarative()
class Base(BaseClass):
"""
.. deprecated:: 0.15.0
The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy.
"""
pass
python-0.20.4/mautrix/util/ffmpeg.py 0000664 0000000 0000000 00000017312 14547234302 0017412 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Iterable
from pathlib import Path
import asyncio
import json
import logging
import mimetypes
import os
import shutil
import tempfile
try:
from . import magic
except ImportError:
magic = None
def _abswhich(program: str) -> str | None:
path = shutil.which(program)
return os.path.abspath(path) if path else None
class ConverterError(ChildProcessError):
pass
class NotInstalledError(ConverterError):
def __init__(self) -> None:
super().__init__("failed to transcode media: ffmpeg is not installed")
ffmpeg_path = _abswhich("ffmpeg")
ffmpeg_default_params = ("-hide_banner", "-loglevel", "warning", "-y")
ffprobe_path = _abswhich("ffprobe")
ffprobe_default_params = (
"-loglevel",
"quiet",
"-print_format",
"json",
"-show_optional_fields",
"1",
"-show_format",
"-show_streams",
)
async def probe_path(
input_file: os.PathLike[str] | str,
logger: logging.Logger | None = None,
) -> Any:
"""
Probes a media file on the disk using ffprobe.
Args:
input_file: The full path to the file.
Returns:
A Python object containing the parsed JSON response from ffprobe
Raises:
ConverterError: if ffprobe returns a non-zero exit code.
"""
if ffprobe_path is None:
raise NotInstalledError()
input_file = Path(input_file)
proc = await asyncio.create_subprocess_exec(
ffprobe_path,
*ffprobe_default_params,
str(input_file),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})"
raise ConverterError(f"ffprobe error: {err_text}")
elif stderr and logger:
logger.warning(f"ffprobe warning: {stderr.decode('utf-8')}")
return json.loads(stdout)
async def probe_bytes(
data: bytes,
input_mime: str | None = None,
logger: logging.Logger | None = None,
) -> Any:
"""
Probe media file data using ffprobe.
Args:
data: The bytes of the file to probe.
input_mime: The mime type of the input data. If not specified, will be guessed using magic.
Returns:
A Python object containing the parsed JSON response from ffprobe
Raises:
ConverterError: if ffprobe returns a non-zero exit code.
"""
if ffprobe_path is None:
raise NotInstalledError()
if input_mime is None:
if magic is None:
raise ValueError("input_mime was not specified and magic is not installed")
input_mime = magic.mimetype(data)
input_extension = mimetypes.guess_extension(input_mime)
with tempfile.TemporaryDirectory(prefix="mautrix_ffmpeg_") as tmpdir:
input_file = Path(tmpdir) / f"data{input_extension}"
with open(input_file, "wb") as file:
file.write(data)
return await probe_path(input_file=input_file, logger=logger)
async def convert_path(
input_file: os.PathLike[str] | str,
output_extension: str | None,
input_args: Iterable[str] | None = None,
output_args: Iterable[str] | None = None,
remove_input: bool = False,
output_path_override: os.PathLike[str] | str | None = None,
logger: logging.Logger | None = None,
) -> Path | bytes:
"""
Convert a media file on the disk using ffmpeg.
Args:
input_file: The full path to the file.
output_extension: The extension that the output file should be.
input_args: Arguments to tell ffmpeg how to parse the input file.
output_args: Arguments to tell ffmpeg how to convert the file to reach the wanted output.
remove_input: Whether the input file should be removed after converting.
Not compatible with ``output_path_override``.
output_path_override: A custom output path to use
(instead of using the input path with a different extension).
Returns:
The path to the converted file, or the stdout if ``output_path_override`` was set to ``-``.
Raises:
ConverterError: if ffmpeg returns a non-zero exit code.
"""
if ffmpeg_path is None:
raise NotInstalledError()
if output_path_override:
output_file = output_path_override
if remove_input:
raise ValueError("remove_input can't be specified with output_path_override")
elif not output_extension:
raise ValueError("output_extension or output_path_override is required")
else:
input_file = Path(input_file)
output_file = input_file.parent / f"{input_file.stem}{output_extension}"
if input_file == output_file:
output_file = Path(output_file)
output_file = output_file.parent / f"{output_file.stem}-new{output_extension}"
proc = await asyncio.create_subprocess_exec(
ffmpeg_path,
*ffmpeg_default_params,
*(input_args or ()),
"-i",
str(input_file),
*(output_args or ()),
str(output_file),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})"
raise ConverterError(f"ffmpeg error: {err_text}")
elif stderr and logger:
logger.warning(f"ffmpeg warning: {stderr.decode('utf-8')}")
if remove_input and isinstance(input_file, Path):
input_file.unlink(missing_ok=True)
return stdout if output_file == "-" else output_file
async def convert_bytes(
data: bytes,
output_extension: str,
input_args: Iterable[str] | None = None,
output_args: Iterable[str] | None = None,
input_mime: str | None = None,
logger: logging.Logger | None = None,
) -> bytes:
"""
Convert media file data using ffmpeg.
Args:
data: The bytes of the file to convert.
output_extension: The extension that the output file should be.
input_args: Arguments to tell ffmpeg how to parse the input file.
output_args: Arguments to tell ffmpeg how to convert the file to reach the wanted output.
input_mime: The mime type of the input data. If not specified, will be guessed using magic.
Returns:
The converted file as bytes.
Raises:
ConverterError: if ffmpeg returns a non-zero exit code.
"""
if ffmpeg_path is None:
raise NotInstalledError()
if input_mime is None:
if magic is None:
raise ValueError("input_mime was not specified and magic is not installed")
input_mime = magic.mimetype(data)
input_extension = mimetypes.guess_extension(input_mime)
with tempfile.TemporaryDirectory(prefix="mautrix_ffmpeg_") as tmpdir:
input_file = Path(tmpdir) / f"data{input_extension}"
with open(input_file, "wb") as file:
file.write(data)
output_file = await convert_path(
input_file=input_file,
output_extension=output_extension,
input_args=input_args,
output_args=output_args,
logger=logger,
)
with open(output_file, "rb") as file:
return file.read()
__all__ = [
"ffmpeg_path",
"ffmpeg_default_params",
"ConverterError",
"NotInstalledError",
"convert_bytes",
"convert_path",
"probe_bytes",
"probe_path",
]
python-0.20.4/mautrix/util/file_store.py 0000664 0000000 0000000 00000004374 14547234302 0020305 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import IO, Any, Protocol
from abc import ABC, abstractmethod
from pathlib import Path
import json
import pickle
import time
class Filer(Protocol):
def dump(self, obj: Any, file: IO) -> None:
pass
def load(self, file: IO) -> Any:
pass
class FileStore(ABC):
path: str | Path | IO
filer: Filer
binary: bool
save_interval: float
_last_save: float
def __init__(
self,
path: str | Path | IO,
filer: Filer | None = None,
binary: bool = True,
save_interval: float = 60.0,
) -> None:
self.path = path
self.filer = filer or (pickle if binary else json)
self.binary = binary
self.save_interval = save_interval
self._last_save = time.monotonic()
@abstractmethod
def serialize(self) -> Any:
pass
@abstractmethod
def deserialize(self, data: Any) -> None:
pass
def _save(self) -> None:
if isinstance(self.path, IO):
file = self.path
close = False
else:
file = open(self.path, "wb" if self.binary else "w")
close = True
try:
self.filer.dump(self.serialize(), file)
finally:
if close:
file.close()
def _load(self) -> None:
if isinstance(self.path, IO):
file = self.path
close = False
else:
try:
file = open(self.path, "rb" if self.binary else "r")
except FileNotFoundError:
return
close = True
try:
self.deserialize(self.filer.load(file))
finally:
if close:
file.close()
async def flush(self) -> None:
self._save()
async def open(self) -> None:
self._load()
def _time_limited_flush(self) -> None:
if self._last_save + self.save_interval < time.monotonic():
self._save()
self._last_save = time.monotonic()
python-0.20.4/mautrix/util/format_duration.py 0000664 0000000 0000000 00000003331 14547234302 0021337 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
def _pluralize(count: int, singular: str) -> str:
return singular if count == 1 else f"{singular}s"
def _include_if_positive(count: int, word: str) -> str:
return f"{count} {_pluralize(count, word)}" if count > 0 else ""
def format_duration(seconds: int) -> str:
"""
Format seconds as a simple duration in weeks/days/hours/minutes/seconds.
Args:
seconds: The number of seconds as an integer. Must be positive.
Returns:
The formatted duration.
Examples:
>>> from mautrix.util.format_duration import format_duration
>>> format_duration(1234)
'20 minutes and 34 seconds'
>>> format_duration(987654)
'1 week, 4 days, 10 hours, 20 minutes and 54 seconds'
>>> format_duration(60)
'1 minute'
Raises:
ValueError: if the duration is not positive.
"""
if seconds <= 0:
raise ValueError("format_duration only accepts positive values")
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
days, hours = divmod(hours, 24)
weeks, days = divmod(days, 7)
parts = [
_include_if_positive(weeks, "week"),
_include_if_positive(days, "day"),
_include_if_positive(hours, "hour"),
_include_if_positive(minutes, "minute"),
_include_if_positive(seconds, "second"),
]
parts = [part for part in parts if part]
if len(parts) > 2:
parts = [", ".join(parts[:-1]), parts[-1]]
return " and ".join(parts)
python-0.20.4/mautrix/util/format_duration_test.py 0000664 0000000 0000000 00000001540 14547234302 0022376 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import pytest
from .format_duration import format_duration
tests = {
1234: "20 minutes and 34 seconds",
987654: "1 week, 4 days, 10 hours, 20 minutes and 54 seconds",
694861: "1 week, 1 day, 1 hour, 1 minute and 1 second",
1: "1 second",
59: "59 seconds",
60: "1 minute",
120: "2 minutes",
}
def test_format_duration() -> None:
for seconds, formatted in tests.items():
assert format_duration(seconds) == formatted
def test_non_positive_error() -> None:
with pytest.raises(ValueError):
format_duration(0)
with pytest.raises(ValueError):
format_duration(-123)
python-0.20.4/mautrix/util/formatter/ 0000775 0000000 0000000 00000000000 14547234302 0017573 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/util/formatter/__init__.py 0000664 0000000 0000000 00000001572 14547234302 0021711 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from .entity_string import AbstractEntity, EntityString, SemiAbstractEntity, SimpleEntity
from .formatted_string import EntityType, FormattedString
from .html_reader import HTMLNode, read_html
from .markdown_string import MarkdownString
from .parser import MatrixParser, RecursionContext
async def parse_html(input_html: str) -> str:
return (await MatrixParser().parse(input_html)).text
__all__ = [
"AbstractEntity",
"EntityString",
"SemiAbstractEntity",
"SimpleEntity",
"EntityType",
"FormattedString",
"HTMLNode",
"read_html",
"MarkdownString",
"MatrixParser",
"RecursionContext",
"parse_html",
]
python-0.20.4/mautrix/util/formatter/entity_string.py 0000664 0000000 0000000 00000012150 14547234302 0023046 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Generic, Iterable, Sequence, Type, TypeVar
from abc import ABC, abstractmethod
from itertools import chain
from attr import dataclass
import attr
from .formatted_string import EntityType, FormattedString
class AbstractEntity(ABC):
def __init__(
self, type: EntityType, offset: int, length: int, extra_info: dict[str, Any]
) -> None:
pass
@abstractmethod
def copy(self) -> AbstractEntity:
pass
@abstractmethod
def adjust_offset(self, offset: int, max_length: int = -1) -> AbstractEntity | None:
pass
class SemiAbstractEntity(AbstractEntity, ABC):
offset: int
length: int
def adjust_offset(self, offset: int, max_length: int = -1) -> SemiAbstractEntity | None:
entity = self.copy()
entity.offset += offset
if entity.offset < 0:
entity.length += entity.offset
if entity.length < 0:
return None
entity.offset = 0
elif entity.offset > max_length > -1:
return None
elif entity.offset + entity.length > max_length > -1:
entity.length = max_length - entity.offset
return entity
@dataclass
class SimpleEntity(SemiAbstractEntity):
type: EntityType
offset: int
length: int
extra_info: dict[str, Any] = attr.ib(factory=dict)
def copy(self) -> SimpleEntity:
return attr.evolve(self)
TEntity = TypeVar("TEntity", bound=AbstractEntity)
TEntityType = TypeVar("TEntityType")
class EntityString(Generic[TEntity, TEntityType], FormattedString):
text: str
_entities: list[TEntity]
entity_class: Type[AbstractEntity] = SimpleEntity
def __init__(self, text: str = "", entities: list[TEntity] = None) -> None:
self.text = text
self._entities = entities or []
def __repr__(self) -> str:
return f"{self.__class__.__name__}(text='{self.text}', entities={self.entities})"
def __str__(self) -> str:
return self.text
@property
def entities(self) -> list[TEntity]:
return self._entities
@entities.setter
def entities(self, val: Iterable[TEntity]) -> None:
self._entities = [entity for entity in val if entity is not None]
def _offset_entities(self, offset: int) -> EntityString:
self.entities = (entity.adjust_offset(offset, len(self.text)) for entity in self.entities)
return self
def append(self, *args: str | FormattedString) -> EntityString:
for msg in args:
if isinstance(msg, EntityString):
self.entities += (entity.adjust_offset(len(self.text)) for entity in msg.entities)
self.text += msg.text
else:
self.text += str(msg)
return self
def prepend(self, *args: str | FormattedString) -> EntityString:
for msg in args:
if isinstance(msg, EntityString):
self.text = msg.text + self.text
self.entities = chain(
msg.entities, (entity.adjust_offset(len(msg.text)) for entity in self.entities)
)
else:
text = str(msg)
self.text = text + self.text
self.entities = (entity.adjust_offset(len(text)) for entity in self.entities)
return self
def format(
self, entity_type: TEntityType, offset: int = None, length: int = None, **kwargs
) -> EntityString:
self.entities.append(
self.entity_class(
type=entity_type,
offset=offset or 0,
length=length or len(self.text),
extra_info=kwargs,
)
)
return self
def trim(self) -> EntityString:
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
self.text = self.text.rstrip()
self._offset_entities(-diff)
return self
def split(self, separator, max_items: int = -1) -> list[EntityString]:
text_parts = self.text.split(separator, max_items - 1)
output: list[EntityString] = []
offset = 0
for part in text_parts:
msg = type(self)(part)
msg.entities = (entity.adjust_offset(-offset, len(part)) for entity in self.entities)
output.append(msg)
offset += len(part)
offset += len(separator)
return output
@classmethod
def join(cls, items: Sequence[str | EntityString], separator: str = " ") -> EntityString:
main = cls()
for msg in items:
if not isinstance(msg, EntityString):
msg = cls(text=str(msg))
main.entities += [entity.adjust_offset(len(main.text)) for entity in msg.entities]
main.text += msg.text + separator
if len(separator) > 0:
main.text = main.text[: -len(separator)]
return main
python-0.20.4/mautrix/util/formatter/formatted_string.py 0000664 0000000 0000000 00000010551 14547234302 0023522 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Sequence
from abc import ABC, abstractmethod
from enum import Enum, auto
class EntityType(Enum):
"""EntityType is a Matrix formatting entity type."""
BOLD = auto()
ITALIC = auto()
STRIKETHROUGH = auto()
UNDERLINE = auto()
URL = auto()
EMAIL = auto()
USER_MENTION = auto()
ROOM_MENTION = auto()
PREFORMATTED = auto()
INLINE_CODE = auto()
BLOCKQUOTE = auto()
HEADER = auto()
COLOR = auto()
SPOILER = auto()
class FormattedString(ABC):
"""FormattedString is an abstract HTML parsing target."""
@abstractmethod
def append(self, *args: str | FormattedString) -> FormattedString:
"""
Append strings to this FormattedString.
This method may mutate the source object, but it is not required to do so.
Make sure to always use the return value when mutating and to duplicate strings if you don't
want the original to change.
Args:
*args: The strings to append.
Returns:
A FormattedString that is a concatenation of this string and the given strings.
"""
pass
@abstractmethod
def prepend(self, *args: str | FormattedString) -> FormattedString:
"""
Prepend strings to this FormattedString.
This method may mutate the source object, but it is not required to do so.
Make sure to always use the return value when mutating and to duplicate strings if you don't
want the original to change.
Args:
*args: The strings to prepend.
Returns:
A FormattedString that is a concatenation of the given strings and this string.
"""
pass
@abstractmethod
def format(self, entity_type: EntityType, **kwargs) -> FormattedString:
"""
Apply formatting to this FormattedString.
This method may mutate the source object, but it is not required to do so.
Make sure to always use the return value when mutating and to duplicate strings if you don't
want the original to change.
Args:
entity_type: The type of formatting to apply to this string.
**kwargs: Additional metadata required by the formatting type.
Returns:
A FormattedString with the given formatting applied.
"""
pass
@abstractmethod
def trim(self) -> FormattedString:
"""
Trim surrounding whitespace from this FormattedString.
This method may mutate the source object, but it is not required to do so.
Make sure to always use the return value when mutating and to duplicate strings if you don't
want the original to change.
Returns:
A FormattedString without surrounding whitespace.
"""
pass
@abstractmethod
def split(self, separator, max_items: int = -1) -> list[FormattedString]:
"""
Split this FormattedString by the given separator.
Args:
separator: The separator to split by.
max_items: The maximum number of items to return. If the limit is reached, the remaining
string will be returned as one even if it contains the separator.
Returns:
The split strings.
"""
pass
@classmethod
def concat(cls, *args: str | FormattedString) -> FormattedString:
"""
Concatenate many FormattedStrings.
Args:
*args: The strings to concatenate.
Returns:
A FormattedString that is a concatenation of the given strings.
"""
return cls.join(items=args, separator="")
@classmethod
@abstractmethod
def join(cls, items: Sequence[str | FormattedString], separator: str = " ") -> FormattedString:
"""
Join a list of FormattedStrings with the given separator.
Args:
items: The strings to join.
separator: The separator to join them with.
Returns:
A FormattedString that is a combination of the given strings with the given separator
between each one.
"""
pass
python-0.20.4/mautrix/util/formatter/html_reader.py 0000664 0000000 0000000 00000004107 14547234302 0022435 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from html.parser import HTMLParser
class HTMLNode(list):
tag: str
text: str
tail: str
attrib: dict[str, str]
def __repr__(self) -> str:
return (
f"HTMLNode(tag='{self.tag}', attrs={self.attrib}, text='{self.text}', "
f"tail='{self.tail}', children={list(self)})"
)
def __init__(self, tag: str, attrs: list[tuple[str, str]]) -> None:
super().__init__()
self.tag = tag
self.text = ""
self.tail = ""
self.attrib = dict(attrs)
class NodeifyingParser(HTMLParser):
# From https://www.w3.org/TR/html5/syntax.html#writing-html-documents-elements
void_tags = (
"area",
"base",
"br",
"col",
"command",
"embed",
"hr",
"img",
"input",
"link",
"meta",
"param",
"source",
"track",
"wbr",
)
stack: list[HTMLNode]
def __init__(self) -> None:
super().__init__()
self.stack = [HTMLNode("html", [])]
def handle_starttag(self, tag: str, attrs: list[tuple[str, str]]) -> None:
node = HTMLNode(tag, attrs)
self.stack[-1].append(node)
if tag not in self.void_tags:
self.stack.append(node)
def handle_startendtag(self, tag, attrs):
self.stack[-1].append(HTMLNode(tag, attrs))
def handle_endtag(self, tag: str) -> None:
if tag == self.stack[-1].tag:
self.stack.pop()
def handle_data(self, data: str) -> None:
if len(self.stack[-1]) > 0:
self.stack[-1][-1].tail += data
else:
self.stack[-1].text += data
def error(self, message: str) -> None:
pass
def read_html(data: str) -> HTMLNode:
parser = NodeifyingParser()
parser.feed(data)
return parser.stack[0]
python-0.20.4/mautrix/util/formatter/html_reader.pyi 0000664 0000000 0000000 00000000706 14547234302 0022607 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
class HTMLNode(list[HTMLNode]):
tag: str
text: str
tail: str
attrib: dict[str, str]
def __init__(self, tag: str, attrs: list[tuple[str, str]]) -> None: ...
def read_html(data: str) -> HTMLNode: ...
python-0.20.4/mautrix/util/formatter/markdown_string.py 0000664 0000000 0000000 00000005120 14547234302 0023353 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import List, Sequence, Union
from .formatted_string import EntityType, FormattedString
class MarkdownString(FormattedString):
text: str
def __init__(self, text: str = "") -> None:
self.text = text
def __str__(self) -> str:
return self.text
def append(self, *args: Union[str, FormattedString]) -> MarkdownString:
self.text += "".join(str(arg) for arg in args)
return self
def prepend(self, *args: Union[str, FormattedString]) -> MarkdownString:
self.text = "".join(str(arg) for arg in args + (self.text,))
return self
def format(self, entity_type: EntityType, **kwargs) -> MarkdownString:
if entity_type == EntityType.BOLD:
self.text = f"**{self.text}**"
elif entity_type == EntityType.ITALIC:
self.text = f"_{self.text}_"
elif entity_type == EntityType.STRIKETHROUGH:
self.text = f"~~{self.text}~~"
elif entity_type == EntityType.SPOILER:
reason = kwargs.get("reason", "")
if reason:
self.text = f"{reason}|{self.text}"
self.text = f"||{self.text}||"
elif entity_type == EntityType.URL:
if kwargs["url"] != self.text:
self.text = f"[{self.text}]({kwargs['url']})"
elif entity_type == EntityType.PREFORMATTED:
self.text = f"```{kwargs['language']}\n{self.text}\n```"
elif entity_type == EntityType.INLINE_CODE:
self.text = f"`{self.text}`"
elif entity_type == EntityType.BLOCKQUOTE:
children = self.trim().split("\n")
children = [child.prepend("> ") for child in children]
self.text = self.join(children, "\n").text
elif entity_type == EntityType.HEADER:
prefix = "#" * kwargs["size"]
self.text = f"{prefix} {self.text}"
return self
def trim(self) -> MarkdownString:
self.text = self.text.strip()
return self
def split(self, separator, max_items: int = -1) -> List[MarkdownString]:
return [MarkdownString(text) for text in self.text.split(separator, max_items)]
@classmethod
def join(
cls, items: Sequence[Union[str, FormattedString]], separator: str = " "
) -> MarkdownString:
return cls(separator.join(str(item) for item in items))
python-0.20.4/mautrix/util/formatter/parser.py 0000664 0000000 0000000 00000026637 14547234302 0021457 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Callable, Generic, Type, TypeVar
import re
from ...types import EventID, MatrixURI, RoomAlias, RoomID, UserID
from .formatted_string import EntityType, FormattedString
from .html_reader import HTMLNode, read_html
from .markdown_string import MarkdownString
class RecursionContext:
preserve_whitespace: bool
ul_depth: int
_inited: bool
def __init__(self, preserve_whitespace: bool = False, ul_depth: int = 0) -> None:
self.preserve_whitespace = preserve_whitespace
self.ul_depth = ul_depth
self._inited = True
def __setattr__(self, key: str, value: Any) -> None:
if getattr(self, "_inited", False) is True:
raise TypeError("'RecursionContext' object is immutable")
super(RecursionContext, self).__setattr__(key, value)
def enter_list(self) -> RecursionContext:
return RecursionContext(
preserve_whitespace=self.preserve_whitespace, ul_depth=self.ul_depth + 1
)
def enter_code_block(self) -> RecursionContext:
return RecursionContext(preserve_whitespace=True, ul_depth=self.ul_depth)
T = TypeVar("T", bound=FormattedString)
spaces = re.compile(r"\s+")
space = " "
class MatrixParser(Generic[T]):
block_tags: tuple[str, ...] = (
"p",
"pre",
"blockquote",
"ol",
"ul",
"li",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"div",
"hr",
"table",
)
list_bullets: tuple[str, ...] = ("●", "○", "■", "‣")
e: Type[EntityType] = EntityType
fs: Type[T] = MarkdownString
read_html: Callable[[str], HTMLNode] = staticmethod(read_html)
ignore_less_relevant_links: bool = True
exclude_plaintext_attrib: str = "data-mautrix-exclude-plaintext"
def list_bullet(self, depth: int) -> str:
return self.list_bullets[(depth - 1) % len(self.list_bullets)] + " "
async def list_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T:
ordered: bool = node.tag == "ol"
tagged_children: list[tuple[T, str]] = await self.node_to_tagged_fstrings(node, ctx)
counter: int = 1
indent_length: int = 0
if ordered:
try:
counter = int(node.attrib.get("start", "1"))
except ValueError:
counter = 1
longest_index = counter - 1 + len(tagged_children)
indent_length = len(str(longest_index))
indent: str = (indent_length + 2) * " "
children: list[T] = []
for child, tag in tagged_children:
if tag != "li":
continue
if ordered:
prefix = f"{counter}. "
counter += 1
else:
prefix = self.list_bullet(ctx.ul_depth)
child = child.prepend(prefix)
parts = child.split("\n")
parts = parts[:1] + [part.prepend(indent) for part in parts[1:]]
child = self.fs.join(parts, "\n")
children.append(child)
return self.fs.join(children, "\n")
async def blockquote_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T:
msg = await self.tag_aware_parse_node(node, ctx)
return msg.format(self.e.BLOCKQUOTE)
async def hr_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T:
return self.fs("---")
async def header_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T:
children = await self.node_to_fstrings(node, ctx)
length = int(node.tag[1])
return self.fs.join(children, "").format(self.e.HEADER, size=length)
async def basic_format_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T:
msg = await self.tag_aware_parse_node(node, ctx)
if self.exclude_plaintext_attrib in node.attrib:
return msg
if node.tag in ("b", "strong"):
msg = msg.format(self.e.BOLD)
elif node.tag in ("i", "em"):
msg = msg.format(self.e.ITALIC)
elif node.tag in ("s", "strike", "del"):
msg = msg.format(self.e.STRIKETHROUGH)
elif node.tag in ("u", "ins"):
msg = msg.format(self.e.UNDERLINE)
return msg
async def link_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T:
msg = await self.tag_aware_parse_node(node, ctx)
href = node.attrib.get("href", "")
if not href:
return msg
if href.startswith("mailto:"):
return self.fs(href[len("mailto:") :]).format(self.e.EMAIL)
matrix_uri = MatrixURI.try_parse(href)
if matrix_uri:
if matrix_uri.user_id:
new_msg = await self.user_pill_to_fstring(msg, matrix_uri.user_id)
elif matrix_uri.event_id:
new_msg = await self.event_link_to_fstring(
msg, matrix_uri.room_id or matrix_uri.room_alias, matrix_uri.event_id
)
elif matrix_uri.room_alias:
new_msg = await self.room_pill_to_fstring(msg, matrix_uri.room_alias)
elif matrix_uri.room_id:
new_msg = await self.room_id_link_to_fstring(msg, matrix_uri.room_id)
else:
new_msg = None
if new_msg:
return new_msg
# Custom attribute to tell the parser that the link isn't relevant and
# shouldn't be included in plaintext representation.
if self.ignore_less_relevant_links and self.exclude_plaintext_attrib in node.attrib:
return msg
return await self.url_to_fstring(msg, href)
async def url_to_fstring(self, msg: T, url: str) -> T | None:
return msg.format(self.e.URL, url=url)
async def user_pill_to_fstring(self, msg: T, user_id: UserID) -> T | None:
return msg.format(self.e.USER_MENTION, user_id=user_id)
async def room_pill_to_fstring(self, msg: T, room_alias: RoomAlias) -> T | None:
return None
async def room_id_link_to_fstring(self, msg: T, room_id: RoomID) -> T | None:
return None
async def event_link_to_fstring(
self, msg: T, room: RoomID | RoomAlias, event_id: EventID
) -> T | None:
return None
async def img_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T:
return self.fs(node.attrib.get("alt") or node.attrib.get("title") or "")
async def custom_node_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T | None:
return None
async def color_to_fstring(self, msg: T, color: str) -> T:
return msg.format(self.e.COLOR, color=color)
async def spoiler_to_fstring(self, msg: T, reason: str) -> T:
return msg.format(self.e.SPOILER, reason=reason)
async def node_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T:
custom = await self.custom_node_to_fstring(node, ctx)
if custom:
return custom
elif node.tag == "mx-reply":
return self.fs("")
elif node.tag == "blockquote":
return await self.blockquote_to_fstring(node, ctx)
elif node.tag == "hr":
return await self.hr_to_fstring(node, ctx)
elif node.tag == "ol":
return await self.list_to_fstring(node, ctx)
elif node.tag == "ul":
return await self.list_to_fstring(node, ctx.enter_list())
elif node.tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
return await self.header_to_fstring(node, ctx)
elif node.tag == "br":
return self.fs("\n")
elif node.tag in ("b", "strong", "i", "em", "s", "del", "u", "ins"):
return await self.basic_format_to_fstring(node, ctx)
elif node.tag == "a":
return await self.link_to_fstring(node, ctx)
elif node.tag == "img":
return await self.img_to_fstring(node, ctx)
elif node.tag == "p":
return (await self.tag_aware_parse_node(node, ctx)).append("\n")
elif node.tag in ("font", "span"):
msg = await self.tag_aware_parse_node(node, ctx)
try:
spoiler = node.attrib["data-mx-spoiler"]
except KeyError:
pass
else:
msg = await self.spoiler_to_fstring(msg, spoiler)
try:
color = node.attrib["color"]
except KeyError:
try:
color = node.attrib["data-mx-color"]
except KeyError:
color = None
if color:
msg = await self.color_to_fstring(msg, color)
return msg
elif node.tag == "pre":
lang = ""
try:
if node[0].tag == "code":
node = node[0]
lang = node.attrib["class"][len("language-") :]
except (IndexError, KeyError):
pass
return (await self.parse_node(node, ctx.enter_code_block())).format(
self.e.PREFORMATTED, language=lang
)
elif node.tag == "code":
return (await self.parse_node(node, ctx.enter_code_block())).format(self.e.INLINE_CODE)
return await self.tag_aware_parse_node(node, ctx)
async def text_to_fstring(
self, text: str, ctx: RecursionContext, strip_leading_whitespace: bool = False
) -> T:
if not ctx.preserve_whitespace:
text = spaces.sub(space, text.lstrip() if strip_leading_whitespace else text)
return self.fs(text)
async def node_to_tagged_fstrings(
self, node: HTMLNode, ctx: RecursionContext
) -> list[tuple[T, str]]:
output = []
if node.text:
output.append((await self.text_to_fstring(node.text, ctx), "text"))
for child in node:
output.append((await self.node_to_fstring(child, ctx), child.tag))
if child.tail:
# For text following a block tag, the leading whitespace is meaningless (there'll
# be a newline added later), but for other tags it can be interpreted as a space.
text = await self.text_to_fstring(
child.tail, ctx, strip_leading_whitespace=child.tag in self.block_tags
)
output.append((text, "text"))
return output
async def node_to_fstrings(self, node: HTMLNode, ctx: RecursionContext) -> list[T]:
return [msg for (msg, tag) in await self.node_to_tagged_fstrings(node, ctx)]
async def tag_aware_parse_node(self, node: HTMLNode, ctx: RecursionContext) -> T:
msgs = await self.node_to_tagged_fstrings(node, ctx)
output = self.fs()
prev_was_block = False
for msg, tag in msgs:
if tag in self.block_tags:
msg = msg.append("\n")
if not prev_was_block:
msg = msg.prepend("\n")
prev_was_block = True
output = output.append(msg)
return output.trim()
async def parse_node(self, node: HTMLNode, ctx: RecursionContext) -> T:
return self.fs.join(await self.node_to_fstrings(node, ctx))
async def parse(self, data: str) -> T:
msg = await self.node_to_fstring(
self.read_html(f"{data}"), RecursionContext()
)
return msg
python-0.20.4/mautrix/util/formatter/parser_test.py 0000664 0000000 0000000 00000004011 14547234302 0022474 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import pytest
from . import parse_html
async def test_basic_markdown() -> None:
tests = {
"test": "**test**",
"test!": "**t_e~~s~~t_!**",
"example": "[example](https://example.com)",
"div {\n display: none;\n}
": "```css\ndiv {\n display: none;\n}\n```",
"hello
": "`hello`",
"Testing
123
": "> Testing\n> 123",
"": "● test\n● foo\n● bar",
"- test
\n- foo
\n- bar
\n
": "123. test\n124. foo\n125. bar",
"header
": "#### header",
"spoiler?": "||spoiler?||",
"not really": "||SPOILER!|not really||",
}
for html, markdown_ish in tests.items():
assert await parse_html(html) == markdown_ish
async def test_nested_markdown() -> None:
input_html = """
Hello, World!
- example
-
-
def random() -> int:
if 4 is 1:
return 5
return 4
- Just some text
""".strip()
expected_output = """
# Hello, World!
> 1. [example](https://example.com)
> 2. ● item 1
> ● item 2
> 3. ```python
> def random() -> int:
> if 4 is 1:
> return 5
> return 4
> ```
> 4. **Just some text**
""".strip()
assert await parse_html(input_html) == expected_output
python-0.20.4/mautrix/util/logging/ 0000775 0000000 0000000 00000000000 14547234302 0017216 5 ustar 00root root 0000000 0000000 python-0.20.4/mautrix/util/logging/__init__.py 0000664 0000000 0000000 00000000216 14547234302 0021326 0 ustar 00root root 0000000 0000000 from .color import ColorFormatter
from .trace import SILLY, TRACE, TraceLogger
__all__ = ["ColorFormatter", "TraceLogger", "SILLY", "TRACE"]
python-0.20.4/mautrix/util/logging/color.py 0000664 0000000 0000000 00000003476 14547234302 0020720 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from copy import copy
from logging import Formatter, LogRecord
PREFIX = "\033["
RESET = PREFIX + "0m"
MAU_COLOR = PREFIX + "32m" # green
AIOHTTP_COLOR = PREFIX + "36m" # cyan
MXID_COLOR = PREFIX + "33m" # yellow
LEVEL_COLORS = {
"DEBUG": "37m", # white
"INFO": "36m", # cyan
"WARNING": "33;1m", # yellow
"ERROR": "31;1m", # red
"CRITICAL": f"37;1m{PREFIX}41m", # white on red bg
}
LEVELNAME_OVERRIDE = {
name: f"{PREFIX}{color}{name}{RESET}" for name, color in LEVEL_COLORS.items()
}
class ColorFormatter(Formatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _color_name(self, module: str) -> str:
as_api = "mau.as.api"
if module.startswith(as_api):
return f"{MAU_COLOR}{as_api}{RESET}.{MXID_COLOR}{module[len(as_api) + 1:]}{RESET}"
elif module.startswith("mau."):
try:
next_dot = module.index(".", len("mau."))
return (
f"{MAU_COLOR}{module[:next_dot]}{RESET}"
f".{MXID_COLOR}{module[next_dot+1:]}{RESET}"
)
except ValueError:
return MAU_COLOR + module + RESET
elif module.startswith("aiohttp"):
return AIOHTTP_COLOR + module + RESET
return module
def format(self, record: LogRecord):
colored_record: LogRecord = copy(record)
colored_record.name = self._color_name(record.name)
colored_record.levelname = LEVELNAME_OVERRIDE.get(record.levelname, record.levelname)
return super().format(colored_record)
python-0.20.4/mautrix/util/logging/trace.py 0000664 0000000 0000000 00000001615 14547234302 0020671 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Type, cast
import logging
TRACE = logging.TRACE = 5
logging.addLevelName(TRACE, "TRACE")
SILLY = logging.SILLY = 1
logging.addLevelName(SILLY, "SILLY")
OldLogger: Type[logging.Logger] = cast(Type[logging.Logger], logging.getLoggerClass())
class TraceLogger(OldLogger):
def trace(self, msg, *args, **kwargs) -> None:
self.log(TRACE, msg, *args, **kwargs)
def silly(self, msg, *args, **kwargs) -> None:
self.log(SILLY, msg, *args, **kwargs)
def getChild(self, suffix: str) -> TraceLogger:
return cast(TraceLogger, super().getChild(suffix))
logging.setLoggerClass(TraceLogger)
python-0.20.4/mautrix/util/magic.py 0000664 0000000 0000000 00000002661 14547234302 0017227 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import functools
import magic
try:
_from_buffer = functools.partial(magic.from_buffer, mime=True)
_from_filename = functools.partial(magic.from_file, mime=True)
except AttributeError:
_from_buffer = lambda data: magic.detect_from_content(data).mime_type
_from_filename = lambda file: magic.detect_from_filename(file).mime_type
def mimetype(data: bytes | bytearray | str) -> str:
"""
Uses magic to determine the mimetype of a file on disk or in memory.
Supports both libmagic's Python bindings and the python-magic package.
Args:
data: The file data, either as in-memory bytes or a path to the file as a string.
Returns:
The mime type as a string.
"""
if isinstance(data, str):
return _from_filename(data)
elif isinstance(data, bytes):
return _from_buffer(data)
elif isinstance(data, bytearray):
# Magic doesn't like bytearrays directly, so just copy the first 1024 bytes for it.
return _from_buffer(bytes(data[:1024]))
else:
raise TypeError(
f"mimetype() argument must be a string or bytes, not {type(data).__name__!r}"
)
__all__ = ["mimetype"]
python-0.20.4/mautrix/util/manhole.py 0000664 0000000 0000000 00000024765 14547234302 0017603 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# Based on https://github.com/nhoad/aiomanhole Copyright (c) 2014, Nathan Hoad
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from abc import ABC, abstractmethod
from io import BytesIO, StringIO
from socket import SOL_SOCKET
from types import CodeType
import ast
import asyncio
import codeop
import contextlib
import functools
import inspect
import logging
import os
import pwd
import struct
import sys
import traceback
try:
from socket import SO_PEERCRED
except ImportError:
SO_PEERCRED = None
log = logging.getLogger("mau.manhole")
TOP_LEVEL_AWAIT = sys.version_info >= (3, 8)
ASYNC_EVAL_WRAPPER: str = """
async def __eval_async_expr():
try:
pass
finally:
globals().update(locals())
"""
def compile_async(tree: ast.AST) -> CodeType:
flags = 0
if TOP_LEVEL_AWAIT:
flags += ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
node_to_compile = tree
else:
insert_returns(tree.body)
wrapper_node: ast.AST = ast.parse(ASYNC_EVAL_WRAPPER, "", "single")
method_stmt = wrapper_node.body[0]
try_stmt = method_stmt.body[0]
try_stmt.body = tree.body
node_to_compile = wrapper_node
return compile(node_to_compile, "", "single", flags=flags)
# From https://gist.github.com/nitros12/2c3c265813121492655bc95aa54da6b9
def insert_returns(body: List[ast.AST]) -> None:
if isinstance(body[-1], ast.Expr):
body[-1] = ast.Return(body[-1].value)
ast.fix_missing_locations(body[-1])
elif isinstance(body[-1], ast.If):
insert_returns(body[-1].body)
insert_returns(body[-1].orelse)
elif isinstance(body[-1], (ast.With, ast.AsyncWith)):
insert_returns(body[-1].body)
class StatefulCommandCompiler(codeop.CommandCompiler):
"""A command compiler that buffers input until a full command is available."""
buf: BytesIO
def __init__(self) -> None:
super().__init__()
self.compiler = functools.partial(
compile,
optimize=1,
flags=(
ast.PyCF_ONLY_AST
| codeop.PyCF_DONT_IMPLY_DEDENT
| codeop.PyCF_ALLOW_INCOMPLETE_INPUT
),
)
self.buf = BytesIO()
def is_partial_command(self) -> bool:
return bool(self.buf.getvalue())
def __call__(self, source: bytes, **kwargs: Any) -> Optional[CodeType]:
buf = self.buf
if self.is_partial_command():
buf.write(b"\n")
buf.write(source)
code = self.buf.getvalue().decode("utf-8")
codeobj = super().__call__(code, **kwargs)
if codeobj:
self.reset()
return compile_async(codeobj)
return None
def reset(self) -> None:
self.buf.seek(0)
self.buf.truncate(0)
class Interpreter(ABC):
@abstractmethod
def __init__(self, namespace: Dict[str, Any], banner: Union[bytes, str]) -> None:
pass
@abstractmethod
def close(self) -> None:
pass
@abstractmethod
async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
pass
class AsyncInterpreter(Interpreter):
"""An interactive asynchronous interpreter."""
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
namespace: Dict[str, Any]
banner: bytes
compiler: StatefulCommandCompiler
running: bool
def __init__(self, namespace: Dict[str, Any], banner: Union[bytes, str]) -> None:
super().__init__(namespace, banner)
self.namespace = namespace
self.banner = banner if isinstance(banner, bytes) else str(banner).encode("utf-8")
self.compiler = StatefulCommandCompiler()
async def send_exception(self) -> None:
"""When an exception has occurred, write the traceback to the user."""
self.compiler.reset()
exc = traceback.format_exc()
self.writer.write(exc.encode("utf-8"))
await self.writer.drain()
async def execute(self, codeobj: CodeType) -> Tuple[Any, str]:
with contextlib.redirect_stdout(StringIO()) as buf:
if TOP_LEVEL_AWAIT:
value = eval(codeobj, self.namespace)
if codeobj.co_flags & inspect.CO_COROUTINE:
value = await value
else:
exec(codeobj, self.namespace)
value = await eval("__eval_async_expr()", self.namespace)
return value, buf.getvalue()
async def handle_one_command(self) -> None:
"""Process a single command. May have many lines."""
while True:
await self.write_prompt()
codeobj = await self.read_command()
if codeobj is not None:
await self.run_command(codeobj)
return
async def run_command(self, codeobj: CodeType) -> None:
"""Execute a compiled code object, and write the output back to the client."""
try:
value, stdout = await self.execute(codeobj)
except Exception:
await self.send_exception()
return
else:
await self.send_output(value, stdout)
async def write_prompt(self) -> None:
writer = self.writer
if self.compiler.is_partial_command():
writer.write(b"... ")
else:
writer.write(b">>> ")
await writer.drain()
async def read_command(self) -> Optional[CodeType]:
"""Read a command from the user line by line.
Returns a code object suitable for execution.
"""
reader = self.reader
line = await reader.readline()
if line == b"":
raise ConnectionResetError()
try:
# skip the newline to make CommandCompiler work as advertised
codeobj = self.compiler(line.rstrip(b"\n"))
except SyntaxError:
await self.send_exception()
return None
return codeobj
async def send_output(self, value: str, stdout: str) -> None:
"""Write the output or value of the expression back to user.
>>> 5
5
>>> print('cash rules everything around me')
cash rules everything around me
"""
writer = self.writer
if value is not None:
writer.write(f"{value!r}\n".encode("utf-8"))
if stdout:
writer.write(stdout.encode("utf-8"))
await writer.drain()
def close(self) -> None:
if self.running:
self.writer.close()
self.running = False
async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
"""Main entry point for an interpreter session with a single client."""
self.reader = reader
self.writer = writer
self.running = True
if self.banner:
writer.write(self.banner)
await writer.drain()
while self.running:
try:
await self.handle_one_command()
except ConnectionResetError:
writer.close()
self.running = False
break
except Exception:
log.exception("Exception in manhole REPL")
self.writer.write(traceback.format_exc())
await self.writer.drain()
class InterpreterFactory:
namespace: Dict[str, Any]
banner: bytes
interpreter_class: Type[Interpreter]
clients: List[Interpreter]
whitelist: Set[int]
_conn_id: int
def __init__(
self,
namespace: Dict[str, Any],
banner: Union[bytes, str],
interpreter_class: Type[Interpreter],
whitelist: Set[int],
) -> None:
self.namespace = namespace or {}
self.banner = banner
self.interpreter_class = interpreter_class
self.clients = []
self.whitelist = whitelist
self._conn_id = 0
@property
def conn_id(self) -> int:
self._conn_id += 1
return self._conn_id
async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
sock = writer.transport.get_extra_info("socket")
# TODO support non-linux OSes
# I think FreeBSD uses SCM_CREDS
creds = sock.getsockopt(SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i"))
pid, uid, gid = struct.unpack("3i", creds)
user_info = pwd.getpwuid(uid)
username = f"{user_info.pw_name} ({uid})" if user_info and user_info.pw_name else uid
if len(self.whitelist) > 0 and uid not in self.whitelist:
writer.write(b"You are not whitelisted to use the manhole.")
log.warning(f"Non-whitelisted user {username} tried to connect from PID {pid}")
await writer.drain()
writer.close()
return
namespace = {**self.namespace}
interpreter = self.interpreter_class(namespace=namespace, banner=self.banner)
namespace["exit"] = interpreter.close
self.clients.append(interpreter)
conn_id = self.conn_id
log.info(f"Manhole connection OPENED: {conn_id} from PID {pid} by {username}")
await asyncio.create_task(interpreter(reader, writer))
log.info(f"Manhole connection CLOSED: {conn_id} from PID {pid} by {username}")
self.clients.remove(interpreter)
async def start_manhole(
path: str,
banner: str = "",
namespace: Optional[Dict[str, Any]] = None,
loop: asyncio.AbstractEventLoop = None,
whitelist: Set[int] = None,
) -> Tuple[asyncio.AbstractServer, Callable[[], None]]:
"""
Starts a manhole server on a given UNIX address.
Args:
path: The path to create the UNIX socket at.
banner: The banner to show when clients connect.
namespace: The globals to provide to connected clients.
loop: The asyncio event loop to use.
whitelist: List of user IDs to allow connecting.
"""
if not SO_PEERCRED:
raise ValueError("SO_PEERCRED is not supported on this platform")
factory = InterpreterFactory(
namespace=namespace,
banner=banner,
interpreter_class=AsyncInterpreter,
whitelist=whitelist,
)
server = await asyncio.start_unix_server(factory, path=path)
os.chmod(path, 0o666)
def stop():
for client in factory.clients:
client.close()
server.close()
return server, stop
python-0.20.4/mautrix/util/markdown.py 0000664 0000000 0000000 00000002105 14547234302 0017762 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import commonmark
class HtmlEscapingRenderer(commonmark.HtmlRenderer):
def __init__(self, allow_html: bool = False):
super().__init__()
self.allow_html = allow_html
def lit(self, s):
if self.allow_html:
return super().lit(s)
return super().lit(s.replace("<", "<").replace(">", ">"))
def image(self, node, entering):
prev = self.allow_html
self.allow_html = True
super().image(node, entering)
self.allow_html = prev
md_parser = commonmark.Parser()
yes_html_renderer = commonmark.HtmlRenderer()
no_html_renderer = HtmlEscapingRenderer()
def render(message: str, allow_html: bool = False) -> str:
parsed = md_parser.parse(message)
if allow_html:
return yes_html_renderer.render(parsed)
else:
return no_html_renderer.render(parsed)
python-0.20.4/mautrix/util/message_send_checkpoint.py 0000664 0000000 0000000 00000006210 14547234302 0023005 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Sumner Evans
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
import logging
from aiohttp.client import ClientTimeout
from attr import dataclass
import aiohttp
from mautrix.api import HTTPAPI
from mautrix.types import EventType, MessageType, SerializableAttrs, SerializableEnum
class MessageSendCheckpointStep(SerializableEnum):
CLIENT = "CLIENT"
HOMESERVER = "HOMESERVER"
BRIDGE = "BRIDGE"
DECRYPTED = "DECRYPTED"
REMOTE = "REMOTE"
COMMAND = "COMMAND"
class MessageSendCheckpointStatus(SerializableEnum):
SUCCESS = "SUCCESS"
WILL_RETRY = "WILL_RETRY"
PERM_FAILURE = "PERM_FAILURE"
UNSUPPORTED = "UNSUPPORTED"
TIMEOUT = "TIMEOUT"
DELIVERY_FAILED = "DELIVERY_FAILED"
class MessageSendCheckpointReportedBy(SerializableEnum):
ASMUX = "ASMUX"
BRIDGE = "BRIDGE"
@dataclass
class MessageSendCheckpoint(SerializableAttrs):
event_id: str
room_id: str
step: MessageSendCheckpointStep
timestamp: int
status: MessageSendCheckpointStatus
event_type: EventType
reported_by: MessageSendCheckpointReportedBy
retry_num: int = 0
message_type: Optional[MessageType] = None
info: Optional[str] = None
client_type: Optional[str] = None
client_version: Optional[str] = None
async def send(self, endpoint: str, as_token: str, log: logging.Logger) -> None:
if not endpoint:
return
try:
headers = {"Authorization": f"Bearer {as_token}", "User-Agent": HTTPAPI.default_ua}
async with aiohttp.ClientSession() as sess, sess.post(
endpoint,
json={"checkpoints": [self.serialize()]},
headers=headers,
timeout=ClientTimeout(30),
) as resp:
if not 200 <= resp.status < 300:
text = await resp.text()
text = text.replace("\n", "\\n")
log.warning(
f"Unexpected status code {resp.status} sending checkpoint "
f"for {self.event_id} ({self.step}/{self.status}): {text}"
)
else:
log.info(
f"Successfully sent checkpoint for {self.event_id} "
f"({self.step}/{self.status})"
)
except Exception as e:
log.warning(
f"Failed to send checkpoint for {self.event_id} ({self.step}/{self.status}): "
f"{type(e).__name__}: {e}"
)
CHECKPOINT_TYPES = {
EventType.ROOM_REDACTION,
EventType.ROOM_MESSAGE,
EventType.ROOM_ENCRYPTED,
EventType.ROOM_MEMBER,
EventType.ROOM_NAME,
EventType.ROOM_AVATAR,
EventType.ROOM_TOPIC,
EventType.STICKER,
EventType.REACTION,
EventType.CALL_INVITE,
EventType.CALL_CANDIDATES,
EventType.CALL_SELECT_ANSWER,
EventType.CALL_ANSWER,
EventType.CALL_HANGUP,
EventType.CALL_REJECT,
EventType.CALL_NEGOTIATE,
}
python-0.20.4/mautrix/util/opt_prometheus.py 0000664 0000000 0000000 00000003622 14547234302 0021222 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, cast
class _NoopPrometheusEntity:
"""NoopPrometheusEntity is a class that can be used as a no-op placeholder for prometheus
metrics objects when prometheus_client isn't installed."""
def __init__(self, *args, **kwargs):
pass
def __call__(self, *args, **kwargs):
if not kwargs and len(args) == 1 and callable(args[0]):
return args[0]
return self
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def __getattr__(self, item):
return self
try:
from prometheus_client import Counter, Enum, Gauge, Histogram, Info, Summary
is_installed = True
except ImportError:
Counter = Gauge = Summary = Histogram = Info = Enum = cast(Any, _NoopPrometheusEntity)
is_installed = False
def async_time(metric: Gauge | Summary | Histogram):
"""
Measure the time that each execution of the decorated async function takes.
This is equivalent to the ``time`` method-decorator in the metrics, but
supports async functions.
Args:
metric: The metric instance to store the measures in.
"""
if not hasattr(metric, "time") or not callable(metric.time):
raise ValueError("async_time only supports metrics that support timing")
def decorator(fn):
async def wrapper(*args, **kwargs):
with metric.time():
return await fn(*args, **kwargs)
return wrapper if is_installed else fn
return decorator
__all__ = [
"Counter",
"Gauge",
"Summary",
"Histogram",
"Info",
"Enum",
"async_time",
"is_installed",
]
python-0.20.4/mautrix/util/opt_prometheus.pyi 0000664 0000000 0000000 00000006002 14547234302 0021366 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Callable, Generic, Iterable, TypeVar
T = TypeVar("T")
Number = int | float
class Metric:
name: str
documentation: str
unit: str
typ: str
samples: list[Any]
def add_sample(
self,
name: str,
labels: Iterable[str],
value: Any,
timestamp: Any = None,
exemplar: Any = None,
) -> None: ...
class MetricWrapperBase(Generic[T]):
def __init__(
self,
name: str,
documentation: str,
labelnames: Iterable[str] = (),
namespace: str = "",
subsystem: str = "",
unit: str = "",
registry: Any = None,
labelvalues: Any = None,
) -> None: ...
def describe(self) -> list[Metric]: ...
def collect(self) -> list[Metric]: ...
def labels(self, *labelvalues, **labelkwargs) -> T: ...
def remove(self, *labelvalues) -> None: ...
class ContextManager:
def __enter__(self) -> None: ...
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
def __call__(self, f) -> None: ...
class Counter(MetricWrapperBase[Counter]):
def inc(self, amount: Number = 1) -> None: ...
def count_exceptions(self, exception: Exception = Exception) -> ContextManager: ...
class Gauge(MetricWrapperBase[Gauge]):
def inc(self, amount: Number = 1) -> None: ...
def dec(self, amount: Number = 1) -> None: ...
def set(self, value: Number = 1) -> None: ...
def set_to_current_time(self) -> None: ...
def track_inprogress(self) -> ContextManager: ...
def time(self) -> ContextManager: ...
def set_function(self, f: Callable[[], Number]) -> None: ...
class Summary(MetricWrapperBase[Summary]):
def observe(self, amount: Number) -> None: ...
def time(self) -> ContextManager: ...
class Histogram(MetricWrapperBase[Histogram]):
def __init__(
self,
name: str,
documentation: str,
labelnames: Iterable[str] = (),
namespace: str = "",
subsystem: str = "",
unit: str = "",
registry: Any = None,
labelvalues: Any = None,
buckets: Iterable[Number] = (),
) -> None: ...
def observe(self, amount: Number = 1) -> None: ...
def time(self) -> ContextManager: ...
class Info(MetricWrapperBase[Info]):
def info(self, val: dict[str, str]) -> None: ...
class Enum(MetricWrapperBase[Enum]):
def __init__(
self,
name: str,
documentation: str,
labelnames: Iterable[str] = (),
namespace: str = "",
subsystem: str = "",
unit: str = "",
registry: Any = None,
labelvalues: Any = None,
states: Iterable[str] = None,
) -> None: ...
def state(self, state: str) -> None: ...
def async_time(metric: Gauge | Summary | Histogram) -> Callable[[Callable], Callable]: ...
python-0.20.4/mautrix/util/program.py 0000664 0000000 0000000 00000024035 14547234302 0017615 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, AsyncIterable, Awaitable, Iterable, Union, cast
from itertools import chain
from time import time
import argparse
import asyncio
import copy
import inspect
import logging
import logging.config
import signal
import sys
from .config import BaseFileConfig, BaseMissingError, BaseValidatableConfig, ConfigValueError
from .logging import TraceLogger
try:
import uvloop
except ImportError:
uvloop = None
try:
import prometheus_client as prometheus
except ImportError:
prometheus = None
NewTask = Union[Awaitable[Any], Iterable[Awaitable[Any]], AsyncIterable[Awaitable[Any]]]
TaskList = Iterable[Awaitable[Any]]
class Program:
"""
A generic main class for programs that handles argument parsing, config loading, logger setup
and general startup/shutdown lifecycle.
"""
loop: asyncio.AbstractEventLoop
log: TraceLogger
parser: argparse.ArgumentParser
args: argparse.Namespace
config_class: type[BaseFileConfig]
config: BaseFileConfig
startup_actions: TaskList
shutdown_actions: TaskList
module: str
name: str
version: str
command: str
description: str
def __init__(
self,
module: str | None = None,
name: str | None = None,
description: str | None = None,
command: str | None = None,
version: str | None = None,
config_class: type[BaseFileConfig] | None = None,
) -> None:
if module:
self.module = module
if name:
self.name = name
if description:
self.description = description
if command:
self.command = command
if version:
self.version = version
if config_class:
self.config_class = config_class
self.startup_actions = []
self.shutdown_actions = []
self._automatic_prometheus = True
def run(self) -> None:
"""
Prepare and run the program. This is the main entrypoint and the only function that should
be called manually.
"""
self._prepare()
self._run()
def _prepare(self) -> None:
start_ts = time()
self.preinit()
self.log.info(f"Initializing {self.name} {self.version}")
try:
self.prepare()
except Exception:
self.log.critical("Unexpected error in initialization", exc_info=True)
sys.exit(1)
end_ts = time()
self.log.info(f"Initialization complete in {round(end_ts - start_ts, 2)} seconds")
def preinit(self) -> None:
"""
First part of startup: parse command-line arguments, load and update config, prepare logger.
Exceptions thrown here will crash the program immediately. Asyncio must not be used at this
stage, as the loop is only initialized later.
"""
self.prepare_arg_parser()
self.args = self.parser.parse_args()
self.prepare_config()
self.prepare_log()
self.check_config()
@property
def base_config_path(self) -> str:
return f"pkg://{self.module}/example-config.yaml"
def prepare_arg_parser(self) -> None:
"""Pre-init lifecycle method. Extend this if you want custom command-line arguments."""
self.parser = argparse.ArgumentParser(description=self.description, prog=self.command)
self.parser.add_argument(
"-c",
"--config",
type=str,
default="config.yaml",
metavar="",
help="the path to your config file",
)
self.parser.add_argument(
"-n", "--no-update", action="store_true", help="Don't save updated config to disk"
)
def prepare_config(self) -> None:
"""Pre-init lifecycle method. Extend this if you want to customize config loading."""
self.config = self.config_class(self.args.config, self.base_config_path)
self.load_and_update_config()
def load_and_update_config(self) -> None:
self.config.load()
try:
self.config.update(save=not self.args.no_update)
except BaseMissingError:
print(
"Failed to read base config from the default path "
f"({self.base_config_path}). Maybe your installation is corrupted?"
)
sys.exit(12)
def check_config(self) -> None:
"""Pre-init lifecycle method. Extend this if you want to customize config validation."""
if not isinstance(self.config, BaseValidatableConfig):
return
try:
self.config.check_default_values()
except ConfigValueError as e:
self.log.fatal(f"Configuration error: {e}")
sys.exit(11)
def prepare_log(self) -> None:
"""Pre-init lifecycle method. Extend this if you want to customize logging setup."""
logging.config.dictConfig(copy.deepcopy(self.config["logging"]))
self.log = cast(TraceLogger, logging.getLogger("mau.init"))
def prepare(self) -> None:
"""
Lifecycle method where the primary program initialization happens.
Use this to fill startup_actions with async startup tasks.
"""
self.prepare_loop()
def prepare_loop(self) -> None:
"""Init lifecycle method where the asyncio event loop is created."""
if uvloop is not None:
uvloop.install()
self.log.debug("Using uvloop for asyncio")
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def start_prometheus(self) -> None:
try:
enabled = self.config["metrics.enabled"]
listen_port = self.config["metrics.listen_port"]
except KeyError:
return
if not enabled:
return
elif not prometheus:
self.log.warning(
"Metrics are enabled in config, but prometheus_client is not installed"
)
return
prometheus.start_http_server(listen_port)
def _run(self) -> None:
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
self._stop_task = self.loop.create_future()
exit_code = 0
try:
self.log.debug("Running startup actions...")
start_ts = time()
self.loop.run_until_complete(self.start())
end_ts = time()
self.log.info(
f"Startup actions complete in {round(end_ts - start_ts, 2)} seconds, "
"now running forever"
)
exit_code = self.loop.run_until_complete(self._stop_task)
self.log.debug("manual_stop() called, stopping...")
except KeyboardInterrupt:
self.log.debug("Interrupt received, stopping...")
except Exception:
self.log.critical("Unexpected error in main event loop", exc_info=True)
self.loop.run_until_complete(self.system_exit())
sys.exit(2)
except SystemExit:
self.loop.run_until_complete(self.system_exit())
raise
self.prepare_stop()
self.loop.run_until_complete(self.stop())
self.prepare_shutdown()
self.loop.close()
asyncio.set_event_loop(None)
self.log.info("Everything stopped, shutting down")
sys.exit(exit_code)
async def system_exit(self) -> None:
"""Lifecycle method that is called if the main event loop exits using ``sys.exit()``."""
async def start(self) -> None:
"""
First lifecycle method called inside the asyncio event loop. Extend this if you want more
control over startup than just filling startup_actions in the prepare step.
"""
if self._automatic_prometheus:
self.start_prometheus()
await asyncio.gather(*(self.startup_actions or []))
def prepare_stop(self) -> None:
"""
Lifecycle method that is called before awaiting :meth:`stop`.
Useful for filling shutdown_actions.
"""
async def stop(self) -> None:
"""
Lifecycle method used to stop things that need awaiting to stop. Extend this if you want
more control over shutdown than just filling shutdown_actions in the prepare_stop method.
"""
await asyncio.gather(*(self.shutdown_actions or []))
def prepare_shutdown(self) -> None:
"""Lifecycle method that is called right before ``sys.exit(0)``."""
def manual_stop(self, exit_code: int = 0) -> None:
"""Tell the event loop to cleanly stop and run the stop lifecycle steps."""
self._stop_task.set_result(exit_code)
def add_startup_actions(self, *actions: NewTask) -> None:
self.startup_actions = self._add_actions(self.startup_actions, actions)
def add_shutdown_actions(self, *actions: NewTask) -> None:
self.shutdown_actions = self._add_actions(self.shutdown_actions, actions)
@staticmethod
async def _unpack_async_iterator(iterable: AsyncIterable[Awaitable[Any]]) -> None:
tasks = []
async for task in iterable:
if inspect.isawaitable(task):
tasks.append(asyncio.create_task(task))
await asyncio.gather(*tasks)
def _add_actions(self, to: TaskList, add: tuple[NewTask, ...]) -> TaskList:
for item in add:
if inspect.isasyncgen(item):
to.append(self._unpack_async_iterator(item))
elif inspect.isawaitable(item):
if isinstance(to, list):
to.append(item)
else:
to = chain(to, [item])
elif isinstance(item, list):
if isinstance(to, list):
to += item
else:
to = chain(to, item)
else:
to = chain(to, item)
return to
python-0.20.4/mautrix/util/proxy.py 0000664 0000000 0000000 00000007506 14547234302 0017333 0 ustar 00root root 0000000 0000000 from __future__ import annotations
from typing import Awaitable, Callable, TypeVar
import asyncio
import json
import logging
import time
import urllib.request
from aiohttp import ClientConnectionError
from yarl import URL
from mautrix.util.logging import TraceLogger
try:
from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError
except ImportError:
class ProxyError(Exception):
pass
ProxyConnectionError = ProxyTimeoutError = ProxyError
RETRYABLE_PROXY_EXCEPTIONS = (
ProxyError,
ProxyTimeoutError,
ProxyConnectionError,
ClientConnectionError,
ConnectionError,
asyncio.TimeoutError,
)
class ProxyHandler:
current_proxy_url: str | None = None
log = logging.getLogger("mau.proxy")
def __init__(self, api_url: str | None) -> None:
self.api_url = api_url
def get_proxy_url_from_api(self, reason: str | None = None) -> str | None:
assert self.api_url is not None
api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {}))
# NOTE: using urllib.request to intentionally block the whole bridge until the proxy change applied
request = urllib.request.Request(api_url, method="GET")
self.log.debug("Requesting proxy from: %s", api_url)
try:
with urllib.request.urlopen(request) as f:
response = json.loads(f.read().decode())
except Exception:
self.log.exception("Failed to retrieve proxy from API")
return self.current_proxy_url
else:
return response["proxy_url"]
def update_proxy_url(self, reason: str | None = None) -> bool:
old_proxy = self.current_proxy_url
new_proxy = None
if self.api_url is not None:
new_proxy = self.get_proxy_url_from_api(reason)
else:
new_proxy = urllib.request.getproxies().get("http")
if old_proxy != new_proxy:
self.log.debug("Set new proxy URL: %s", new_proxy)
self.current_proxy_url = new_proxy
return True
self.log.debug("Got same proxy URL: %s", new_proxy)
return False
def get_proxy_url(self) -> str | None:
if not self.current_proxy_url:
self.update_proxy_url()
return self.current_proxy_url
T = TypeVar("T")
async def proxy_with_retry(
name: str,
func: Callable[[], Awaitable[T]],
logger: TraceLogger,
proxy_handler: ProxyHandler,
on_proxy_change: Callable[[], Awaitable[None]],
max_retries: int = 10,
min_wait_seconds: int = 0,
max_wait_seconds: int = 60,
multiply_wait_seconds: int = 10,
retryable_exceptions: tuple[Exception] = RETRYABLE_PROXY_EXCEPTIONS,
reset_after_seconds: int | None = None,
) -> T:
errors = 0
last_error = 0
while True:
try:
return await func()
except retryable_exceptions as e:
errors += 1
if errors > max_retries:
raise
wait = errors * multiply_wait_seconds
wait = max(wait, min_wait_seconds)
wait = min(wait, max_wait_seconds)
logger.warning(
"%s while trying to %s, retrying in %d seconds",
e.__class__.__name__,
name,
wait,
)
if errors > 1 and proxy_handler.update_proxy_url(
f"{e.__class__.__name__} while trying to {name}"
):
await on_proxy_change()
# If sufficient time has passed since the previous error, reset the
# error count. Useful for long running tasks with rare failures.
if reset_after_seconds is not None:
now = time.time()
if last_error and now - last_error > reset_after_seconds:
errors = 0
last_error = now
python-0.20.4/mautrix/util/signed_token.py 0000664 0000000 0000000 00000002466 14547234302 0020623 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from hashlib import sha256
import base64
import hmac
import json
def _get_checksum(key: str, payload: bytes) -> str:
hasher = hmac.new(key.encode("utf-8"), msg=payload, digestmod=sha256)
checksum = base64.urlsafe_b64encode(hasher.digest())
return checksum.decode("utf-8").rstrip("=")
def sign_token(key: str, payload: dict) -> str:
payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload_b64)
payload_str = payload_b64.decode("utf-8").rstrip("=")
return f"{checksum}:{payload_str}"
def verify_token(key: str, data: str) -> dict | None:
if not data:
return None
try:
checksum, payload = data.split(":", 1)
except ValueError:
return None
payload += (3 - (len(payload) + 3) % 4) * "="
if checksum != _get_checksum(key, payload.encode("utf-8")):
return None
payload = base64.urlsafe_b64decode(payload).decode("utf-8")
try:
return json.loads(payload)
except json.JSONDecodeError:
return None
python-0.20.4/mautrix/util/simple_lock.py 0000664 0000000 0000000 00000002715 14547234302 0020450 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import asyncio
import logging
class SimpleLock:
_event: asyncio.Event
log: logging.Logger | None
message: str | None
noop_mode: bool
def __init__(
self,
message: str | None = None,
log: logging.Logger | None = None,
noop_mode: bool = False,
) -> None:
self.noop_mode = noop_mode
if not noop_mode:
self._event = asyncio.Event()
self._event.set()
self.log = log
self.message = message
def __enter__(self) -> None:
if not self.noop_mode:
self._event.clear()
async def __aenter__(self) -> None:
self.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if not self.noop_mode:
self._event.set()
def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
self.__exit__(exc_type, exc_val, exc_tb)
@property
def locked(self) -> bool:
return not self.noop_mode and not self._event.is_set()
async def wait(self, task: str | None = None) -> None:
if self.locked:
if self.log and self.message:
self.log.debug(self.message, task)
await self._event.wait()
python-0.20.4/mautrix/util/simple_template.py 0000664 0000000 0000000 00000003060 14547234302 0021325 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Generic, TypeVar
T = TypeVar("T")
class SimpleTemplate(Generic[T]):
_template: str
_keyword: str
_prefix: str
_suffix: str
_type: type[T]
def __init__(
self, template: str, keyword: str, prefix: str = "", suffix: str = "", type: type[T] = str
) -> None:
self._template = template
self._keyword = keyword
index = self._template.find("{%s}" % keyword)
length = len(keyword) + 2
self._prefix = prefix + self._template[:index]
self._suffix = self._template[index + length :] + suffix
self._type = type
def format(self, arg: T) -> str:
return self._template.format(**{self._keyword: arg})
def format_full(self, arg: T) -> str:
return f"{self._prefix}{arg}{self._suffix}"
def parse(self, val: str) -> T | None:
prefix_ok = val[: len(self._prefix)] == self._prefix
has_suffix = len(self._suffix) > 0
suffix_ok = not has_suffix or val[-len(self._suffix) :] == self._suffix
if prefix_ok and suffix_ok:
start = len(self._prefix)
end = -len(self._suffix) if has_suffix else len(val)
try:
return self._type(val[start:end])
except ValueError:
pass
return None
python-0.20.4/mautrix/util/utf16_surrogate.py 0000664 0000000 0000000 00000002613 14547234302 0021204 0 ustar 00root root 0000000 0000000 # From https://github.com/LonamiWebs/Telethon/blob/v1.24.0/telethon/helpers.py#L38-L62
# Copyright (c) LonamiWebs, MIT license
import struct
def add(text: str) -> str:
"""
Add surrogate pairs to characters in the text. This makes the indices match how most platforms
calculate string length when formatting texts using offset-based entities.
Args:
text: The text to add surrogate pairs to.
Returns:
The text with surrogate pairs.
"""
return "".join(
"".join(chr(y) for y in struct.unpack(" str:
"""
Remove surrogate pairs from text. This does the opposite of :func:`add`.
Args:
text: The text with surrogate pairs.
Returns:
The text without surrogate pairs.
"""
return text.encode("utf-16", "surrogatepass").decode("utf-16")
def is_within(text: str, index: int, *, length: int = None) -> bool:
"""
Returns:
`True` if ``index`` is within a surrogate (before and after it, not at!).
"""
if length is None:
length = len(text)
return (
1 < index < len(text)
and "\ud800" <= text[index - 1] <= "\udfff" # in bounds
and "\ud800" <= text[index] <= "\udfff" # previous is # current is
)
__all__ = ["add", "remove"]
python-0.20.4/mautrix/util/variation_selector.json 0000664 0000000 0000000 00000015651 14547234302 0022367 0 ustar 00root root 0000000 0000000 {
"0023": "#",
"002A": "*",
"0030": "0",
"0031": "1",
"0032": "2",
"0033": "3",
"0034": "4",
"0035": "5",
"0036": "6",
"0037": "7",
"0038": "8",
"0039": "9",
"00A9": "©",
"00AE": "®",
"203C": "‼",
"2049": "⁉",
"2122": "™",
"2139": "ℹ",
"2194": "↔",
"2195": "↕",
"2196": "↖",
"2197": "↗",
"2198": "↘",
"2199": "↙",
"21A9": "↩",
"21AA": "↪",
"231A": "⌚",
"231B": "⌛",
"2328": "⌨",
"23CF": "⏏",
"23E9": "⏩",
"23EA": "⏪",
"23ED": "⏭",
"23EE": "⏮",
"23EF": "⏯",
"23F1": "⏱",
"23F2": "⏲",
"23F3": "⏳",
"23F8": "⏸",
"23F9": "⏹",
"23FA": "⏺",
"24C2": "Ⓜ",
"25AA": "▪",
"25AB": "▫",
"25B6": "▶",
"25C0": "◀",
"25FB": "◻",
"25FC": "◼",
"25FD": "◽",
"25FE": "◾",
"2600": "☀",
"2601": "☁",
"2602": "☂",
"2603": "☃",
"2604": "☄",
"260E": "☎",
"2611": "☑",
"2614": "☔",
"2615": "☕",
"2618": "☘",
"261D": "☝",
"2620": "☠",
"2622": "☢",
"2623": "☣",
"2626": "☦",
"262A": "☪",
"262E": "☮",
"262F": "☯",
"2638": "☸",
"2639": "☹",
"263A": "☺",
"2640": "♀",
"2642": "♂",
"2648": "♈",
"2649": "♉",
"264A": "♊",
"264B": "♋",
"264C": "♌",
"264D": "♍",
"264E": "♎",
"264F": "♏",
"2650": "♐",
"2651": "♑",
"2652": "♒",
"2653": "♓",
"265F": "♟",
"2660": "♠",
"2663": "♣",
"2665": "♥",
"2666": "♦",
"2668": "♨",
"267B": "♻",
"267E": "♾",
"267F": "♿",
"2692": "⚒",
"2693": "⚓",
"2694": "⚔",
"2695": "⚕",
"2696": "⚖",
"2697": "⚗",
"2699": "⚙",
"269B": "⚛",
"269C": "⚜",
"26A0": "⚠",
"26A1": "⚡",
"26A7": "⚧",
"26AA": "⚪",
"26AB": "⚫",
"26B0": "⚰",
"26B1": "⚱",
"26BD": "⚽",
"26BE": "⚾",
"26C4": "⛄",
"26C5": "⛅",
"26C8": "⛈",
"26CF": "⛏",
"26D1": "⛑",
"26D3": "⛓",
"26D4": "⛔",
"26E9": "⛩",
"26EA": "⛪",
"26F0": "⛰",
"26F1": "⛱",
"26F2": "⛲",
"26F3": "⛳",
"26F4": "⛴",
"26F5": "⛵",
"26F7": "⛷",
"26F8": "⛸",
"26F9": "⛹",
"26FA": "⛺",
"26FD": "⛽",
"2702": "✂",
"2708": "✈",
"2709": "✉",
"270C": "✌",
"270D": "✍",
"270F": "✏",
"2712": "✒",
"2714": "✔",
"2716": "✖",
"271D": "✝",
"2721": "✡",
"2733": "✳",
"2734": "✴",
"2744": "❄",
"2747": "❇",
"2753": "❓",
"2757": "❗",
"2763": "❣",
"2764": "❤",
"27A1": "➡",
"2934": "⤴",
"2935": "⤵",
"2B05": "⬅",
"2B06": "⬆",
"2B07": "⬇",
"2B1B": "⬛",
"2B1C": "⬜",
"2B50": "⭐",
"2B55": "⭕",
"3030": "〰",
"303D": "〽",
"3297": "㊗",
"3299": "㊙",
"1F004": "🀄",
"1F170": "🅰",
"1F171": "🅱",
"1F17E": "🅾",
"1F17F": "🅿",
"1F202": "🈂",
"1F21A": "🈚",
"1F22F": "🈯",
"1F237": "🈷",
"1F30D": "🌍",
"1F30E": "🌎",
"1F30F": "🌏",
"1F315": "🌕",
"1F31C": "🌜",
"1F321": "🌡",
"1F324": "🌤",
"1F325": "🌥",
"1F326": "🌦",
"1F327": "🌧",
"1F328": "🌨",
"1F329": "🌩",
"1F32A": "🌪",
"1F32B": "🌫",
"1F32C": "🌬",
"1F336": "🌶",
"1F378": "🍸",
"1F37D": "🍽",
"1F393": "🎓",
"1F396": "🎖",
"1F397": "🎗",
"1F399": "🎙",
"1F39A": "🎚",
"1F39B": "🎛",
"1F39E": "🎞",
"1F39F": "🎟",
"1F3A7": "🎧",
"1F3AC": "🎬",
"1F3AD": "🎭",
"1F3AE": "🎮",
"1F3C2": "🏂",
"1F3C4": "🏄",
"1F3C6": "🏆",
"1F3CA": "🏊",
"1F3CB": "🏋",
"1F3CC": "🏌",
"1F3CD": "🏍",
"1F3CE": "🏎",
"1F3D4": "🏔",
"1F3D5": "🏕",
"1F3D6": "🏖",
"1F3D7": "🏗",
"1F3D8": "🏘",
"1F3D9": "🏙",
"1F3DA": "🏚",
"1F3DB": "🏛",
"1F3DC": "🏜",
"1F3DD": "🏝",
"1F3DE": "🏞",
"1F3DF": "🏟",
"1F3E0": "🏠",
"1F3ED": "🏭",
"1F3F3": "🏳",
"1F3F5": "🏵",
"1F3F7": "🏷",
"1F408": "🐈",
"1F415": "🐕",
"1F41F": "🐟",
"1F426": "🐦",
"1F43F": "🐿",
"1F441": "👁",
"1F442": "👂",
"1F446": "👆",
"1F447": "👇",
"1F448": "👈",
"1F449": "👉",
"1F44D": "👍",
"1F44E": "👎",
"1F453": "👓",
"1F46A": "👪",
"1F47D": "👽",
"1F4A3": "💣",
"1F4B0": "💰",
"1F4B3": "💳",
"1F4BB": "💻",
"1F4BF": "💿",
"1F4CB": "📋",
"1F4DA": "📚",
"1F4DF": "📟",
"1F4E4": "📤",
"1F4E5": "📥",
"1F4E6": "📦",
"1F4EA": "📪",
"1F4EB": "📫",
"1F4EC": "📬",
"1F4ED": "📭",
"1F4F7": "📷",
"1F4F9": "📹",
"1F4FA": "📺",
"1F4FB": "📻",
"1F4FD": "📽",
"1F508": "🔈",
"1F50D": "🔍",
"1F512": "🔒",
"1F513": "🔓",
"1F549": "🕉",
"1F54A": "🕊",
"1F550": "🕐",
"1F551": "🕑",
"1F552": "🕒",
"1F553": "🕓",
"1F554": "🕔",
"1F555": "🕕",
"1F556": "🕖",
"1F557": "🕗",
"1F558": "🕘",
"1F559": "🕙",
"1F55A": "🕚",
"1F55B": "🕛",
"1F55C": "🕜",
"1F55D": "🕝",
"1F55E": "🕞",
"1F55F": "🕟",
"1F560": "🕠",
"1F561": "🕡",
"1F562": "🕢",
"1F563": "🕣",
"1F564": "🕤",
"1F565": "🕥",
"1F566": "🕦",
"1F567": "🕧",
"1F56F": "🕯",
"1F570": "🕰",
"1F573": "🕳",
"1F574": "🕴",
"1F575": "🕵",
"1F576": "🕶",
"1F577": "🕷",
"1F578": "🕸",
"1F579": "🕹",
"1F587": "🖇",
"1F58A": "🖊",
"1F58B": "🖋",
"1F58C": "🖌",
"1F58D": "🖍",
"1F590": "🖐",
"1F5A5": "🖥",
"1F5A8": "🖨",
"1F5B1": "🖱",
"1F5B2": "🖲",
"1F5BC": "🖼",
"1F5C2": "🗂",
"1F5C3": "🗃",
"1F5C4": "🗄",
"1F5D1": "🗑",
"1F5D2": "🗒",
"1F5D3": "🗓",
"1F5DC": "🗜",
"1F5DD": "🗝",
"1F5DE": "🗞",
"1F5E1": "🗡",
"1F5E3": "🗣",
"1F5E8": "🗨",
"1F5EF": "🗯",
"1F5F3": "🗳",
"1F5FA": "🗺",
"1F610": "😐",
"1F687": "🚇",
"1F68D": "🚍",
"1F691": "🚑",
"1F694": "🚔",
"1F698": "🚘",
"1F6AD": "🚭",
"1F6B2": "🚲",
"1F6B9": "🚹",
"1F6BA": "🚺",
"1F6BC": "🚼",
"1F6CB": "🛋",
"1F6CD": "🛍",
"1F6CE": "🛎",
"1F6CF": "🛏",
"1F6E0": "🛠",
"1F6E1": "🛡",
"1F6E2": "🛢",
"1F6E3": "🛣",
"1F6E4": "🛤",
"1F6E5": "🛥",
"1F6E9": "🛩",
"1F6F0": "🛰",
"1F6F3": "🛳"
}
python-0.20.4/mautrix/util/variation_selector.py 0000664 0000000 0000000 00000007174 14547234302 0022047 0 ustar 00root root 0000000 0000000 # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
import json
import pkgutil
import aiohttp
EMOJI_VAR_URL = "https://www.unicode.org/Public/14.0.0/ucd/emoji/emoji-variation-sequences.txt"
def read_data() -> dict[str, str]:
"""
Get the list of emoji that need a variation selector. This loads the local data file that was
previously generated from the Unicode spec data files.
Returns:
A dict from hex to the emoji string (you have to bring the variation selectors yourself).
"""
return json.loads(pkgutil.get_data("mautrix.util", "variation_selector.json"))
async def fetch_data() -> dict[str, str]:
"""
Generate the list of emoji that need a variation selector from the Unicode spec data files.
Returns:
A dict from hex to the emoji string (you have to bring the variation selectors yourself).
"""
async with aiohttp.ClientSession() as sess, sess.get(EMOJI_VAR_URL) as resp:
data = await resp.text()
emojis = {}
for line in data.split("\n"):
if "emoji style" in line:
emoji_hex = line.split(" ", 1)[0]
emojis[emoji_hex] = rf"\U{emoji_hex:>08}".encode("ascii").decode("unicode-escape")
return emojis
if __name__ == "__main__":
import asyncio
import sys
import pkg_resources
path = pkg_resources.resource_filename("mautrix.util", "variation_selector.json")
emojis = asyncio.run(fetch_data())
with open(path, "w") as file:
json.dump(emojis, file, indent=" ", ensure_ascii=False)
file.write("\n")
print(f"Wrote {len(emojis)} emojis to {path}")
sys.exit(0)
VARIATION_SELECTOR_16 = "\ufe0f"
ADD_VARIATION_TRANSLATION = str.maketrans(
{ord(emoji): f"{emoji}{VARIATION_SELECTOR_16}" for emoji in read_data().values()}
)
SKIN_TONE_MODIFIERS = ("\U0001F3FB", "\U0001F3FC", "\U0001F3FD", "\U0001F3FE", "\U0001F3FF")
SKIN_TONE_REPLACEMENTS = {f"{VARIATION_SELECTOR_16}{mod}": mod for mod in SKIN_TONE_MODIFIERS}
VARIATION_SELECTOR_REPLACEMENTS = {
**SKIN_TONE_REPLACEMENTS,
"\U0001F408\ufe0f\u200d\u2b1b\ufe0f": "\U0001F408\u200d\u2b1b",
}
def add(val: str) -> str:
r"""
Add emoji variation selectors (16) to all emojis that have multiple forms in the given string.
This will remove all variation selectors first to make sure it doesn't add duplicates.
.. versionadded:: 0.12.5
Examples:
>>> from mautrix.util import variation_selector
>>> variation_selector.add("\U0001f44d")
"\U0001f44d\ufe0f"
>>> variation_selector.add("\U0001f44d\ufe0f")
"\U0001f44d\ufe0f"
>>> variation_selector.add("4\u20e3")
"4\ufe0f\u20e3"
>>> variation_selector.add("\U0001f9d0")
"\U0001f9d0"
Args:
val: The string to add variation selectors to.
Returns:
The string with variation selectors added.
"""
added = remove(val).translate(ADD_VARIATION_TRANSLATION)
for invalid_selector, replacement in VARIATION_SELECTOR_REPLACEMENTS.items():
added = added.replace(invalid_selector, replacement)
return added
def remove(val: str) -> str:
"""
Remove all emoji variation selectors in the given string.
.. versionadded:: 0.12.5
Args:
val: The string to remove variation selectors from.
Returns:
The string with variation selectors removed.
"""
return val.replace(VARIATION_SELECTOR_16, "")
__all__ = ["add", "remove", "read_data", "fetch_data"]
python-0.20.4/optional-requirements.txt 0000664 0000000 0000000 00000000223 14547234302 0020226 0 ustar 00root root 0000000 0000000 python-magic
ruamel.yaml
SQLAlchemy<2
commonmark
lxml
asyncpg
aiosqlite
prometheus_client
setuptools
uvloop
python-olm
unpaddedbase64
pycryptodome
python-0.20.4/pyproject.toml 0000664 0000000 0000000 00000000435 14547234302 0016040 0 ustar 00root root 0000000 0000000 [tool.isort]
profile = "black"
force_to_top = "typing"
from_first = true
combine_as_imports = true
line_length = 99
[tool.black]
line-length = 99
target-version = ["py38"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
addopts = "--ignore mautrix/util/db/ --ignore mautrix/bridge/"
python-0.20.4/requirements.txt 0000664 0000000 0000000 00000000023 14547234302 0016401 0 ustar 00root root 0000000 0000000 aiohttp
attrs
yarl
python-0.20.4/setup.py 0000664 0000000 0000000 00000003327 14547234302 0014641 0 ustar 00root root 0000000 0000000 import setuptools
from mautrix import __version__
encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome"]
test_dependencies = ["aiosqlite", "asyncpg", "ruamel.yaml", *encryption_dependencies]
setuptools.setup(
name="mautrix",
version=__version__,
url="https://github.com/mautrix/python",
project_urls={
"Changelog": "https://github.com/mautrix/python/blob/master/CHANGELOG.md",
},
author="Tulir Asokan",
author_email="tulir@maunium.net",
description="A Python 3 asyncio Matrix framework.",
long_description=open("README.rst").read(),
packages=setuptools.find_packages(),
install_requires=[
"aiohttp>=3,<4",
"attrs>=18.1.0",
"yarl>=1.5,<2",
],
extras_require={
"detect_mimetype": ["python-magic>=0.4.15,<0.5"],
"lint": ["black~=23.1", "isort"],
"test": ["pytest", "pytest-asyncio", *test_dependencies],
"encryption": encryption_dependencies,
},
tests_require=test_dependencies,
python_requires="~=3.10",
classifiers=[
"Development Status :: 4 - Beta",
"License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)",
"Topic :: Communications :: Chat",
"Framework :: AsyncIO",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
package_data={
"mautrix": ["py.typed"],
"mautrix.types.event": ["type.pyi"],
"mautrix.util": ["opt_prometheus.pyi", "variation_selector.json"],
"mautrix.util.formatter": ["html_reader.pyi"],
},
)