starlette-0.18.0/0000755000175100001710000000000014173233775014422 5ustar runnerdocker00000000000000starlette-0.18.0/LICENSE.md0000644000175100001710000000275614173233741016031 0ustar runnerdocker00000000000000Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. starlette-0.18.0/MANIFEST.in0000644000175100001710000000010614173233741016146 0ustar runnerdocker00000000000000include LICENSE.md global-exclude __pycache__ global-exclude *.py[co] starlette-0.18.0/PKG-INFO0000644000175100001710000001620314173233775015521 0ustar runnerdocker00000000000000Metadata-Version: 2.1 Name: starlette Version: 0.18.0 Summary: The little ASGI library that shines. Home-page: https://github.com/encode/starlette Author: Tom Christie Author-email: tom@tomchristie.com License: BSD Description:

starlette

✨ The little ASGI framework that shines. ✨

Build Status Package version

--- **Documentation**: [https://www.starlette.io/](https://www.starlette.io/) --- # Starlette Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit, which is ideal for building high performance async services. It is production-ready, and gives you the following: * Seriously impressive performance. * WebSocket support. * In-process background tasks. * Startup and shutdown events. * Test client built on `requests`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. * Few hard dependencies. * Compatible with `asyncio` and `trio` backends. ## Requirements Python 3.6+ ## Installation ```shell $ pip3 install starlette ``` You'll also want to install an ASGI server, such as [uvicorn](http://www.uvicorn.org/), [daphne](https://github.com/django/daphne/), or [hypercorn](https://pgjones.gitlab.io/hypercorn/). ```shell $ pip3 install uvicorn ``` ## Example **example.py**: ```python from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route async def homepage(request): return JSONResponse({'hello': 'world'}) routes = [ Route("/", endpoint=homepage) ] app = Starlette(debug=True, routes=routes) ``` Then run the application using Uvicorn: ```shell $ uvicorn example:app ``` For a more complete example, see [encode/starlette-example](https://github.com/encode/starlette-example). ## Dependencies Starlette only requires `anyio`, and the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. You can install all of these with `pip3 install starlette[full]`. ## Framework or Toolkit Starlette is designed to be used either as a complete framework, or as an ASGI toolkit. You can use any of its components independently. ```python from starlette.responses import PlainTextResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = PlainTextResponse('Hello, world!') await response(scope, receive, send) ``` Run the `app` application in `example.py`: ```shell $ uvicorn example:app INFO: Started server process [11509] INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) ``` Run uvicorn with `--reload` to enable auto-reloading on code changes. ## Modularity The modularity that Starlette is designed on promotes building re-usable components that can be shared between any ASGI framework. This should enable an ecosystem of shared middleware and mountable applications. The clean API separation also means it's easier to understand each component in isolation. ## Performance Independent TechEmpower benchmarks show Starlette applications running under Uvicorn as [one of the fastest Python frameworks available](https://www.techempower.com/benchmarks/#section=data-r17&hw=ph&test=fortune&l=zijzen-1). *(\*)* For high throughput loads you should: * Run using gunicorn using the `uvicorn` worker class. * Use one or two workers per-CPU core. (You might need to experiment with this.) * Disable access logging. Eg. ```shell gunicorn -w 4 -k uvicorn.workers.UvicornWorker --log-level warning example:app ``` Several of the ASGI servers also have pure Python implementations available, so you can also run under `PyPy` if your application code has parts that are CPU constrained. Either programatically: ```python uvicorn.run(..., http='h11', loop='asyncio') ``` Or using Gunicorn: ```shell gunicorn -k uvicorn.workers.UvicornH11Worker ... ```

— ⭐️ —

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation Platform: UNKNOWN Classifier: Development Status :: 3 - Alpha Classifier: Environment :: Web Environment Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: BSD License Classifier: Operating System :: OS Independent Classifier: Topic :: Internet :: WWW/HTTP Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Requires-Python: >=3.6 Description-Content-Type: text/markdown Provides-Extra: full starlette-0.18.0/README.md0000644000175100001710000001163614173233741015701 0ustar runnerdocker00000000000000

starlette

✨ The little ASGI framework that shines. ✨

Build Status Package version

--- **Documentation**: [https://www.starlette.io/](https://www.starlette.io/) --- # Starlette Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit, which is ideal for building high performance async services. It is production-ready, and gives you the following: * Seriously impressive performance. * WebSocket support. * In-process background tasks. * Startup and shutdown events. * Test client built on `requests`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. * Few hard dependencies. * Compatible with `asyncio` and `trio` backends. ## Requirements Python 3.6+ ## Installation ```shell $ pip3 install starlette ``` You'll also want to install an ASGI server, such as [uvicorn](http://www.uvicorn.org/), [daphne](https://github.com/django/daphne/), or [hypercorn](https://pgjones.gitlab.io/hypercorn/). ```shell $ pip3 install uvicorn ``` ## Example **example.py**: ```python from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route async def homepage(request): return JSONResponse({'hello': 'world'}) routes = [ Route("/", endpoint=homepage) ] app = Starlette(debug=True, routes=routes) ``` Then run the application using Uvicorn: ```shell $ uvicorn example:app ``` For a more complete example, see [encode/starlette-example](https://github.com/encode/starlette-example). ## Dependencies Starlette only requires `anyio`, and the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. You can install all of these with `pip3 install starlette[full]`. ## Framework or Toolkit Starlette is designed to be used either as a complete framework, or as an ASGI toolkit. You can use any of its components independently. ```python from starlette.responses import PlainTextResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = PlainTextResponse('Hello, world!') await response(scope, receive, send) ``` Run the `app` application in `example.py`: ```shell $ uvicorn example:app INFO: Started server process [11509] INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) ``` Run uvicorn with `--reload` to enable auto-reloading on code changes. ## Modularity The modularity that Starlette is designed on promotes building re-usable components that can be shared between any ASGI framework. This should enable an ecosystem of shared middleware and mountable applications. The clean API separation also means it's easier to understand each component in isolation. ## Performance Independent TechEmpower benchmarks show Starlette applications running under Uvicorn as [one of the fastest Python frameworks available](https://www.techempower.com/benchmarks/#section=data-r17&hw=ph&test=fortune&l=zijzen-1). *(\*)* For high throughput loads you should: * Run using gunicorn using the `uvicorn` worker class. * Use one or two workers per-CPU core. (You might need to experiment with this.) * Disable access logging. Eg. ```shell gunicorn -w 4 -k uvicorn.workers.UvicornWorker --log-level warning example:app ``` Several of the ASGI servers also have pure Python implementations available, so you can also run under `PyPy` if your application code has parts that are CPU constrained. Either programatically: ```python uvicorn.run(..., http='h11', loop='asyncio') ``` Or using Gunicorn: ```shell gunicorn -k uvicorn.workers.UvicornH11Worker ... ```

— ⭐️ —

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation starlette-0.18.0/setup.cfg0000644000175100001710000000202314173233775016240 0ustar runnerdocker00000000000000[flake8] ignore = W503, E203, B305 max-line-length = 88 [mypy] disallow_untyped_defs = True ignore_missing_imports = True show_error_codes = True [mypy-tests.*] disallow_untyped_defs = False check_untyped_defs = True [tool:isort] profile = black combine_as_imports = True [tool:pytest] addopts = -rxXs --strict-config --strict-markers xfail_strict = True filterwarnings = error ignore: Using or importing the ABCs from 'collections' instead of from 'collections\.abc' is deprecated.*:DeprecationWarning ignore: The 'context' alias has been deprecated. Please use 'context_value' instead\.:DeprecationWarning ignore: The 'variables' alias has been deprecated. Please use 'variable_values' instead\.:DeprecationWarning ignore:The loop argument is deprecated since Python 3\.8, and scheduled for removal in Python 3\.10\.:DeprecationWarning:asyncio [coverage:run] source_pkgs = starlette, tests [coverage:report] exclude_lines = pragma: no cover pragma: nocover if typing.TYPE_CHECKING: [egg_info] tag_build = tag_date = 0 starlette-0.18.0/setup.py0000644000175100001710000000366314173233741016135 0ustar runnerdocker00000000000000#!/usr/bin/env python # -*- coding: utf-8 -*- import os import re from setuptools import setup, find_packages def get_version(package): """ Return package version as listed in `__version__` in `init.py`. """ with open(os.path.join(package, "__init__.py")) as f: return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) def get_long_description(): """ Return the README. """ with open("README.md", encoding="utf8") as f: return f.read() setup( name="starlette", python_requires=">=3.6", version=get_version("starlette"), url="https://github.com/encode/starlette", license="BSD", description="The little ASGI library that shines.", long_description=get_long_description(), long_description_content_type="text/markdown", author="Tom Christie", author_email="tom@tomchristie.com", packages=find_packages(exclude=["tests*"]), package_data={"starlette": ["py.typed"]}, include_package_data=True, install_requires=[ "anyio>=3.0.0,<4", "typing_extensions; python_version < '3.10'", "contextlib2 >= 21.6.0; python_version < '3.7'", ], extras_require={ "full": [ "itsdangerous", "jinja2", "python-multipart", "pyyaml", "requests", ] }, classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Web Environment", "Intended Audience :: Developers", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Topic :: Internet :: WWW/HTTP", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", ], zip_safe=False, ) starlette-0.18.0/starlette/0000755000175100001710000000000014173233775016431 5ustar runnerdocker00000000000000starlette-0.18.0/starlette/__init__.py0000644000175100001710000000002714173233741020532 0ustar runnerdocker00000000000000__version__ = "0.18.0" starlette-0.18.0/starlette/_compat.py0000644000175100001710000000217514173233741020423 0ustar runnerdocker00000000000000import hashlib # Compat wrapper to always include the `usedforsecurity=...` parameter, # which is only added from Python 3.9 onwards. # We use this flag to indicate that we use `md5` hashes only for non-security # cases (our ETag checksums). # If we don't indicate that we're using MD5 for non-security related reasons, # then attempting to use this function will raise an error when used # environments which enable a strict "FIPs mode". # # See issue: https://github.com/encode/starlette/issues/1365 try: # check if the Python version supports the parameter # using usedforsecurity=False to avoid an exception on FIPS systems # that reject usedforsecurity=True hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg] def md5_hexdigest( data: bytes, *, usedforsecurity: bool = True ) -> str: # pragma: no cover return hashlib.md5( # type: ignore[call-arg] data, usedforsecurity=usedforsecurity ).hexdigest() except TypeError: # pragma: no cover def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str: return hashlib.md5(data).hexdigest() starlette-0.18.0/starlette/applications.py0000644000175100001710000001730314173233741021466 0ustar runnerdocker00000000000000import typing from starlette.datastructures import State, URLPath from starlette.exceptions import ExceptionMiddleware from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router from starlette.types import ASGIApp, Receive, Scope, Send class Starlette: """ Creates an application instance. **Parameters:** * **debug** - Boolean indicating if debug tracebacks should be returned on errors. * **routes** - A list of routes to serve incoming HTTP and WebSocket requests. * **middleware** - A list of middleware to run for every request. A starlette application will always automatically include two middleware classes. `ServerErrorMiddleware` is added as the very outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled exception cases occurring in the routing or endpoints. * **exception_handlers** - A mapping of either integer status codes, or exception class types onto callables which handle the exceptions. Exception handler callables should be of the form `handler(request, exc) -> response` and may be be either standard functions, or async functions. * **on_startup** - A list of callables to run on application startup. Startup handler callables do not take any arguments, and may be be either standard functions, or async functions. * **on_shutdown** - A list of callables to run on application shutdown. Shutdown handler callables do not take any arguments, and may be be either standard functions, or async functions. """ def __init__( self, debug: bool = False, routes: typing.Sequence[BaseRoute] = None, middleware: typing.Sequence[Middleware] = None, exception_handlers: typing.Mapping[ typing.Any, typing.Callable[ [Request, Exception], typing.Union[Response, typing.Awaitable[Response]] ], ] = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. assert lifespan is None or ( on_startup is None and on_shutdown is None ), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both." self._debug = debug self.state = State() self.router = Router( routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan ) self.exception_handlers = ( {} if exception_handlers is None else dict(exception_handlers) ) self.user_middleware = [] if middleware is None else list(middleware) self.middleware_stack = self.build_middleware_stack() def build_middleware_stack(self) -> ASGIApp: debug = self.debug error_handler = None exception_handlers: typing.Dict[ typing.Any, typing.Callable[[Request, Exception], Response] ] = {} for key, value in self.exception_handlers.items(): if key in (500, Exception): error_handler = value else: exception_handlers[key] = value middleware = ( [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + self.user_middleware + [ Middleware( ExceptionMiddleware, handlers=exception_handlers, debug=debug ) ] ) app = self.router for cls, options in reversed(middleware): app = cls(app=app, **options) return app @property def routes(self) -> typing.List[BaseRoute]: return self.router.routes @property def debug(self) -> bool: return self._debug @debug.setter def debug(self, value: bool) -> None: self._debug = value self.middleware_stack = self.build_middleware_stack() def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: return self.router.url_path_for(name, **path_params) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self await self.middleware_stack(scope, receive, send) # The following usages are now discouraged in favour of configuration #  during Starlette.__init__(...) def on_event(self, event_type: str) -> typing.Callable: return self.router.on_event(event_type) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) def host(self, host: str, app: ASGIApp, name: str = None) -> None: self.router.host(host, app=app, name=name) def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: self.user_middleware.insert(0, Middleware(middleware_class, **options)) self.middleware_stack = self.build_middleware_stack() def add_exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable, ) -> None: self.exception_handlers[exc_class_or_status_code] = handler self.middleware_stack = self.build_middleware_stack() def add_event_handler(self, event_type: str, func: typing.Callable) -> None: self.router.add_event_handler(event_type, func) def add_route( self, path: str, route: typing.Callable, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> None: self.router.add_route( path, route, methods=methods, name=name, include_in_schema=include_in_schema ) def add_websocket_route( self, path: str, route: typing.Callable, name: str = None ) -> None: self.router.add_websocket_route(path, route, name=name) def exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]] ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_exception_handler(exc_class_or_status_code, func) return func return decorator def route( self, path: str, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.router.add_route( path, func, methods=methods, name=name, include_in_schema=include_in_schema, ) return func return decorator def websocket_route(self, path: str, name: str = None) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.router.add_websocket_route(path, func, name=name) return func return decorator def middleware(self, middleware_type: str) -> typing.Callable: assert ( middleware_type == "http" ), 'Currently only middleware("http") is supported.' def decorator(func: typing.Callable) -> typing.Callable: self.add_middleware(BaseHTTPMiddleware, dispatch=func) return func return decorator starlette-0.18.0/starlette/authentication.py0000644000175100001710000001066114173233741022017 0ustar runnerdocker00000000000000import asyncio import functools import inspect import typing from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse, Response from starlette.websockets import WebSocket def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: for scope in scopes: if scope not in conn.auth.scopes: return False return True def requires( scopes: typing.Union[str, typing.Sequence[str]], status_code: int = 403, redirect: str = None, ) -> typing.Callable: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator(func: typing.Callable) -> typing.Callable: sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": type_ = parameter.name break else: raise Exception( f'No "request" or "websocket" argument on function "{func}"' ) if type_ == "websocket": # Handle websocket functions. (Always async) @functools.wraps(func) async def websocket_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> None: websocket = kwargs.get( "websocket", args[idx] if idx < len(args) else None ) assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): await websocket.close() else: await func(*args, **kwargs) return websocket_wrapper elif asyncio.iscoroutinefunction(func): # Handle async request/response functions. @functools.wraps(func) async def async_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> Response: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: return RedirectResponse( url=request.url_for(redirect), status_code=303 ) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) return async_wrapper else: # Handle sync request/response functions. @functools.wraps(func) def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: return RedirectResponse( url=request.url_for(redirect), status_code=303 ) raise HTTPException(status_code=status_code) return func(*args, **kwargs) return sync_wrapper return decorator class AuthenticationError(Exception): pass class AuthenticationBackend: async def authenticate( self, conn: HTTPConnection ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: raise NotImplementedError() # pragma: no cover class AuthCredentials: def __init__(self, scopes: typing.Sequence[str] = None): self.scopes = [] if scopes is None else list(scopes) class BaseUser: @property def is_authenticated(self) -> bool: raise NotImplementedError() # pragma: no cover @property def display_name(self) -> str: raise NotImplementedError() # pragma: no cover @property def identity(self) -> str: raise NotImplementedError() # pragma: no cover class SimpleUser(BaseUser): def __init__(self, username: str) -> None: self.username = username @property def is_authenticated(self) -> bool: return True @property def display_name(self) -> str: return self.username class UnauthenticatedUser(BaseUser): @property def is_authenticated(self) -> bool: return False @property def display_name(self) -> str: return "" starlette-0.18.0/starlette/background.py0000644000175100001710000000230414173233741021112 0ustar runnerdocker00000000000000import asyncio import sys import typing if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec else: # pragma: no cover from typing_extensions import ParamSpec from starlette.concurrency import run_in_threadpool P = ParamSpec("P") class BackgroundTask: def __init__( self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs ) -> None: self.func = func self.args = args self.kwargs = kwargs self.is_async = asyncio.iscoroutinefunction(func) async def __call__(self) -> None: if self.is_async: await self.func(*self.args, **self.kwargs) else: await run_in_threadpool(self.func, *self.args, **self.kwargs) class BackgroundTasks(BackgroundTask): def __init__(self, tasks: typing.Sequence[BackgroundTask] = None): self.tasks = list(tasks) if tasks else [] def add_task( self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs ) -> None: task = BackgroundTask(func, *args, **kwargs) self.tasks.append(task) async def __call__(self) -> None: for task in self.tasks: await task() starlette-0.18.0/starlette/concurrency.py0000644000175100001710000000376114173233741021335 0ustar runnerdocker00000000000000import functools import sys import typing import anyio if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec else: # pragma: no cover from typing_extensions import ParamSpec try: import contextvars # Python 3.7+ only or via contextvars backport. except ImportError: # pragma: no cover contextvars = None # type: ignore T = typing.TypeVar("T") P = ParamSpec("P") async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: async with anyio.create_task_group() as task_group: async def run(func: typing.Callable[[], typing.Coroutine]) -> None: await func() task_group.cancel_scope.cancel() for func, kwargs in args: task_group.start_soon(run, functools.partial(func, **kwargs)) async def run_in_threadpool( func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs ) -> T: if contextvars is not None: # pragma: no cover # Ensure we run in the same context child = functools.partial(func, *args, **kwargs) context = contextvars.copy_context() func = context.run # type: ignore[assignment] args = (child,) # type: ignore[assignment] elif kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) return await anyio.to_thread.run_sync(func, *args) class _StopIteration(Exception): pass def _next(iterator: typing.Iterator[T]) -> T: # We can't raise `StopIteration` from within the threadpool iterator # and catch it outside that context, so we coerce them into a different # exception type. try: return next(iterator) except StopIteration: raise _StopIteration async def iterate_in_threadpool( iterator: typing.Iterator[T], ) -> typing.AsyncIterator[T]: while True: try: yield await anyio.to_thread.run_sync(_next, iterator) except _StopIteration: break starlette-0.18.0/starlette/config.py0000644000175100001710000001010514173233741020236 0ustar runnerdocker00000000000000import os import typing from collections.abc import MutableMapping from pathlib import Path class undefined: pass class EnvironError(Exception): pass class Environ(MutableMapping): def __init__(self, environ: typing.MutableMapping = os.environ): self._environ = environ self._has_been_read: typing.Set[typing.Any] = set() def __getitem__(self, key: typing.Any) -> typing.Any: self._has_been_read.add(key) return self._environ.__getitem__(key) def __setitem__(self, key: typing.Any, value: typing.Any) -> None: if key in self._has_been_read: raise EnvironError( f"Attempting to set environ['{key}'], but the value has already been " "read." ) self._environ.__setitem__(key, value) def __delitem__(self, key: typing.Any) -> None: if key in self._has_been_read: raise EnvironError( f"Attempting to delete environ['{key}'], but the value has already " "been read." ) self._environ.__delitem__(key) def __iter__(self) -> typing.Iterator: return iter(self._environ) def __len__(self) -> int: return len(self._environ) environ = Environ() T = typing.TypeVar("T") class Config: def __init__( self, env_file: typing.Union[str, Path] = None, environ: typing.Mapping[str, str] = environ, ) -> None: self.environ = environ self.file_values: typing.Dict[str, str] = {} if env_file is not None and os.path.isfile(env_file): self.file_values = self._read_file(env_file) @typing.overload def __call__( self, key: str, cast: typing.Type[T], default: T = ... ) -> T: # pragma: no cover ... @typing.overload def __call__( self, key: str, cast: typing.Type[str] = ..., default: str = ... ) -> str: # pragma: no cover ... @typing.overload def __call__( self, key: str, cast: typing.Type[str] = ..., default: T = ... ) -> typing.Union[T, str]: # pragma: no cover ... def __call__( self, key: str, cast: typing.Callable = None, default: typing.Any = undefined ) -> typing.Any: return self.get(key, cast, default) def get( self, key: str, cast: typing.Callable = None, default: typing.Any = undefined ) -> typing.Any: if key in self.environ: value = self.environ[key] return self._perform_cast(key, value, cast) if key in self.file_values: value = self.file_values[key] return self._perform_cast(key, value, cast) if default is not undefined: return self._perform_cast(key, default, cast) raise KeyError(f"Config '{key}' is missing, and has no default.") def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str]: file_values: typing.Dict[str, str] = {} with open(file_name) as input_file: for line in input_file.readlines(): line = line.strip() if "=" in line and not line.startswith("#"): key, value = line.split("=", 1) key = key.strip() value = value.strip().strip("\"'") file_values[key] = value return file_values def _perform_cast( self, key: str, value: typing.Any, cast: typing.Callable = None ) -> typing.Any: if cast is None or value is None: return value elif cast is bool and isinstance(value, str): mapping = {"true": True, "1": True, "false": False, "0": False} value = value.lower() if value not in mapping: raise ValueError( f"Config '{key}' has value '{value}'. Not a valid bool." ) return mapping[value] try: return cast(value) except (TypeError, ValueError): raise ValueError( f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}." ) starlette-0.18.0/starlette/convertors.py0000644000175100001710000000402114173233741021175 0ustar runnerdocker00000000000000import math import typing import uuid class Convertor: regex = "" def convert(self, value: str) -> typing.Any: raise NotImplementedError() # pragma: no cover def to_string(self, value: typing.Any) -> str: raise NotImplementedError() # pragma: no cover class StringConvertor(Convertor): regex = "[^/]+" def convert(self, value: str) -> typing.Any: return value def to_string(self, value: typing.Any) -> str: value = str(value) assert "/" not in value, "May not contain path separators" assert value, "Must not be empty" return value class PathConvertor(Convertor): regex = ".*" def convert(self, value: str) -> typing.Any: return str(value) def to_string(self, value: typing.Any) -> str: return str(value) class IntegerConvertor(Convertor): regex = "[0-9]+" def convert(self, value: str) -> typing.Any: return int(value) def to_string(self, value: typing.Any) -> str: value = int(value) assert value >= 0, "Negative integers are not supported" return str(value) class FloatConvertor(Convertor): regex = "[0-9]+(.[0-9]+)?" def convert(self, value: str) -> typing.Any: return float(value) def to_string(self, value: typing.Any) -> str: value = float(value) assert value >= 0.0, "Negative floats are not supported" assert not math.isnan(value), "NaN values are not supported" assert not math.isinf(value), "Infinite values are not supported" return ("%0.20f" % value).rstrip("0").rstrip(".") class UUIDConvertor(Convertor): regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" def convert(self, value: str) -> typing.Any: return uuid.UUID(value) def to_string(self, value: typing.Any) -> str: return str(value) CONVERTOR_TYPES = { "str": StringConvertor(), "path": PathConvertor(), "int": IntegerConvertor(), "float": FloatConvertor(), "uuid": UUIDConvertor(), } starlette-0.18.0/starlette/datastructures.py0000644000175100001710000005173114173233741022060 0ustar runnerdocker00000000000000import tempfile import typing from collections import namedtuple from collections.abc import Sequence from shlex import shlex from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit from starlette.concurrency import run_in_threadpool from starlette.types import Scope Address = namedtuple("Address", ["host", "port"]) class URL: def __init__( self, url: str = "", scope: Scope = None, **components: typing.Any ) -> None: if scope is not None: assert not url, 'Cannot set both "url" and "scope".' assert not components, 'Cannot set both "scope" and "**components".' scheme = scope.get("scheme", "http") server = scope.get("server", None) path = scope.get("root_path", "") + scope["path"] query_string = scope.get("query_string", b"") host_header = None for key, value in scope["headers"]: if key == b"host": host_header = value.decode("latin-1") break if host_header is not None: url = f"{scheme}://{host_header}{path}" elif server is None: url = path else: host, port = server default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] if port == default_port: url = f"{scheme}://{host}{path}" else: url = f"{scheme}://{host}:{port}{path}" if query_string: url += "?" + query_string.decode() elif components: assert not url, 'Cannot set both "url" and "**components".' url = URL("").replace(**components).components.geturl() self._url = url @property def components(self) -> SplitResult: if not hasattr(self, "_components"): self._components = urlsplit(self._url) return self._components @property def scheme(self) -> str: return self.components.scheme @property def netloc(self) -> str: return self.components.netloc @property def path(self) -> str: return self.components.path @property def query(self) -> str: return self.components.query @property def fragment(self) -> str: return self.components.fragment @property def username(self) -> typing.Union[None, str]: return self.components.username @property def password(self) -> typing.Union[None, str]: return self.components.password @property def hostname(self) -> typing.Union[None, str]: return self.components.hostname @property def port(self) -> typing.Optional[int]: return self.components.port @property def is_secure(self) -> bool: return self.scheme in ("https", "wss") def replace(self, **kwargs: typing.Any) -> "URL": if ( "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs ): hostname = kwargs.pop("hostname", self.hostname) port = kwargs.pop("port", self.port) username = kwargs.pop("username", self.username) password = kwargs.pop("password", self.password) netloc = hostname if port is not None: netloc += f":{port}" if username is not None: userpass = username if password is not None: userpass += f":{password}" netloc = f"{userpass}@{netloc}" kwargs["netloc"] = netloc components = self.components._replace(**kwargs) return self.__class__(components.geturl()) def include_query_params(self, **kwargs: typing.Any) -> "URL": params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) params.update({str(key): str(value) for key, value in kwargs.items()}) query = urlencode(params.multi_items()) return self.replace(query=query) def replace_query_params(self, **kwargs: typing.Any) -> "URL": query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) return self.replace(query=query) def remove_query_params( self, keys: typing.Union[str, typing.Sequence[str]] ) -> "URL": if isinstance(keys, str): keys = [keys] params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) for key in keys: params.pop(key, None) query = urlencode(params.multi_items()) return self.replace(query=query) def __eq__(self, other: typing.Any) -> bool: return str(self) == str(other) def __str__(self) -> str: return self._url def __repr__(self) -> str: url = str(self) if self.password: url = str(self.replace(password="********")) return f"{self.__class__.__name__}({repr(url)})" class URLPath(str): """ A URL path string that may also hold an associated protocol and/or host. Used by the routing to return `url_path_for` matches. """ def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath": assert protocol in ("http", "websocket", "") return str.__new__(cls, path) def __init__(self, path: str, protocol: str = "", host: str = "") -> None: self.protocol = protocol self.host = host def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str: if isinstance(base_url, str): base_url = URL(base_url) if self.protocol: scheme = { "http": {True: "https", False: "http"}, "websocket": {True: "wss", False: "ws"}, }[self.protocol][base_url.is_secure] else: scheme = base_url.scheme netloc = self.host or base_url.netloc path = base_url.path.rstrip("/") + str(self) return str(URL(scheme=scheme, netloc=netloc, path=path)) class Secret: """ Holds a string value that should not be revealed in tracebacks etc. You should cast the value to `str` at the point it is required. """ def __init__(self, value: str): self._value = value def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}('**********')" def __str__(self) -> str: return self._value class CommaSeparatedStrings(Sequence): def __init__(self, value: typing.Union[str, typing.Sequence[str]]): if isinstance(value, str): splitter = shlex(value, posix=True) splitter.whitespace = "," splitter.whitespace_split = True self._items = [item.strip() for item in splitter] else: self._items = list(value) def __len__(self) -> int: return len(self._items) def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any: return self._items[index] def __iter__(self) -> typing.Iterator[str]: return iter(self._items) def __repr__(self) -> str: class_name = self.__class__.__name__ items = [item for item in self] return f"{class_name}({items!r})" def __str__(self) -> str: return ", ".join(repr(item) for item in self) class ImmutableMultiDict(typing.Mapping): def __init__( self, *args: typing.Union[ "ImmutableMultiDict", typing.Mapping, typing.List[typing.Tuple[typing.Any, typing.Any]], ], **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." value = args[0] if args else [] if kwargs: value = ( ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() ) if not value: _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = [] elif hasattr(value, "multi_items"): value = typing.cast(ImmutableMultiDict, value) _items = list(value.multi_items()) elif hasattr(value, "items"): value = typing.cast(typing.Mapping, value) _items = list(value.items()) else: value = typing.cast( typing.List[typing.Tuple[typing.Any, typing.Any]], value ) _items = list(value) self._dict = {k: v for k, v in _items} self._list = _items def getlist(self, key: typing.Any) -> typing.List[typing.Any]: return [item_value for item_key, item_value in self._list if item_key == key] def keys(self) -> typing.KeysView: return self._dict.keys() def values(self) -> typing.ValuesView: return self._dict.values() def items(self) -> typing.ItemsView: return self._dict.items() def multi_items(self) -> typing.List[typing.Tuple[str, str]]: return list(self._list) def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: if key in self._dict: return self._dict[key] return default def __getitem__(self, key: typing.Any) -> str: return self._dict[key] def __contains__(self, key: typing.Any) -> bool: return key in self._dict def __iter__(self) -> typing.Iterator[typing.Any]: return iter(self.keys()) def __len__(self) -> int: return len(self._dict) def __eq__(self, other: typing.Any) -> bool: if not isinstance(other, self.__class__): return False return sorted(self._list) == sorted(other._list) def __repr__(self) -> str: class_name = self.__class__.__name__ items = self.multi_items() return f"{class_name}({items!r})" class MultiDict(ImmutableMultiDict): def __setitem__(self, key: typing.Any, value: typing.Any) -> None: self.setlist(key, [value]) def __delitem__(self, key: typing.Any) -> None: self._list = [(k, v) for k, v in self._list if k != key] del self._dict[key] def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: self._list = [(k, v) for k, v in self._list if k != key] return self._dict.pop(key, default) def popitem(self) -> typing.Tuple: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value def poplist(self, key: typing.Any) -> typing.List: values = [v for k, v in self._list if k == key] self.pop(key) return values def clear(self) -> None: self._dict.clear() self._list.clear() def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: if key not in self: self._dict[key] = default self._list.append((key, default)) return self[key] def setlist(self, key: typing.Any, values: typing.List) -> None: if not values: self.pop(key, None) else: existing_items = [(k, v) for (k, v) in self._list if k != key] self._list = existing_items + [(key, value) for value in values] self._dict[key] = values[-1] def append(self, key: typing.Any, value: typing.Any) -> None: self._list.append((key, value)) self._dict[key] = value def update( self, *args: typing.Union[ "MultiDict", typing.Mapping, typing.List[typing.Tuple[typing.Any, typing.Any]], ], **kwargs: typing.Any, ) -> None: value = MultiDict(*args, **kwargs) existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] self._list = existing_items + value.multi_items() self._dict.update(value) class QueryParams(ImmutableMultiDict): """ An immutable multidict. """ def __init__( self, *args: typing.Union[ "ImmutableMultiDict", typing.Mapping, typing.List[typing.Tuple[typing.Any, typing.Any]], str, bytes, ], **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." value = args[0] if args else [] if isinstance(value, str): super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) elif isinstance(value, bytes): super().__init__( parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs ) else: super().__init__(*args, **kwargs) # type: ignore self._list = [(str(k), str(v)) for k, v in self._list] self._dict = {str(k): str(v) for k, v in self._dict.items()} def __str__(self) -> str: return urlencode(self._list) def __repr__(self) -> str: class_name = self.__class__.__name__ query_string = str(self) return f"{class_name}({query_string!r})" class UploadFile: """ An uploaded file included as part of the request data. """ spool_max_size = 1024 * 1024 file: typing.BinaryIO headers: "Headers" def __init__( self, filename: str, file: typing.Optional[typing.BinaryIO] = None, content_type: str = "", *, headers: "typing.Optional[Headers]" = None, ) -> None: self.filename = filename self.content_type = content_type if file is None: self.file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size) # type: ignore # noqa: E501 else: self.file = file self.headers = headers or Headers() @property def _in_memory(self) -> bool: rolled_to_disk = getattr(self.file, "_rolled", True) return not rolled_to_disk async def write(self, data: bytes) -> None: if self._in_memory: self.file.write(data) else: await run_in_threadpool(self.file.write, data) async def read(self, size: int = -1) -> bytes: if self._in_memory: return self.file.read(size) return await run_in_threadpool(self.file.read, size) async def seek(self, offset: int) -> None: if self._in_memory: self.file.seek(offset) else: await run_in_threadpool(self.file.seek, offset) async def close(self) -> None: if self._in_memory: self.file.close() else: await run_in_threadpool(self.file.close) class FormData(ImmutableMultiDict): """ An immutable multidict, containing both file uploads and text input. """ def __init__( self, *args: typing.Union[ "FormData", typing.Mapping[str, typing.Union[str, UploadFile]], typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], ], **kwargs: typing.Union[str, UploadFile], ) -> None: super().__init__(*args, **kwargs) async def close(self) -> None: for key, value in self.multi_items(): if isinstance(value, UploadFile): await value.close() class Headers(typing.Mapping[str, str]): """ An immutable, case-insensitive multidict. """ def __init__( self, headers: typing.Mapping[str, str] = None, raw: typing.List[typing.Tuple[bytes, bytes]] = None, scope: Scope = None, ) -> None: self._list: typing.List[typing.Tuple[bytes, bytes]] = [] if headers is not None: assert raw is None, 'Cannot set both "headers" and "raw".' assert scope is None, 'Cannot set both "headers" and "scope".' self._list = [ (key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items() ] elif raw is not None: assert scope is None, 'Cannot set both "raw" and "scope".' self._list = raw elif scope is not None: self._list = scope["headers"] @property def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: return list(self._list) def keys(self) -> typing.List[str]: # type: ignore return [key.decode("latin-1") for key, value in self._list] def values(self) -> typing.List[str]: # type: ignore return [value.decode("latin-1") for key, value in self._list] def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore return [ (key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list ] def get(self, key: str, default: typing.Any = None) -> typing.Any: try: return self[key] except KeyError: return default def getlist(self, key: str) -> typing.List[str]: get_header_key = key.lower().encode("latin-1") return [ item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key ] def mutablecopy(self) -> "MutableHeaders": return MutableHeaders(raw=self._list[:]) def __getitem__(self, key: str) -> str: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return header_value.decode("latin-1") raise KeyError(key) def __contains__(self, key: typing.Any) -> bool: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return True return False def __iter__(self) -> typing.Iterator[typing.Any]: return iter(self.keys()) def __len__(self) -> int: return len(self._list) def __eq__(self, other: typing.Any) -> bool: if not isinstance(other, Headers): return False return sorted(self._list) == sorted(other._list) def __repr__(self) -> str: class_name = self.__class__.__name__ as_dict = dict(self.items()) if len(as_dict) == len(self): return f"{class_name}({as_dict!r})" return f"{class_name}(raw={self.raw!r})" class MutableHeaders(Headers): def __setitem__(self, key: str, value: str) -> None: """ Set the header `key` to `value`, removing any duplicate entries. Retains insertion order. """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") found_indexes = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: found_indexes.append(idx) for idx in reversed(found_indexes[1:]): del self._list[idx] if found_indexes: idx = found_indexes[0] self._list[idx] = (set_key, set_value) else: self._list.append((set_key, set_value)) def __delitem__(self, key: str) -> None: """ Remove the header `key`. """ del_key = key.lower().encode("latin-1") pop_indexes = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == del_key: pop_indexes.append(idx) for idx in reversed(pop_indexes): del self._list[idx] @property def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: return self._list def setdefault(self, key: str, value: str) -> str: """ If the header `key` does not exist, then set it to `value`. Returns the header value. """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: return item_value.decode("latin-1") self._list.append((set_key, set_value)) return value def update(self, other: dict) -> None: for key, val in other.items(): self[key] = val def append(self, key: str, value: str) -> None: """ Append a header, preserving any duplicate entries. """ append_key = key.lower().encode("latin-1") append_value = value.encode("latin-1") self._list.append((append_key, append_value)) def add_vary_header(self, vary: str) -> None: existing = self.get("vary") if existing is not None: vary = ", ".join([existing, vary]) self["vary"] = vary class State: """ An object that can be used to store arbitrary state. Used for `request.state` and `app.state`. """ def __init__(self, state: typing.Dict = None): if state is None: state = {} super().__setattr__("_state", state) def __setattr__(self, key: typing.Any, value: typing.Any) -> None: self._state[key] = value def __getattr__(self, key: typing.Any) -> typing.Any: try: return self._state[key] except KeyError: message = "'{}' object has no attribute '{}'" raise AttributeError(message.format(self.__class__.__name__, key)) def __delattr__(self, key: typing.Any) -> None: del self._state[key] starlette-0.18.0/starlette/endpoints.py0000644000175100001710000001123614173233741021002 0ustar runnerdocker00000000000000import asyncio import json import typing from starlette import status from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket class HTTPEndpoint: def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == "http" self.scope = scope self.receive = receive self.send = send def __await__(self) -> typing.Generator: return self.dispatch().__await__() async def dispatch(self) -> None: request = Request(self.scope, receive=self.receive) handler_name = ( "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower() ) handler: typing.Callable[[Request], typing.Any] = getattr( self, handler_name, self.method_not_allowed ) is_async = asyncio.iscoroutinefunction(handler) if is_async: response = await handler(request) else: response = await run_in_threadpool(handler, request) await response(self.scope, self.receive, self.send) async def method_not_allowed(self, request: Request) -> Response: # If we're running inside a starlette application then raise an # exception, so that the configurable exception handler can deal with # returning the response. For plain ASGI apps, just return the response. if "app" in self.scope: raise HTTPException(status_code=405) return PlainTextResponse("Method Not Allowed", status_code=405) class WebSocketEndpoint: encoding: typing.Optional[str] = None # May be "text", "bytes", or "json". def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == "websocket" self.scope = scope self.receive = receive self.send = send def __await__(self) -> typing.Generator: return self.dispatch().__await__() async def dispatch(self) -> None: websocket = WebSocket(self.scope, receive=self.receive, send=self.send) await self.on_connect(websocket) close_code = status.WS_1000_NORMAL_CLOSURE try: while True: message = await websocket.receive() if message["type"] == "websocket.receive": data = await self.decode(websocket, message) await self.on_receive(websocket, data) elif message["type"] == "websocket.disconnect": close_code = int(message.get("code", status.WS_1000_NORMAL_CLOSURE)) break except Exception as exc: close_code = status.WS_1011_INTERNAL_ERROR raise exc finally: await self.on_disconnect(websocket, close_code) async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: if self.encoding == "text": if "text" not in message: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) raise RuntimeError("Expected text websocket messages, but got bytes") return message["text"] elif self.encoding == "bytes": if "bytes" not in message: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) raise RuntimeError("Expected bytes websocket messages, but got text") return message["bytes"] elif self.encoding == "json": if message.get("text") is not None: text = message["text"] else: text = message["bytes"].decode("utf-8") try: return json.loads(text) except json.decoder.JSONDecodeError: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) raise RuntimeError("Malformed JSON data received.") assert ( self.encoding is None ), f"Unsupported 'encoding' attribute {self.encoding}" return message["text"] if message.get("text") else message["bytes"] async def on_connect(self, websocket: WebSocket) -> None: """Override to handle an incoming websocket connection""" await websocket.accept() async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: """Override to handle an incoming websocket message""" async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: """Override to handle a disconnecting websocket""" starlette-0.18.0/starlette/exceptions.py0000644000175100001710000000717314173233741021165 0ustar runnerdocker00000000000000import asyncio import http import typing from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send class HTTPException(Exception): def __init__(self, status_code: int, detail: str = None) -> None: if detail is None: detail = http.HTTPStatus(status_code).phrase self.status_code = status_code self.detail = detail def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" class ExceptionMiddleware: def __init__( self, app: ASGIApp, handlers: typing.Mapping[ typing.Any, typing.Callable[[Request, Exception], Response] ] = None, debug: bool = False, ) -> None: self.app = app self.debug = debug # TODO: We ought to handle 404 cases if debug is set. self._status_handlers: typing.Dict[int, typing.Callable] = {} self._exception_handlers: typing.Dict[ typing.Type[Exception], typing.Callable ] = {HTTPException: self.http_exception} if handlers is not None: for key, value in handlers.items(): self.add_exception_handler(key, value) def add_exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable[[Request, Exception], Response], ) -> None: if isinstance(exc_class_or_status_code, int): self._status_handlers[exc_class_or_status_code] = handler else: assert issubclass(exc_class_or_status_code, Exception) self._exception_handlers[exc_class_or_status_code] = handler def _lookup_exception_handler( self, exc: Exception ) -> typing.Optional[typing.Callable]: for cls in type(exc).__mro__: if cls in self._exception_handlers: return self._exception_handlers[cls] return None async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return response_started = False async def sender(message: Message) -> None: nonlocal response_started if message["type"] == "http.response.start": response_started = True await send(message) try: await self.app(scope, receive, sender) except Exception as exc: handler = None if isinstance(exc, HTTPException): handler = self._status_handlers.get(exc.status_code) if handler is None: handler = self._lookup_exception_handler(exc) if handler is None: raise exc if response_started: msg = "Caught handled exception, but response already started." raise RuntimeError(msg) from exc request = Request(scope, receive=receive) if asyncio.iscoroutinefunction(handler): response = await handler(request, exc) else: response = await run_in_threadpool(handler, request, exc) await response(scope, receive, sender) def http_exception(self, request: Request, exc: HTTPException) -> Response: if exc.status_code in {204, 304}: return Response(b"", status_code=exc.status_code) return PlainTextResponse(exc.detail, status_code=exc.status_code) starlette-0.18.0/starlette/formparsers.py0000644000175100001710000002105714173233741021344 0ustar runnerdocker00000000000000import typing from enum import Enum from urllib.parse import unquote_plus from starlette.datastructures import FormData, Headers, UploadFile try: import multipart from multipart.multipart import parse_options_header except ImportError: # pragma: nocover parse_options_header = None multipart = None class FormMessage(Enum): FIELD_START = 1 FIELD_NAME = 2 FIELD_DATA = 3 FIELD_END = 4 END = 5 class MultiPartMessage(Enum): PART_BEGIN = 1 PART_DATA = 2 PART_END = 3 HEADER_FIELD = 4 HEADER_VALUE = 5 HEADER_END = 6 HEADERS_FINISHED = 7 END = 8 def _user_safe_decode(src: bytes, codec: str) -> str: try: return src.decode(codec) except (UnicodeDecodeError, LookupError): return src.decode("latin-1") class FormParser: def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] ) -> None: assert ( multipart is not None ), "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = [] def on_field_start(self) -> None: message = (FormMessage.FIELD_START, b"") self.messages.append(message) def on_field_name(self, data: bytes, start: int, end: int) -> None: message = (FormMessage.FIELD_NAME, data[start:end]) self.messages.append(message) def on_field_data(self, data: bytes, start: int, end: int) -> None: message = (FormMessage.FIELD_DATA, data[start:end]) self.messages.append(message) def on_field_end(self) -> None: message = (FormMessage.FIELD_END, b"") self.messages.append(message) def on_end(self) -> None: message = (FormMessage.END, b"") self.messages.append(message) async def parse(self) -> FormData: # Callbacks dictionary. callbacks = { "on_field_start": self.on_field_start, "on_field_name": self.on_field_name, "on_field_data": self.on_field_data, "on_field_end": self.on_field_end, "on_end": self.on_end, } # Create the parser. parser = multipart.QuerystringParser(callbacks) field_name = b"" field_value = b"" items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] # Feed the parser with data from the request. async for chunk in self.stream: if chunk: parser.write(chunk) else: parser.finalize() messages = list(self.messages) self.messages.clear() for message_type, message_bytes in messages: if message_type == FormMessage.FIELD_START: field_name = b"" field_value = b"" elif message_type == FormMessage.FIELD_NAME: field_name += message_bytes elif message_type == FormMessage.FIELD_DATA: field_value += message_bytes elif message_type == FormMessage.FIELD_END: name = unquote_plus(field_name.decode("latin-1")) value = unquote_plus(field_value.decode("latin-1")) items.append((name, value)) return FormData(items) class MultiPartParser: def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] ) -> None: assert ( multipart is not None ), "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.messages: typing.List[typing.Tuple[MultiPartMessage, bytes]] = [] def on_part_begin(self) -> None: message = (MultiPartMessage.PART_BEGIN, b"") self.messages.append(message) def on_part_data(self, data: bytes, start: int, end: int) -> None: message = (MultiPartMessage.PART_DATA, data[start:end]) self.messages.append(message) def on_part_end(self) -> None: message = (MultiPartMessage.PART_END, b"") self.messages.append(message) def on_header_field(self, data: bytes, start: int, end: int) -> None: message = (MultiPartMessage.HEADER_FIELD, data[start:end]) self.messages.append(message) def on_header_value(self, data: bytes, start: int, end: int) -> None: message = (MultiPartMessage.HEADER_VALUE, data[start:end]) self.messages.append(message) def on_header_end(self) -> None: message = (MultiPartMessage.HEADER_END, b"") self.messages.append(message) def on_headers_finished(self) -> None: message = (MultiPartMessage.HEADERS_FINISHED, b"") self.messages.append(message) def on_end(self) -> None: message = (MultiPartMessage.END, b"") self.messages.append(message) async def parse(self) -> FormData: # Parse the Content-Type header to get the multipart boundary. content_type, params = parse_options_header(self.headers["Content-Type"]) charset = params.get(b"charset", "utf-8") if type(charset) == bytes: charset = charset.decode("latin-1") boundary = params.get(b"boundary") # Callbacks dictionary. callbacks = { "on_part_begin": self.on_part_begin, "on_part_data": self.on_part_data, "on_part_end": self.on_part_end, "on_header_field": self.on_header_field, "on_header_value": self.on_header_value, "on_header_end": self.on_header_end, "on_headers_finished": self.on_headers_finished, "on_end": self.on_end, } # Create the parser. parser = multipart.MultipartParser(boundary, callbacks) header_field = b"" header_value = b"" content_disposition = None content_type = b"" field_name = "" data = b"" file: typing.Optional[UploadFile] = None items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] item_headers: typing.List[typing.Tuple[bytes, bytes]] = [] # Feed the parser with data from the request. async for chunk in self.stream: parser.write(chunk) messages = list(self.messages) self.messages.clear() for message_type, message_bytes in messages: if message_type == MultiPartMessage.PART_BEGIN: content_disposition = None content_type = b"" data = b"" item_headers = [] elif message_type == MultiPartMessage.HEADER_FIELD: header_field += message_bytes elif message_type == MultiPartMessage.HEADER_VALUE: header_value += message_bytes elif message_type == MultiPartMessage.HEADER_END: field = header_field.lower() if field == b"content-disposition": content_disposition = header_value elif field == b"content-type": content_type = header_value item_headers.append((field, header_value)) header_field = b"" header_value = b"" elif message_type == MultiPartMessage.HEADERS_FINISHED: disposition, options = parse_options_header(content_disposition) field_name = _user_safe_decode(options[b"name"], charset) if b"filename" in options: filename = _user_safe_decode(options[b"filename"], charset) file = UploadFile( filename=filename, content_type=content_type.decode("latin-1"), headers=Headers(raw=item_headers), ) else: file = None elif message_type == MultiPartMessage.PART_DATA: if file is None: data += message_bytes else: await file.write(message_bytes) elif message_type == MultiPartMessage.PART_END: if file is None: items.append((field_name, _user_safe_decode(data, charset))) else: await file.seek(0) items.append((field_name, file)) parser.finalize() return FormData(items) starlette-0.18.0/starlette/middleware/0000755000175100001710000000000014173233775020546 5ustar runnerdocker00000000000000starlette-0.18.0/starlette/middleware/__init__.py0000644000175100001710000000104214173233741022645 0ustar runnerdocker00000000000000import typing class Middleware: def __init__(self, cls: type, **options: typing.Any) -> None: self.cls = cls self.options = options def __iter__(self) -> typing.Iterator: as_tuple = (self.cls, self.options) return iter(as_tuple) def __repr__(self) -> str: class_name = self.__class__.__name__ option_strings = [f"{key}={value!r}" for key, value in self.options.items()] args_repr = ", ".join([self.cls.__name__] + option_strings) return f"{class_name}({args_repr})" starlette-0.18.0/starlette/middleware/authentication.py0000644000175100001710000000335214173233741024133 0ustar runnerdocker00000000000000import typing from starlette.authentication import ( AuthCredentials, AuthenticationBackend, AuthenticationError, UnauthenticatedUser, ) from starlette.requests import HTTPConnection from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send class AuthenticationMiddleware: def __init__( self, app: ASGIApp, backend: AuthenticationBackend, on_error: typing.Callable[ [HTTPConnection, AuthenticationError], Response ] = None, ) -> None: self.app = app self.backend = backend self.on_error: typing.Callable[ [HTTPConnection, AuthenticationError], Response ] = (on_error if on_error is not None else self.default_on_error) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ["http", "websocket"]: await self.app(scope, receive, send) return conn = HTTPConnection(scope) try: auth_result = await self.backend.authenticate(conn) except AuthenticationError as exc: response = self.on_error(conn, exc) if scope["type"] == "websocket": await send({"type": "websocket.close", "code": 1000}) else: await response(scope, receive, send) return if auth_result is None: auth_result = AuthCredentials(), UnauthenticatedUser() scope["auth"], scope["user"] = auth_result await self.app(scope, receive, send) @staticmethod def default_on_error(conn: HTTPConnection, exc: Exception) -> Response: return PlainTextResponse(str(exc), status_code=400) starlette-0.18.0/starlette/middleware/base.py0000644000175100001710000000477314173233741022036 0ustar runnerdocker00000000000000import typing import anyio from starlette.requests import Request from starlette.responses import Response, StreamingResponse from starlette.types import ASGIApp, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] ] class BaseHTTPMiddleware: def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None: self.app = app self.dispatch_func = self.dispatch if dispatch is None else dispatch async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream() async def coro() -> None: nonlocal app_exc async with send_stream: try: await self.app(scope, request.receive, send_stream.send) except Exception as exc: app_exc = exc task_group.start_soon(coro) try: message = await recv_stream.receive() except anyio.EndOfStream: if app_exc is not None: raise app_exc raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: async with recv_stream: async for message in recv_stream: assert message["type"] == "http.response.body" yield message.get("body", b"") response = StreamingResponse( status_code=message["status"], content=body_stream() ) response.raw_headers = message["headers"] return response async with anyio.create_task_group() as task_group: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: raise NotImplementedError() # pragma: no cover starlette-0.18.0/starlette/middleware/cors.py0000644000175100001710000001562314173233741022066 0ustar runnerdocker00000000000000import functools import re import typing from starlette.datastructures import Headers, MutableHeaders from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} class CORSMiddleware: def __init__( self, app: ASGIApp, allow_origins: typing.Sequence[str] = (), allow_methods: typing.Sequence[str] = ("GET",), allow_headers: typing.Sequence[str] = (), allow_credentials: bool = False, allow_origin_regex: str = None, expose_headers: typing.Sequence[str] = (), max_age: int = 600, ) -> None: if "*" in allow_methods: allow_methods = ALL_METHODS compiled_allow_origin_regex = None if allow_origin_regex is not None: compiled_allow_origin_regex = re.compile(allow_origin_regex) allow_all_origins = "*" in allow_origins allow_all_headers = "*" in allow_headers preflight_explicit_allow_origin = not allow_all_origins or allow_credentials simple_headers = {} if allow_all_origins: simple_headers["Access-Control-Allow-Origin"] = "*" if allow_credentials: simple_headers["Access-Control-Allow-Credentials"] = "true" if expose_headers: simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) preflight_headers = {} if preflight_explicit_allow_origin: # The origin value will be set in preflight_response() if it is allowed. preflight_headers["Vary"] = "Origin" else: preflight_headers["Access-Control-Allow-Origin"] = "*" preflight_headers.update( { "Access-Control-Allow-Methods": ", ".join(allow_methods), "Access-Control-Max-Age": str(max_age), } ) allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) if allow_headers and not allow_all_headers: preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) if allow_credentials: preflight_headers["Access-Control-Allow-Credentials"] = "true" self.app = app self.allow_origins = allow_origins self.allow_methods = allow_methods self.allow_headers = [h.lower() for h in allow_headers] self.allow_all_origins = allow_all_origins self.allow_all_headers = allow_all_headers self.preflight_explicit_allow_origin = preflight_explicit_allow_origin self.allow_origin_regex = compiled_allow_origin_regex self.simple_headers = simple_headers self.preflight_headers = preflight_headers async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": # pragma: no cover await self.app(scope, receive, send) return method = scope["method"] headers = Headers(scope=scope) origin = headers.get("origin") if origin is None: await self.app(scope, receive, send) return if method == "OPTIONS" and "access-control-request-method" in headers: response = self.preflight_response(request_headers=headers) await response(scope, receive, send) return await self.simple_response(scope, receive, send, request_headers=headers) def is_allowed_origin(self, origin: str) -> bool: if self.allow_all_origins: return True if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch( origin ): return True return origin in self.allow_origins def preflight_response(self, request_headers: Headers) -> Response: requested_origin = request_headers["origin"] requested_method = request_headers["access-control-request-method"] requested_headers = request_headers.get("access-control-request-headers") headers = dict(self.preflight_headers) failures = [] if self.is_allowed_origin(origin=requested_origin): if self.preflight_explicit_allow_origin: # The "else" case is already accounted for in self.preflight_headers # and the value would be "*". headers["Access-Control-Allow-Origin"] = requested_origin else: failures.append("origin") if requested_method not in self.allow_methods: failures.append("method") # If we allow all headers, then we have to mirror back any requested # headers in the response. if self.allow_all_headers and requested_headers is not None: headers["Access-Control-Allow-Headers"] = requested_headers elif requested_headers is not None: for header in [h.lower() for h in requested_headers.split(",")]: if header.strip() not in self.allow_headers: failures.append("headers") break # We don't strictly need to use 400 responses here, since its up to # the browser to enforce the CORS policy, but its more informative # if we do. if failures: failure_text = "Disallowed CORS " + ", ".join(failures) return PlainTextResponse(failure_text, status_code=400, headers=headers) return PlainTextResponse("OK", status_code=200, headers=headers) async def simple_response( self, scope: Scope, receive: Receive, send: Send, request_headers: Headers ) -> None: send = functools.partial(self.send, send=send, request_headers=request_headers) await self.app(scope, receive, send) async def send( self, message: Message, send: Send, request_headers: Headers ) -> None: if message["type"] != "http.response.start": await send(message) return message.setdefault("headers", []) headers = MutableHeaders(scope=message) headers.update(self.simple_headers) origin = request_headers["Origin"] has_cookie = "cookie" in request_headers # If request includes any cookie headers, then we must respond # with the specific origin instead of '*'. if self.allow_all_origins and has_cookie: self.allow_explicit_origin(headers, origin) # If we only allow specific origins, then we have to mirror back # the Origin header in the response. elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): self.allow_explicit_origin(headers, origin) await send(message) @staticmethod def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: headers["Access-Control-Allow-Origin"] = origin headers.add_vary_header("Origin") starlette-0.18.0/starlette/middleware/errors.py0000644000175100001710000001714514173233741022435 0ustar runnerdocker00000000000000import asyncio import html import inspect import traceback import typing from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import HTMLResponse, PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send STYLES = """ p { color: #211c1c; } .traceback-container { border: 1px solid #038BB8; } .traceback-title { background-color: #038BB8; color: lemonchiffon; padding: 12px; font-size: 20px; margin-top: 0px; } .frame-line { padding-left: 10px; font-family: monospace; } .frame-filename { font-family: monospace; } .center-line { background-color: #038BB8; color: #f9f6e1; padding: 5px 0px 5px 5px; } .lineno { margin-right: 5px; } .frame-title { font-weight: unset; padding: 10px 10px 10px 10px; background-color: #E4F4FD; margin-right: 10px; color: #191f21; font-size: 17px; border: 1px solid #c7dce8; } .collapse-btn { float: right; padding: 0px 5px 1px 5px; border: solid 1px #96aebb; cursor: pointer; } .collapsed { display: none; } .source-code { font-family: courier; font-size: small; padding-bottom: 10px; } """ JS = """ """ TEMPLATE = """ Starlette Debugger

500 Server Error

{error}

Traceback

{exc_html}
{js} """ FRAME_TEMPLATE = """

File {frame_filename}, line {frame_lineno}, in {frame_name} {collapse_button}

{code_context}
""" # noqa: E501 LINE = """

{lineno}. {line}

""" CENTER_LINE = """

{lineno}. {line}

""" class ServerErrorMiddleware: """ Handles returning 500 responses when a server error occurs. If 'debug' is set, then traceback responses will be returned, otherwise the designated 'handler' will be called. This middleware class should generally be used to wrap *everything* else up, so that unhandled exceptions anywhere in the stack always result in an appropriate 500 response. """ def __init__( self, app: ASGIApp, handler: typing.Callable = None, debug: bool = False ) -> None: self.app = app self.handler = handler self.debug = debug async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return response_started = False async def _send(message: Message) -> None: nonlocal response_started, send if message["type"] == "http.response.start": response_started = True await send(message) try: await self.app(scope, receive, _send) except Exception as exc: if not response_started: request = Request(scope) if self.debug: # In debug mode, return traceback responses. response = self.debug_response(request, exc) elif self.handler is None: # Use our default 500 error handler. response = self.error_response(request, exc) else: # Use an installed 500 error handler. if asyncio.iscoroutinefunction(self.handler): response = await self.handler(request, exc) else: response = await run_in_threadpool(self.handler, request, exc) await response(scope, receive, send) # We always continue to raise the exception. # This allows servers to log the error, or allows test clients # to optionally raise the error within the test case. raise exc def format_line( self, index: int, line: str, frame_lineno: int, frame_index: int ) -> str: values = { # HTML escape - line could contain < or > "line": html.escape(line).replace(" ", " "), "lineno": (frame_lineno - frame_index) + index, } if index != frame_index: return LINE.format(**values) return CENTER_LINE.format(**values) def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str: code_context = "".join( self.format_line(index, line, frame.lineno, frame.index) # type: ignore for index, line in enumerate(frame.code_context or []) ) values = { # HTML escape - filename could contain < or >, especially if it's a virtual # file e.g. in the REPL "frame_filename": html.escape(frame.filename), "frame_lineno": frame.lineno, # HTML escape - if you try very hard it's possible to name a function with < # or > "frame_name": html.escape(frame.function), "code_context": code_context, "collapsed": "collapsed" if is_collapsed else "", "collapse_button": "+" if is_collapsed else "‒", } return FRAME_TEMPLATE.format(**values) def generate_html(self, exc: Exception, limit: int = 7) -> str: traceback_obj = traceback.TracebackException.from_exception( exc, capture_locals=True ) exc_html = "" is_collapsed = False exc_traceback = exc.__traceback__ if exc_traceback is not None: frames = inspect.getinnerframes(exc_traceback, limit) for frame in reversed(frames): exc_html += self.generate_frame_html(frame, is_collapsed) is_collapsed = True # escape error class and text error = ( f"{html.escape(traceback_obj.exc_type.__name__)}: " f"{html.escape(str(traceback_obj))}" ) return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html) def generate_plain_text(self, exc: Exception) -> str: return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) def debug_response(self, request: Request, exc: Exception) -> Response: accept = request.headers.get("accept", "") if "text/html" in accept: content = self.generate_html(exc) return HTMLResponse(content, status_code=500) content = self.generate_plain_text(exc) return PlainTextResponse(content, status_code=500) def error_response(self, request: Request, exc: Exception) -> Response: return PlainTextResponse("Internal Server Error", status_code=500) starlette-0.18.0/starlette/middleware/gzip.py0000644000175100001710000000776714173233741022103 0ustar runnerdocker00000000000000import gzip import io import typing from starlette.datastructures import Headers, MutableHeaders from starlette.types import ASGIApp, Message, Receive, Scope, Send class GZipMiddleware: def __init__( self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9 ) -> None: self.app = app self.minimum_size = minimum_size self.compresslevel = compresslevel async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": headers = Headers(scope=scope) if "gzip" in headers.get("Accept-Encoding", ""): responder = GZipResponder( self.app, self.minimum_size, compresslevel=self.compresslevel ) await responder(scope, receive, send) return await self.app(scope, receive, send) class GZipResponder: def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: self.app = app self.minimum_size = minimum_size self.send: Send = unattached_send self.initial_message: Message = {} self.started = False self.gzip_buffer = io.BytesIO() self.gzip_file = gzip.GzipFile( mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.send = send await self.app(scope, receive, self.send_with_gzip) async def send_with_gzip(self, message: Message) -> None: message_type = message["type"] if message_type == "http.response.start": # Don't send the initial message until we've determined how to # modify the outgoing headers correctly. self.initial_message = message elif message_type == "http.response.body" and not self.started: self.started = True body = message.get("body", b"") more_body = message.get("more_body", False) if len(body) < self.minimum_size and not more_body: # Don't apply GZip to small outgoing responses. await self.send(self.initial_message) await self.send(message) elif not more_body: # Standard GZip response. self.gzip_file.write(body) self.gzip_file.close() body = self.gzip_buffer.getvalue() headers = MutableHeaders(raw=self.initial_message["headers"]) headers["Content-Encoding"] = "gzip" headers["Content-Length"] = str(len(body)) headers.add_vary_header("Accept-Encoding") message["body"] = body await self.send(self.initial_message) await self.send(message) else: # Initial body in streaming GZip response. headers = MutableHeaders(raw=self.initial_message["headers"]) headers["Content-Encoding"] = "gzip" headers.add_vary_header("Accept-Encoding") del headers["Content-Length"] self.gzip_file.write(body) message["body"] = self.gzip_buffer.getvalue() self.gzip_buffer.seek(0) self.gzip_buffer.truncate() await self.send(self.initial_message) await self.send(message) elif message_type == "http.response.body": # Remaining body in streaming GZip response. body = message.get("body", b"") more_body = message.get("more_body", False) self.gzip_file.write(body) if not more_body: self.gzip_file.close() message["body"] = self.gzip_buffer.getvalue() self.gzip_buffer.seek(0) self.gzip_buffer.truncate() await self.send(message) async def unattached_send(message: Message) -> typing.NoReturn: raise RuntimeError("send awaitable not set") # pragma: no cover starlette-0.18.0/starlette/middleware/httpsredirect.py0000644000175100001710000000152014173233741023773 0ustar runnerdocker00000000000000from starlette.datastructures import URL from starlette.responses import RedirectResponse from starlette.types import ASGIApp, Receive, Scope, Send class HTTPSRedirectMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"): url = URL(scope=scope) redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme] netloc = url.hostname if url.port in (80, 443) else url.netloc url = url.replace(scheme=redirect_scheme, netloc=netloc) response = RedirectResponse(url, status_code=307) await response(scope, receive, send) else: await self.app(scope, receive, send) starlette-0.18.0/starlette/middleware/sessions.py0000644000175100001710000000627414173233741022770 0ustar runnerdocker00000000000000import json import typing from base64 import b64decode, b64encode import itsdangerous from itsdangerous.exc import BadSignature from starlette.datastructures import MutableHeaders, Secret from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Message, Receive, Scope, Send class SessionMiddleware: def __init__( self, app: ASGIApp, secret_key: typing.Union[str, Secret], session_cookie: str = "session", max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds same_site: str = "lax", https_only: bool = False, ) -> None: self.app = app self.signer = itsdangerous.TimestampSigner(str(secret_key)) self.session_cookie = session_cookie self.max_age = max_age self.security_flags = "httponly; samesite=" + same_site if https_only: # Secure flag can be used with HTTPS only self.security_flags += "; secure" async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): # pragma: no cover await self.app(scope, receive, send) return connection = HTTPConnection(scope) initial_session_was_empty = True if self.session_cookie in connection.cookies: data = connection.cookies[self.session_cookie].encode("utf-8") try: data = self.signer.unsign(data, max_age=self.max_age) scope["session"] = json.loads(b64decode(data)) initial_session_was_empty = False except BadSignature: scope["session"] = {} else: scope["session"] = {} async def send_wrapper(message: Message) -> None: if message["type"] == "http.response.start": path = scope.get("root_path", "") or "/" if scope["session"]: # We have session data to persist. data = b64encode(json.dumps(scope["session"]).encode("utf-8")) data = self.signer.sign(data) headers = MutableHeaders(scope=message) header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501 session_cookie=self.session_cookie, data=data.decode("utf-8"), path=path, max_age=f"Max-Age={self.max_age}; " if self.max_age else "", security_flags=self.security_flags, ) headers.append("Set-Cookie", header_value) elif not initial_session_was_empty: # The session has been cleared. headers = MutableHeaders(scope=message) header_value = "{}={}; {}".format( self.session_cookie, f"null; path={path}; expires=Thu, 01 Jan 1970 00:00:00 GMT;", self.security_flags, ) headers.append("Set-Cookie", header_value) await send(message) await self.app(scope, receive, send_wrapper) starlette-0.18.0/starlette/middleware/trustedhost.py0000644000175100001710000000421614173233741023504 0ustar runnerdocker00000000000000import typing from starlette.datastructures import URL, Headers from starlette.responses import PlainTextResponse, RedirectResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." class TrustedHostMiddleware: def __init__( self, app: ASGIApp, allowed_hosts: typing.Sequence[str] = None, www_redirect: bool = True, ) -> None: if allowed_hosts is None: allowed_hosts = ["*"] for pattern in allowed_hosts: assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD if pattern.startswith("*") and pattern != "*": assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD self.app = app self.allowed_hosts = list(allowed_hosts) self.allow_any = "*" in allowed_hosts self.www_redirect = www_redirect async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.allow_any or scope["type"] not in ( "http", "websocket", ): # pragma: no cover await self.app(scope, receive, send) return headers = Headers(scope=scope) host = headers.get("host", "").split(":")[0] is_valid_host = False found_www_redirect = False for pattern in self.allowed_hosts: if host == pattern or ( pattern.startswith("*") and host.endswith(pattern[1:]) ): is_valid_host = True break elif "www." + host == pattern: found_www_redirect = True if is_valid_host: await self.app(scope, receive, send) else: response: Response if found_www_redirect and self.www_redirect: url = URL(scope=scope) redirect_url = url.replace(netloc="www." + url.netloc) response = RedirectResponse(url=str(redirect_url)) else: response = PlainTextResponse("Invalid host header", status_code=400) await response(scope, receive, send) starlette-0.18.0/starlette/middleware/wsgi.py0000644000175100001710000001111114173233741022055 0ustar runnerdocker00000000000000import io import math import sys import typing import anyio from starlette.types import Receive, Scope, Send def build_environ(scope: Scope, body: bytes) -> dict: """ Builds a scope and request body into a WSGI environ object. """ environ = { "REQUEST_METHOD": scope["method"], "SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"), "PATH_INFO": scope["path"].encode("utf8").decode("latin1"), "QUERY_STRING": scope["query_string"].decode("ascii"), "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", "wsgi.version": (1, 0), "wsgi.url_scheme": scope.get("scheme", "http"), "wsgi.input": io.BytesIO(body), "wsgi.errors": sys.stdout, "wsgi.multithread": True, "wsgi.multiprocess": True, "wsgi.run_once": False, } # Get server name and port - required in WSGI, not in ASGI server = scope.get("server") or ("localhost", 80) environ["SERVER_NAME"] = server[0] environ["SERVER_PORT"] = server[1] # Get client IP address if scope.get("client"): environ["REMOTE_ADDR"] = scope["client"][0] # Go through headers and make them into environ entries for name, value in scope.get("headers", []): name = name.decode("latin1") if name == "content-length": corrected_name = "CONTENT_LENGTH" elif name == "content-type": corrected_name = "CONTENT_TYPE" else: corrected_name = f"HTTP_{name}".upper().replace("-", "_") # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in # case value = value.decode("latin1") if corrected_name in environ: value = environ[corrected_name] + "," + value environ[corrected_name] = value return environ class WSGIMiddleware: def __init__(self, app: typing.Callable) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == "http" responder = WSGIResponder(self.app, scope) await responder(receive, send) class WSGIResponder: def __init__(self, app: typing.Callable, scope: Scope) -> None: self.app = app self.scope = scope self.status = None self.response_headers = None self.stream_send, self.stream_receive = anyio.create_memory_object_stream( math.inf ) self.response_started = False self.exc_info: typing.Any = None async def __call__(self, receive: Receive, send: Send) -> None: body = b"" more_body = True while more_body: message = await receive() body += message.get("body", b"") more_body = message.get("more_body", False) environ = build_environ(self.scope, body) async with anyio.create_task_group() as task_group: task_group.start_soon(self.sender, send) async with self.stream_send: await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) if self.exc_info is not None: raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) async def sender(self, send: Send) -> None: async with self.stream_receive: async for message in self.stream_receive: await send(message) def start_response( self, status: str, response_headers: typing.List[typing.Tuple[str, str]], exc_info: typing.Any = None, ) -> None: self.exc_info = exc_info if not self.response_started: self.response_started = True status_code_string, _ = status.split(" ", 1) status_code = int(status_code_string) headers = [ (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] anyio.from_thread.run( self.stream_send.send, { "type": "http.response.start", "status": status_code, "headers": headers, }, ) def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): anyio.from_thread.run( self.stream_send.send, {"type": "http.response.body", "body": chunk, "more_body": True}, ) anyio.from_thread.run( self.stream_send.send, {"type": "http.response.body", "body": b""} ) starlette-0.18.0/starlette/requests.py0000644000175100001710000002231514173233741020652 0ustar runnerdocker00000000000000import json import typing from collections.abc import Mapping from http import cookies as http_cookies import anyio from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State from starlette.formparsers import FormParser, MultiPartParser from starlette.types import Message, Receive, Scope, Send try: from multipart.multipart import parse_options_header except ImportError: # pragma: nocover parse_options_header = None if typing.TYPE_CHECKING: from starlette.routing import Router SERVER_PUSH_HEADERS_TO_COPY = { "accept", "accept-encoding", "accept-language", "cache-control", "user-agent", } def cookie_parser(cookie_string: str) -> typing.Dict[str, str]: """ This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. It attempts to mimic browser cookie parsing behavior: browsers and web servers frequently disregard the spec (RFC 6265) when setting and reading cookies, so we attempt to suit the common scenarios here. This function has been adapted from Django 3.1.0. Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based on an outdated spec and will fail on lots of input we want to support """ cookie_dict: typing.Dict[str, str] = {} for chunk in cookie_string.split(";"): if "=" in chunk: key, val = chunk.split("=", 1) else: # Assume an empty name per # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 key, val = "", chunk key, val = key.strip(), val.strip() if key or val: # unquote using Python's algorithm. cookie_dict[key] = http_cookies._unquote(val) return cookie_dict class ClientDisconnect(Exception): pass class HTTPConnection(Mapping): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. """ def __init__(self, scope: Scope, receive: Receive = None) -> None: assert scope["type"] in ("http", "websocket") self.scope = scope def __getitem__(self, key: str) -> typing.Any: return self.scope[key] def __iter__(self) -> typing.Iterator[str]: return iter(self.scope) def __len__(self) -> int: return len(self.scope) # Don't use the `abc.Mapping.__eq__` implementation. # Connection instances should never be considered equal # unless `self is other`. __eq__ = object.__eq__ __hash__ = object.__hash__ @property def app(self) -> typing.Any: return self.scope["app"] @property def url(self) -> URL: if not hasattr(self, "_url"): self._url = URL(scope=self.scope) return self._url @property def base_url(self) -> URL: if not hasattr(self, "_base_url"): base_url_scope = dict(self.scope) base_url_scope["path"] = "/" base_url_scope["query_string"] = b"" base_url_scope["root_path"] = base_url_scope.get( "app_root_path", base_url_scope.get("root_path", "") ) self._base_url = URL(scope=base_url_scope) return self._base_url @property def headers(self) -> Headers: if not hasattr(self, "_headers"): self._headers = Headers(scope=self.scope) return self._headers @property def query_params(self) -> QueryParams: if not hasattr(self, "_query_params"): self._query_params = QueryParams(self.scope["query_string"]) return self._query_params @property def path_params(self) -> dict: return self.scope.get("path_params", {}) @property def cookies(self) -> typing.Dict[str, str]: if not hasattr(self, "_cookies"): cookies: typing.Dict[str, str] = {} cookie_header = self.headers.get("cookie") if cookie_header: cookies = cookie_parser(cookie_header) self._cookies = cookies return self._cookies @property def client(self) -> Address: host, port = self.scope.get("client") or (None, None) return Address(host=host, port=port) @property def session(self) -> dict: assert ( "session" in self.scope ), "SessionMiddleware must be installed to access request.session" return self.scope["session"] @property def auth(self) -> typing.Any: assert ( "auth" in self.scope ), "AuthenticationMiddleware must be installed to access request.auth" return self.scope["auth"] @property def user(self) -> typing.Any: assert ( "user" in self.scope ), "AuthenticationMiddleware must be installed to access request.user" return self.scope["user"] @property def state(self) -> State: if not hasattr(self, "_state"): # Ensure 'state' has an empty dict if it's not already populated. self.scope.setdefault("state", {}) # Create a state instance with a reference to the dict in which it should # store info self._state = State(self.scope["state"]) return self._state def url_for(self, name: str, **path_params: typing.Any) -> str: router: Router = self.scope["router"] url_path = router.url_path_for(name, **path_params) return url_path.make_absolute_url(base_url=self.base_url) async def empty_receive() -> typing.NoReturn: raise RuntimeError("Receive channel has not been made available") async def empty_send(message: Message) -> typing.NoReturn: raise RuntimeError("Send channel has not been made available") class Request(HTTPConnection): def __init__( self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send ): super().__init__(scope) assert scope["type"] == "http" self._receive = receive self._send = send self._stream_consumed = False self._is_disconnected = False @property def method(self) -> str: return self.scope["method"] @property def receive(self) -> Receive: return self._receive async def stream(self) -> typing.AsyncGenerator[bytes, None]: if hasattr(self, "_body"): yield self._body yield b"" return if self._stream_consumed: raise RuntimeError("Stream consumed") self._stream_consumed = True while True: message = await self._receive() if message["type"] == "http.request": body = message.get("body", b"") if body: yield body if not message.get("more_body", False): break elif message["type"] == "http.disconnect": self._is_disconnected = True raise ClientDisconnect() yield b"" async def body(self) -> bytes: if not hasattr(self, "_body"): chunks = [] async for chunk in self.stream(): chunks.append(chunk) self._body = b"".join(chunks) return self._body async def json(self) -> typing.Any: if not hasattr(self, "_json"): body = await self.body() self._json = json.loads(body) return self._json async def form(self) -> FormData: if not hasattr(self, "_form"): assert ( parse_options_header is not None ), "The `python-multipart` library must be installed to use form parsing." content_type_header = self.headers.get("Content-Type") content_type, options = parse_options_header(content_type_header) if content_type == b"multipart/form-data": multipart_parser = MultiPartParser(self.headers, self.stream()) self._form = await multipart_parser.parse() elif content_type == b"application/x-www-form-urlencoded": form_parser = FormParser(self.headers, self.stream()) self._form = await form_parser.parse() else: self._form = FormData() return self._form async def close(self) -> None: if hasattr(self, "_form"): await self._form.close() async def is_disconnected(self) -> bool: if not self._is_disconnected: message: Message = {} # If message isn't immediately available, move on with anyio.CancelScope() as cs: cs.cancel() message = await self._receive() if message.get("type") == "http.disconnect": self._is_disconnected = True return self._is_disconnected async def send_push_promise(self, path: str) -> None: if "http.response.push" in self.scope.get("extensions", {}): raw_headers = [] for name in SERVER_PUSH_HEADERS_TO_COPY: for value in self.headers.getlist(name): raw_headers.append( (name.encode("latin-1"), value.encode("latin-1")) ) await self._send( {"type": "http.response.push", "path": path, "headers": raw_headers} ) starlette-0.18.0/starlette/responses.py0000644000175100001710000002634614173233741021030 0ustar runnerdocker00000000000000import http.cookies import json import os import stat import sys import typing from email.utils import formatdate from functools import partial from mimetypes import guess_type as mimetypes_guess_type from urllib.parse import quote import anyio from starlette._compat import md5_hexdigest from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders from starlette.types import Receive, Scope, Send # Workaround for adding samesite support to pre 3.8 python http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore # Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on typing.Tuple[typing.Optional[str], typing.Optional[str]]: if sys.version_info < (3, 8): # pragma: no cover url = os.fspath(url) return mimetypes_guess_type(url, strict) class Response: media_type = None charset = "utf-8" def __init__( self, content: typing.Any = None, status_code: int = 200, headers: dict = None, media_type: str = None, background: BackgroundTask = None, ) -> None: self.status_code = status_code if media_type is not None: self.media_type = media_type self.background = background self.body = self.render(content) self.init_headers(headers) def render(self, content: typing.Any) -> bytes: if content is None: return b"" if isinstance(content, bytes): return content return content.encode(self.charset) def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: if headers is None: raw_headers: typing.List[typing.Tuple[bytes, bytes]] = [] populate_content_length = True populate_content_type = True else: raw_headers = [ (k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items() ] keys = [h[0] for h in raw_headers] populate_content_length = b"content-length" not in keys populate_content_type = b"content-type" not in keys body = getattr(self, "body", None) if ( body is not None and populate_content_length and not (self.status_code < 200 or self.status_code in (204, 304)) ): content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1"))) content_type = self.media_type if content_type is not None and populate_content_type: if content_type.startswith("text/"): content_type += "; charset=" + self.charset raw_headers.append((b"content-type", content_type.encode("latin-1"))) self.raw_headers = raw_headers @property def headers(self) -> MutableHeaders: if not hasattr(self, "_headers"): self._headers = MutableHeaders(raw=self.raw_headers) return self._headers def set_cookie( self, key: str, value: str = "", max_age: int = None, expires: int = None, path: str = "/", domain: str = None, secure: bool = False, httponly: bool = False, samesite: str = "lax", ) -> None: cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie() cookie[key] = value if max_age is not None: cookie[key]["max-age"] = max_age if expires is not None: cookie[key]["expires"] = expires if path is not None: cookie[key]["path"] = path if domain is not None: cookie[key]["domain"] = domain if secure: cookie[key]["secure"] = True if httponly: cookie[key]["httponly"] = True if samesite is not None: assert samesite.lower() in [ "strict", "lax", "none", ], "samesite must be either 'strict', 'lax' or 'none'" cookie[key]["samesite"] = samesite cookie_val = cookie.output(header="").strip() self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) def delete_cookie( self, key: str, path: str = "/", domain: str = None, secure: bool = False, httponly: bool = False, samesite: str = "lax", ) -> None: self.set_cookie( key, max_age=0, expires=0, path=path, domain=domain, secure=secure, httponly=httponly, samesite=samesite, ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await send( { "type": "http.response.start", "status": self.status_code, "headers": self.raw_headers, } ) await send({"type": "http.response.body", "body": self.body}) if self.background is not None: await self.background() class HTMLResponse(Response): media_type = "text/html" class PlainTextResponse(Response): media_type = "text/plain" class JSONResponse(Response): media_type = "application/json" def render(self, content: typing.Any) -> bytes: return json.dumps( content, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), ).encode("utf-8") class RedirectResponse(Response): def __init__( self, url: typing.Union[str, URL], status_code: int = 307, headers: dict = None, background: BackgroundTask = None, ) -> None: super().__init__( content=b"", status_code=status_code, headers=headers, background=background ) self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") class StreamingResponse(Response): def __init__( self, content: typing.Any, status_code: int = 200, headers: dict = None, media_type: str = None, background: BackgroundTask = None, ) -> None: if isinstance(content, typing.AsyncIterable): self.body_iterator = content else: self.body_iterator = iterate_in_threadpool(content) self.status_code = status_code self.media_type = self.media_type if media_type is None else media_type self.background = background self.init_headers(headers) async def listen_for_disconnect(self, receive: Receive) -> None: while True: message = await receive() if message["type"] == "http.disconnect": break async def stream_response(self, send: Send) -> None: await send( { "type": "http.response.start", "status": self.status_code, "headers": self.raw_headers, } ) async for chunk in self.body_iterator: if not isinstance(chunk, bytes): chunk = chunk.encode(self.charset) await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: await func() task_group.cancel_scope.cancel() task_group.start_soon(wrap, partial(self.stream_response, send)) await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() class FileResponse(Response): chunk_size = 64 * 1024 def __init__( self, path: typing.Union[str, "os.PathLike[str]"], status_code: int = 200, headers: dict = None, media_type: str = None, background: BackgroundTask = None, filename: str = None, stat_result: os.stat_result = None, method: str = None, ) -> None: self.path = path self.status_code = status_code self.filename = filename self.send_header_only = method is not None and method.upper() == "HEAD" if media_type is None: media_type = guess_type(filename or path)[0] or "text/plain" self.media_type = media_type self.background = background self.init_headers(headers) if self.filename is not None: content_disposition_filename = quote(self.filename) if content_disposition_filename != self.filename: content_disposition = "attachment; filename*=utf-8''{}".format( content_disposition_filename ) else: content_disposition = f'attachment; filename="{self.filename}"' self.headers.setdefault("content-disposition", content_disposition) self.stat_result = stat_result if stat_result is not None: self.set_stat_headers(stat_result) def set_stat_headers(self, stat_result: os.stat_result) -> None: content_length = str(stat_result.st_size) last_modified = formatdate(stat_result.st_mtime, usegmt=True) etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) etag = md5_hexdigest(etag_base.encode(), usedforsecurity=False) self.headers.setdefault("content-length", content_length) self.headers.setdefault("last-modified", last_modified) self.headers.setdefault("etag", etag) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.stat_result is None: try: stat_result = await anyio.to_thread.run_sync(os.stat, self.path) self.set_stat_headers(stat_result) except FileNotFoundError: raise RuntimeError(f"File at path {self.path} does not exist.") else: mode = stat_result.st_mode if not stat.S_ISREG(mode): raise RuntimeError(f"File at path {self.path} is not a file.") await send( { "type": "http.response.start", "status": self.status_code, "headers": self.raw_headers, } ) if self.send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) else: async with await anyio.open_file(self.path, mode="rb") as file: more_body = True while more_body: chunk = await file.read(self.chunk_size) more_body = len(chunk) == self.chunk_size await send( { "type": "http.response.body", "body": chunk, "more_body": more_body, } ) if self.background is not None: await self.background() starlette-0.18.0/starlette/routing.py0000644000175100001710000006654214173233741020500 0ustar runnerdocker00000000000000import asyncio import contextlib import functools import inspect import re import sys import traceback import types import typing import warnings from enum import Enum from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose if sys.version_info >= (3, 7): from contextlib import asynccontextmanager # pragma: no cover else: from contextlib2 import asynccontextmanager # pragma: no cover class NoMatchFound(Exception): """ Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)` if no matching route exists. """ class Match(Enum): NONE = 0 PARTIAL = 1 FULL = 2 def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: """ Correctly determines if an object is a coroutine function, including those wrapped in functools.partial objects. """ while isinstance(obj, functools.partial): obj = obj.func return inspect.iscoroutinefunction(obj) def request_response(func: typing.Callable) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ is_coroutine = iscoroutinefunction_or_partial(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive=receive, send=send) if is_coroutine: response = await func(request) else: response = await run_in_threadpool(func, request) await response(scope, receive, send) return app def websocket_session(func: typing.Callable) -> ASGIApp: """ Takes a coroutine `func(session)`, and returns an ASGI application. """ # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" async def app(scope: Scope, receive: Receive, send: Send) -> None: session = WebSocket(scope, receive=receive, send=send) await func(session) return app def get_name(endpoint: typing.Callable) -> str: if inspect.isfunction(endpoint) or inspect.isclass(endpoint): return endpoint.__name__ return endpoint.__class__.__name__ def replace_params( path: str, param_convertors: typing.Dict[str, Convertor], path_params: typing.Dict[str, str], ) -> typing.Tuple[str, dict]: for key, value in list(path_params.items()): if "{" + key + "}" in path: convertor = param_convertors[key] value = convertor.to_string(value) path = path.replace("{" + key + "}", value) path_params.pop(key) return path, path_params # Match parameters in URL paths, eg. '{param}', and '{param:int}' PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") def compile_path( path: str, ) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]: """ Given a path string, like: "/{username:str}", return a three-tuple of (regex, format, {param_name:convertor}). regex: "/(?P[^/]+)" format: "/{username}" convertors: {"username": StringConvertor()} """ path_regex = "^" path_format = "" duplicated_params = set() idx = 0 param_convertors = {} for match in PARAM_REGEX.finditer(path): param_name, convertor_type = match.groups("str") convertor_type = convertor_type.lstrip(":") assert ( convertor_type in CONVERTOR_TYPES ), f"Unknown path convertor '{convertor_type}'" convertor = CONVERTOR_TYPES[convertor_type] path_regex += re.escape(path[idx : match.start()]) path_regex += f"(?P<{param_name}>{convertor.regex})" path_format += path[idx : match.start()] path_format += "{%s}" % param_name if param_name in param_convertors: duplicated_params.add(param_name) param_convertors[param_name] = convertor idx = match.end() if duplicated_params: names = ", ".join(sorted(duplicated_params)) ending = "s" if len(duplicated_params) > 1 else "" raise ValueError(f"Duplicated param name{ending} {names} at path {path}") path_regex += re.escape(path[idx:].split(":")[0]) + "$" path_format += path[idx:] return re.compile(path_regex), path_format, param_convertors class BaseRoute: def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: raise NotImplementedError() # pragma: no cover def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: raise NotImplementedError() # pragma: no cover async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: raise NotImplementedError() # pragma: no cover async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ A route may be used in isolation as a stand-alone ASGI app. This is a somewhat contrived case, as they'll almost always be used within a Router, but could be useful for some tooling and minimal apps. """ match, child_scope = self.matches(scope) if match == Match.NONE: if scope["type"] == "http": response = PlainTextResponse("Not Found", status_code=404) await response(scope, receive, send) elif scope["type"] == "websocket": websocket_close = WebSocketClose() await websocket_close(scope, receive, send) return scope.update(child_scope) await self.handle(scope, receive, send) class Route(BaseRoute): def __init__( self, path: str, endpoint: typing.Callable, *, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name self.include_in_schema = include_in_schema endpoint_handler = endpoint while isinstance(endpoint_handler, functools.partial): endpoint_handler = endpoint_handler.func if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(request) -> response`. self.app = request_response(endpoint) if methods is None: methods = ["GET"] else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint if methods is None: self.methods = None else: self.methods = {method.upper() for method in methods} if "GET" in self.methods: self.methods.add("HEAD") self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] == "http": match = self.path_regex.match(scope["path"]) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) child_scope = {"endpoint": self.endpoint, "path_params": path_params} if self.methods and scope["method"] not in self.methods: return Match.PARTIAL, child_scope else: return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: seen_params = set(path_params.keys()) expected_params = set(self.param_convertors.keys()) if name != self.name or seen_params != expected_params: raise NoMatchFound() path, remaining_params = replace_params( self.path_format, self.param_convertors, path_params ) assert not remaining_params return URLPath(path=path, protocol="http") async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: if self.methods and scope["method"] not in self.methods: if "app" in scope: raise HTTPException(status_code=405) else: response = PlainTextResponse("Method Not Allowed", status_code=405) await response(scope, receive, send) else: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, Route) and self.path == other.path and self.endpoint == other.endpoint and self.methods == other.methods ) class WebSocketRoute(BaseRoute): def __init__( self, path: str, endpoint: typing.Callable, *, name: str = None ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name endpoint_handler = endpoint while isinstance(endpoint_handler, functools.partial): endpoint_handler = endpoint_handler.func if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(websocket)`. self.app = websocket_session(endpoint) else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] == "websocket": match = self.path_regex.match(scope["path"]) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) child_scope = {"endpoint": self.endpoint, "path_params": path_params} return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: seen_params = set(path_params.keys()) expected_params = set(self.param_convertors.keys()) if name != self.name or seen_params != expected_params: raise NoMatchFound() path, remaining_params = replace_params( self.path_format, self.param_convertors, path_params ) assert not remaining_params return URLPath(path=path, protocol="websocket") async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint ) class Mount(BaseRoute): def __init__( self, path: str, app: ASGIApp = None, routes: typing.Sequence[BaseRoute] = None, name: str = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert ( app is not None or routes is not None ), "Either 'app=...', or 'routes=' must be specified" self.path = path.rstrip("/") if app is not None: self.app: ASGIApp = app else: self.app = Router(routes=routes) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path( self.path + "/{path:path}" ) @property def routes(self) -> typing.List[BaseRoute]: return getattr(self.app, "routes", []) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): path = scope["path"] match = self.path_regex.match(path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) remaining_path = "/" + matched_params.pop("path") matched_path = path[: -len(remaining_path)] path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) root_path = scope.get("root_path", "") child_scope = { "path_params": path_params, "app_root_path": scope.get("app_root_path", root_path), "root_path": root_path + matched_path, "path": remaining_path, "endpoint": self.app, } return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: if self.name is not None and name == self.name and "path" in path_params: # 'name' matches "". path_params["path"] = path_params["path"].lstrip("/") path, remaining_params = replace_params( self.path_format, self.param_convertors, path_params ) if not remaining_params: return URLPath(path=path) elif self.name is None or name.startswith(self.name + ":"): if self.name is None: # No mount name. remaining_name = name else: # 'name' matches ":". remaining_name = name[len(self.name) + 1 :] path_kwarg = path_params.get("path") path_params["path"] = "" path_prefix, remaining_params = replace_params( self.path_format, self.param_convertors, path_params ) if path_kwarg is not None: remaining_params["path"] = path_kwarg for route in self.routes or []: try: url = route.url_path_for(remaining_name, **remaining_params) return URLPath( path=path_prefix.rstrip("/") + str(url), protocol=url.protocol ) except NoMatchFound: pass raise NoMatchFound() async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, Mount) and self.path == other.path and self.app == other.app ) class Host(BaseRoute): def __init__(self, host: str, app: ASGIApp, name: str = None) -> None: self.host = host self.app = app self.name = name self.host_regex, self.host_format, self.param_convertors = compile_path(host) @property def routes(self) -> typing.List[BaseRoute]: return getattr(self.app, "routes", []) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): headers = Headers(scope=scope) host = headers.get("host", "").split(":")[0] match = self.host_regex.match(host) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) child_scope = {"path_params": path_params, "endpoint": self.app} return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: if self.name is not None and name == self.name and "path" in path_params: # 'name' matches "". path = path_params.pop("path") host, remaining_params = replace_params( self.host_format, self.param_convertors, path_params ) if not remaining_params: return URLPath(path=path, host=host) elif self.name is None or name.startswith(self.name + ":"): if self.name is None: # No mount name. remaining_name = name else: # 'name' matches ":". remaining_name = name[len(self.name) + 1 :] host, remaining_params = replace_params( self.host_format, self.param_convertors, path_params ) for route in self.routes or []: try: url = route.url_path_for(remaining_name, **remaining_params) return URLPath(path=str(url), protocol=url.protocol, host=host) except NoMatchFound: pass raise NoMatchFound() async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, Host) and self.host == other.host and self.app == other.app ) _T = typing.TypeVar("_T") class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): def __init__(self, cm: typing.ContextManager[_T]): self._cm = cm async def __aenter__(self) -> _T: return self._cm.__enter__() async def __aexit__( self, exc_type: typing.Optional[typing.Type[BaseException]], exc_value: typing.Optional[BaseException], traceback: typing.Optional[types.TracebackType], ) -> typing.Optional[bool]: return self._cm.__exit__(exc_type, exc_value, traceback) def _wrap_gen_lifespan_context( lifespan_context: typing.Callable[[typing.Any], typing.Generator] ) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: cmgr = contextlib.contextmanager(lifespan_context) @functools.wraps(cmgr) def wrapper(app: typing.Any) -> _AsyncLiftContextManager: return _AsyncLiftContextManager(cmgr(app)) return wrapper class _DefaultLifespan: def __init__(self, router: "Router"): self._router = router async def __aenter__(self) -> None: await self._router.startup() async def __aexit__(self, *exc_info: object) -> None: await self._router.shutdown() def __call__(self: _T, app: object) -> _T: return self class Router: def __init__( self, routes: typing.Sequence[BaseRoute] = None, redirect_slashes: bool = True, default: ASGIApp = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes self.default = self.not_found if default is None else default self.on_startup = [] if on_startup is None else list(on_startup) self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) if lifespan is None: self.lifespan_context: typing.Callable[ [typing.Any], typing.AsyncContextManager ] = _DefaultLifespan(self) elif inspect.isasyncgenfunction(lifespan): warnings.warn( "async generator function lifespans are deprecated, " "use an @contextlib.asynccontextmanager function instead", DeprecationWarning, ) self.lifespan_context = asynccontextmanager( lifespan, # type: ignore[arg-type] ) elif inspect.isgeneratorfunction(lifespan): warnings.warn( "generator function lifespans are deprecated, " "use an @contextlib.asynccontextmanager function instead", DeprecationWarning, ) self.lifespan_context = _wrap_gen_lifespan_context( lifespan, # type: ignore[arg-type] ) else: self.lifespan_context = lifespan async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": websocket_close = WebSocketClose() await websocket_close(scope, receive, send) return # If we're running inside a starlette application then raise an # exception, so that the configurable exception handler can deal with # returning the response. For plain ASGI apps, just return the response. if "app" in scope: raise HTTPException(status_code=404) else: response = PlainTextResponse("Not Found", status_code=404) await response(scope, receive, send) def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: for route in self.routes: try: return route.url_path_for(name, **path_params) except NoMatchFound: pass raise NoMatchFound() async def startup(self) -> None: """ Run any `.on_startup` event handlers. """ for handler in self.on_startup: if asyncio.iscoroutinefunction(handler): await handler() else: handler() async def shutdown(self) -> None: """ Run any `.on_shutdown` event handlers. """ for handler in self.on_shutdown: if asyncio.iscoroutinefunction(handler): await handler() else: handler() async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ Handle ASGI lifespan messages, which allows us to manage application startup and shutdown events. """ started = False app = scope.get("app") await receive() try: async with self.lifespan_context(app): await send({"type": "lifespan.startup.complete"}) started = True await receive() except BaseException: exc_text = traceback.format_exc() if started: await send({"type": "lifespan.shutdown.failed", "message": exc_text}) else: await send({"type": "lifespan.startup.failed", "message": exc_text}) raise else: await send({"type": "lifespan.shutdown.complete"}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ The main entry point to the Router class. """ assert scope["type"] in ("http", "websocket", "lifespan") if "router" not in scope: scope["router"] = self if scope["type"] == "lifespan": await self.lifespan(scope, receive, send) return partial = None for route in self.routes: # Determine if any route matches the incoming scope, # and hand over to the matching route if found. match, child_scope = route.matches(scope) if match == Match.FULL: scope.update(child_scope) await route.handle(scope, receive, send) return elif match == Match.PARTIAL and partial is None: partial = route partial_scope = child_scope if partial is not None: #  Handle partial matches. These are cases where an endpoint is # able to handle the request, but is not a preferred option. # We use this in particular to deal with "405 Method Not Allowed". scope.update(partial_scope) await partial.handle(scope, receive, send) return if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/": redirect_scope = dict(scope) if scope["path"].endswith("/"): redirect_scope["path"] = redirect_scope["path"].rstrip("/") else: redirect_scope["path"] = redirect_scope["path"] + "/" for route in self.routes: match, child_scope = route.matches(redirect_scope) if match != Match.NONE: redirect_url = URL(scope=redirect_scope) response = RedirectResponse(url=str(redirect_url)) await response(scope, receive, send) return await self.default(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: return isinstance(other, Router) and self.routes == other.routes # The following usages are now discouraged in favour of configuration #  during Router.__init__(...) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: route = Mount(path, app=app, name=name) self.routes.append(route) def host(self, host: str, app: ASGIApp, name: str = None) -> None: route = Host(host, app=app, name=name) self.routes.append(route) def add_route( self, path: str, endpoint: typing.Callable, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> None: route = Route( path, endpoint=endpoint, methods=methods, name=name, include_in_schema=include_in_schema, ) self.routes.append(route) def add_websocket_route( self, path: str, endpoint: typing.Callable, name: str = None ) -> None: route = WebSocketRoute(path, endpoint=endpoint, name=name) self.routes.append(route) def route( self, path: str, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_route( path, func, methods=methods, name=name, include_in_schema=include_in_schema, ) return func return decorator def websocket_route(self, path: str, name: str = None) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_websocket_route(path, func, name=name) return func return decorator def add_event_handler(self, event_type: str, func: typing.Callable) -> None: assert event_type in ("startup", "shutdown") if event_type == "startup": self.on_startup.append(func) else: self.on_shutdown.append(func) def on_event(self, event_type: str) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_event_handler(event_type, func) return func return decorator starlette-0.18.0/starlette/schemas.py0000644000175100001710000001057214173233741020424 0ustar runnerdocker00000000000000import inspect import typing from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Mount, Route try: import yaml except ImportError: # pragma: nocover yaml = None # type: ignore class OpenAPIResponse(Response): media_type = "application/vnd.oai.openapi" def render(self, content: typing.Any) -> bytes: assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." assert isinstance( content, dict ), "The schema passed to OpenAPIResponse should be a dictionary." return yaml.dump(content, default_flow_style=False).encode("utf-8") class EndpointInfo(typing.NamedTuple): path: str http_method: str func: typing.Callable class BaseSchemaGenerator: def get_schema(self, routes: typing.List[BaseRoute]) -> dict: raise NotImplementedError() # pragma: no cover def get_endpoints( self, routes: typing.List[BaseRoute] ) -> typing.List[EndpointInfo]: """ Given the routes, yields the following information: - path eg: /users/ - http_method one of 'get', 'post', 'put', 'patch', 'delete', 'options' - func method ready to extract the docstring """ endpoints_info: list = [] for route in routes: if isinstance(route, Mount): routes = route.routes or [] sub_endpoints = [ EndpointInfo( path="".join((route.path, sub_endpoint.path)), http_method=sub_endpoint.http_method, func=sub_endpoint.func, ) for sub_endpoint in self.get_endpoints(routes) ] endpoints_info.extend(sub_endpoints) elif not isinstance(route, Route) or not route.include_in_schema: continue elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): for method in route.methods or ["GET"]: if method == "HEAD": continue endpoints_info.append( EndpointInfo(route.path, method.lower(), route.endpoint) ) else: for method in ["get", "post", "put", "patch", "delete", "options"]: if not hasattr(route.endpoint, method): continue func = getattr(route.endpoint, method) endpoints_info.append( EndpointInfo(route.path, method.lower(), func) ) return endpoints_info def parse_docstring(self, func_or_method: typing.Callable) -> dict: """ Given a function, parse the docstring as YAML and return a dictionary of info. """ docstring = func_or_method.__doc__ if not docstring: return {} assert yaml is not None, "`pyyaml` must be installed to use parse_docstring." # We support having regular docstrings before the schema # definition. Here we return just the schema part from # the docstring. docstring = docstring.split("---")[-1] parsed = yaml.safe_load(docstring) if not isinstance(parsed, dict): # A regular docstring (not yaml formatted) can return # a simple string here, which wouldn't follow the schema. return {} return parsed def OpenAPIResponse(self, request: Request) -> Response: routes = request.app.routes schema = self.get_schema(routes=routes) return OpenAPIResponse(schema) class SchemaGenerator(BaseSchemaGenerator): def __init__(self, base_schema: dict) -> None: self.base_schema = base_schema def get_schema(self, routes: typing.List[BaseRoute]) -> dict: schema = dict(self.base_schema) schema.setdefault("paths", {}) endpoints_info = self.get_endpoints(routes) for endpoint in endpoints_info: parsed = self.parse_docstring(endpoint.func) if not parsed: continue if endpoint.path not in schema["paths"]: schema["paths"][endpoint.path] = {} schema["paths"][endpoint.path][endpoint.http_method] = parsed return schema starlette-0.18.0/starlette/staticfiles.py0000644000175100001710000002030314173233741021304 0ustar runnerdocker00000000000000import importlib.util import os import stat import typing from email.utils import parsedate import anyio from starlette.datastructures import URL, Headers from starlette.exceptions import HTTPException from starlette.responses import FileResponse, RedirectResponse, Response from starlette.types import Receive, Scope, Send PathLike = typing.Union[str, "os.PathLike[str]"] class NotModifiedResponse(Response): NOT_MODIFIED_HEADERS = ( "cache-control", "content-location", "date", "etag", "expires", "vary", ) def __init__(self, headers: Headers): super().__init__( status_code=304, headers={ name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS }, ) class StaticFiles: def __init__( self, *, directory: PathLike = None, packages: typing.List[typing.Union[str, typing.Tuple[str, str]]] = None, html: bool = False, check_dir: bool = True, ) -> None: self.directory = directory self.packages = packages self.all_directories = self.get_directories(directory, packages) self.html = html self.config_checked = False if check_dir and directory is not None and not os.path.isdir(directory): raise RuntimeError(f"Directory '{directory}' does not exist") def get_directories( self, directory: PathLike = None, packages: typing.List[typing.Union[str, typing.Tuple[str, str]]] = None, ) -> typing.List[PathLike]: """ Given `directory` and `packages` arguments, return a list of all the directories that should be used for serving static files from. """ directories = [] if directory is not None: directories.append(directory) for package in packages or []: if isinstance(package, tuple): package, statics_dir = package else: statics_dir = "statics" spec = importlib.util.find_spec(package) assert spec is not None, f"Package {package!r} could not be found." assert spec.origin is not None, f"Package {package!r} could not be found." package_directory = os.path.normpath( os.path.join(spec.origin, "..", statics_dir) ) assert os.path.isdir( package_directory ), f"Directory '{statics_dir!r}' in package {package!r} could not be found." directories.append(package_directory) return directories async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ The ASGI entry point. """ assert scope["type"] == "http" if not self.config_checked: await self.check_config() self.config_checked = True path = self.get_path(scope) response = await self.get_response(path, scope) await response(scope, receive, send) def get_path(self, scope: Scope) -> str: """ Given the ASGI scope, return the `path` string to serve up, with OS specific path separators, and any '..', '.' components removed. """ return os.path.normpath(os.path.join(*scope["path"].split("/"))) async def get_response(self, path: str, scope: Scope) -> Response: """ Returns an HTTP response, given the incoming path, method and request headers. """ if scope["method"] not in ("GET", "HEAD"): raise HTTPException(status_code=405) try: full_path, stat_result = await anyio.to_thread.run_sync( self.lookup_path, path ) except PermissionError: raise HTTPException(status_code=401) except OSError: raise if stat_result and stat.S_ISREG(stat_result.st_mode): # We have a static file to serve. return self.file_response(full_path, stat_result, scope) elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: # We're in HTML mode, and have got a directory URL. # Check if we have 'index.html' file to serve. index_path = os.path.join(path, "index.html") full_path, stat_result = await anyio.to_thread.run_sync( self.lookup_path, index_path ) if stat_result is not None and stat.S_ISREG(stat_result.st_mode): if not scope["path"].endswith("/"): # Directory URLs should redirect to always end in "/". url = URL(scope=scope) url = url.replace(path=url.path + "/") return RedirectResponse(url=url) return self.file_response(full_path, stat_result, scope) if self.html: # Check for '404.html' if we're in HTML mode. full_path, stat_result = await anyio.to_thread.run_sync( self.lookup_path, "404.html" ) if stat_result and stat.S_ISREG(stat_result.st_mode): return FileResponse( full_path, stat_result=stat_result, method=scope["method"], status_code=404, ) raise HTTPException(status_code=404) def lookup_path( self, path: str ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: for directory in self.all_directories: full_path = os.path.realpath(os.path.join(directory, path)) directory = os.path.realpath(directory) if os.path.commonprefix([full_path, directory]) != directory: # Don't allow misbehaving clients to break out of the static files # directory. continue try: return full_path, os.stat(full_path) except (FileNotFoundError, NotADirectoryError): continue return "", None def file_response( self, full_path: PathLike, stat_result: os.stat_result, scope: Scope, status_code: int = 200, ) -> Response: method = scope["method"] request_headers = Headers(scope=scope) response = FileResponse( full_path, status_code=status_code, stat_result=stat_result, method=method ) if self.is_not_modified(response.headers, request_headers): return NotModifiedResponse(response.headers) return response async def check_config(self) -> None: """ Perform a one-off configuration check that StaticFiles is actually pointed at a directory, so that we can raise loud errors rather than just returning 404 responses. """ if self.directory is None: return try: stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) except FileNotFoundError: raise RuntimeError( f"StaticFiles directory '{self.directory}' does not exist." ) if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)): raise RuntimeError( f"StaticFiles path '{self.directory}' is not a directory." ) def is_not_modified( self, response_headers: Headers, request_headers: Headers ) -> bool: """ Given the request and response headers, return `True` if an HTTP "Not Modified" response could be returned instead. """ try: if_none_match = request_headers["if-none-match"] etag = response_headers["etag"] if if_none_match == etag: return True except KeyError: pass try: if_modified_since = parsedate(request_headers["if-modified-since"]) last_modified = parsedate(response_headers["last-modified"]) if ( if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified ): return True except KeyError: pass return False starlette-0.18.0/starlette/status.py0000644000175100001710000000533714173233741020327 0ustar runnerdocker00000000000000""" HTTP codes See HTTP Status Code Registry: https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml And RFC 2324 - https://tools.ietf.org/html/rfc2324 """ HTTP_100_CONTINUE = 100 HTTP_101_SWITCHING_PROTOCOLS = 101 HTTP_102_PROCESSING = 102 HTTP_103_EARLY_HINTS = 103 HTTP_200_OK = 200 HTTP_201_CREATED = 201 HTTP_202_ACCEPTED = 202 HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203 HTTP_204_NO_CONTENT = 204 HTTP_205_RESET_CONTENT = 205 HTTP_206_PARTIAL_CONTENT = 206 HTTP_207_MULTI_STATUS = 207 HTTP_208_ALREADY_REPORTED = 208 HTTP_226_IM_USED = 226 HTTP_300_MULTIPLE_CHOICES = 300 HTTP_301_MOVED_PERMANENTLY = 301 HTTP_302_FOUND = 302 HTTP_303_SEE_OTHER = 303 HTTP_304_NOT_MODIFIED = 304 HTTP_305_USE_PROXY = 305 HTTP_306_RESERVED = 306 HTTP_307_TEMPORARY_REDIRECT = 307 HTTP_308_PERMANENT_REDIRECT = 308 HTTP_400_BAD_REQUEST = 400 HTTP_401_UNAUTHORIZED = 401 HTTP_402_PAYMENT_REQUIRED = 402 HTTP_403_FORBIDDEN = 403 HTTP_404_NOT_FOUND = 404 HTTP_405_METHOD_NOT_ALLOWED = 405 HTTP_406_NOT_ACCEPTABLE = 406 HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407 HTTP_408_REQUEST_TIMEOUT = 408 HTTP_409_CONFLICT = 409 HTTP_410_GONE = 410 HTTP_411_LENGTH_REQUIRED = 411 HTTP_412_PRECONDITION_FAILED = 412 HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413 HTTP_414_REQUEST_URI_TOO_LONG = 414 HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415 HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416 HTTP_417_EXPECTATION_FAILED = 417 HTTP_418_IM_A_TEAPOT = 418 HTTP_421_MISDIRECTED_REQUEST = 421 HTTP_422_UNPROCESSABLE_ENTITY = 422 HTTP_423_LOCKED = 423 HTTP_424_FAILED_DEPENDENCY = 424 HTTP_425_TOO_EARLY = 425 HTTP_426_UPGRADE_REQUIRED = 426 HTTP_428_PRECONDITION_REQUIRED = 428 HTTP_429_TOO_MANY_REQUESTS = 429 HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431 HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451 HTTP_500_INTERNAL_SERVER_ERROR = 500 HTTP_501_NOT_IMPLEMENTED = 501 HTTP_502_BAD_GATEWAY = 502 HTTP_503_SERVICE_UNAVAILABLE = 503 HTTP_504_GATEWAY_TIMEOUT = 504 HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505 HTTP_506_VARIANT_ALSO_NEGOTIATES = 506 HTTP_507_INSUFFICIENT_STORAGE = 507 HTTP_508_LOOP_DETECTED = 508 HTTP_510_NOT_EXTENDED = 510 HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 """ WebSocket codes https://www.iana.org/assignments/websocket/websocket.xml#close-code-number https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent """ WS_1000_NORMAL_CLOSURE = 1000 WS_1001_GOING_AWAY = 1001 WS_1002_PROTOCOL_ERROR = 1002 WS_1003_UNSUPPORTED_DATA = 1003 WS_1004_NO_STATUS_RCVD = 1004 WS_1005_ABNORMAL_CLOSURE = 1005 WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007 WS_1008_POLICY_VIOLATION = 1008 WS_1009_MESSAGE_TOO_BIG = 1009 WS_1010_MANDATORY_EXT = 1010 WS_1011_INTERNAL_ERROR = 1011 WS_1012_SERVICE_RESTART = 1012 WS_1013_TRY_AGAIN_LATER = 1013 WS_1014_BAD_GATEWAY = 1014 WS_1015_TLS_HANDSHAKE = 1015 starlette-0.18.0/starlette/templating.py0000644000175100001710000000626014173233741021144 0ustar runnerdocker00000000000000import typing from os import PathLike from starlette.background import BackgroundTask from starlette.responses import Response from starlette.types import Receive, Scope, Send try: import jinja2 # @contextfunction renamed to @pass_context in Jinja 3.0, to be removed in 3.1 if hasattr(jinja2, "pass_context"): pass_context = jinja2.pass_context else: # pragma: nocover pass_context = jinja2.contextfunction except ImportError: # pragma: nocover jinja2 = None # type: ignore class _TemplateResponse(Response): media_type = "text/html" def __init__( self, template: typing.Any, context: dict, status_code: int = 200, headers: dict = None, media_type: str = None, background: BackgroundTask = None, ): self.template = template self.context = context content = template.render(context) super().__init__(content, status_code, headers, media_type, background) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = self.context.get("request", {}) extensions = request.get("extensions", {}) if "http.response.template" in extensions: await send( { "type": "http.response.template", "template": self.template, "context": self.context, } ) await super().__call__(scope, receive, send) class Jinja2Templates: """ templates = Jinja2Templates("templates") return templates.TemplateResponse("index.html", {"request": request}) """ def __init__( self, directory: typing.Union[str, PathLike], **env_options: typing.Any ) -> None: assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" self.env = self._create_env(directory, **env_options) def _create_env( self, directory: typing.Union[str, PathLike], **env_options: typing.Any ) -> "jinja2.Environment": @pass_context def url_for(context: dict, name: str, **path_params: typing.Any) -> str: request = context["request"] return request.url_for(name, **path_params) loader = jinja2.FileSystemLoader(directory) env_options.setdefault("loader", loader) env_options.setdefault("autoescape", True) env = jinja2.Environment(**env_options) env.globals["url_for"] = url_for return env def get_template(self, name: str) -> "jinja2.Template": return self.env.get_template(name) def TemplateResponse( self, name: str, context: dict, status_code: int = 200, headers: dict = None, media_type: str = None, background: BackgroundTask = None, ) -> _TemplateResponse: if "request" not in context: raise ValueError('context must include a "request" key') template = self.get_template(name) return _TemplateResponse( template, context, status_code=status_code, headers=headers, media_type=media_type, background=background, ) starlette-0.18.0/starlette/testclient.py0000644000175100001710000004620714173233741021163 0ustar runnerdocker00000000000000import asyncio import contextlib import http import inspect import io import json import math import queue import sys import types import typing from concurrent.futures import Future from urllib.parse import unquote, urljoin, urlsplit import anyio.abc import requests from anyio.streams.stapled import StapledObjectStream from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect if sys.version_info >= (3, 8): # pragma: no cover from typing import TypedDict else: # pragma: no cover from typing_extensions import TypedDict _PortalFactoryType = typing.Callable[ [], typing.ContextManager[anyio.abc.BlockingPortal] ] # Annotations for `Session.request()` Cookies = typing.Union[ typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar ] Params = typing.Union[bytes, typing.MutableMapping[str, str]] DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO] TimeOut = typing.Union[float, typing.Tuple[float, float]] FileType = typing.MutableMapping[str, typing.IO] AuthType = typing.Union[ typing.Tuple[str, str], requests.auth.AuthBase, typing.Callable[[requests.PreparedRequest], requests.PreparedRequest], ] ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] ASGI2App = typing.Callable[[Scope], ASGIInstance] ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): def get_all(self, key: str, default: str) -> str: return self.getheaders(key) class _MockOriginalResponse: """ We have to jump through some hoops to present the response as if it was made using urllib3. """ def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None: self.msg = _HeaderDict(headers) self.closed = False def isclosed(self) -> bool: return self.closed class _Upgrade(Exception): def __init__(self, session: "WebSocketTestSession") -> None: self.session = session def _get_reason_phrase(status_code: int) -> str: try: return http.HTTPStatus(status_code).phrase except ValueError: return "" def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool: if inspect.isclass(app): return hasattr(app, "__await__") elif inspect.isfunction(app): return asyncio.iscoroutinefunction(app) call = getattr(app, "__call__", None) return asyncio.iscoroutinefunction(call) class _WrapASGI2: """ Provide an ASGI3 interface onto an ASGI2 app. """ def __init__(self, app: ASGI2App) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: instance = self.app(scope) await instance(receive, send) class _AsyncBackend(TypedDict): backend: str backend_options: typing.Dict[str, typing.Any] class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( self, app: ASGI3App, portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path self.portal_factory = portal_factory def send( self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any ) -> requests.Response: scheme, netloc, path, query, fragment = ( str(item) for item in urlsplit(request.url) ) default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] if ":" in netloc: host, port_string = netloc.split(":", 1) port = int(port_string) else: host = netloc port = default_port # Include the 'host' header. if "host" in request.headers: headers: typing.List[typing.Tuple[bytes, bytes]] = [] elif port == default_port: headers = [(b"host", host.encode())] else: headers = [(b"host", (f"{host}:{port}").encode())] # Include other request headers. headers += [ (key.lower().encode(), value.encode()) for key, value in request.headers.items() ] scope: typing.Dict[str, typing.Any] if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) if subprotocol is None: subprotocols: typing.Sequence[str] = [] else: subprotocols = [value.strip() for value in subprotocol.split(",")] scope = { "type": "websocket", "path": unquote(path), "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), "headers": headers, "client": ["testclient", 50000], "server": [host, port], "subprotocols": subprotocols, } session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) scope = { "type": "http", "http_version": "1.1", "method": request.method, "path": unquote(path), "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), "headers": headers, "client": ["testclient", 50000], "server": [host, port], "extensions": {"http.response.template": {}}, } request_complete = False response_started = False response_complete: anyio.Event raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()} template = None context = None async def receive() -> Message: nonlocal request_complete if request_complete: if not response_complete.is_set(): await response_complete.wait() return {"type": "http.disconnect"} body = request.body if isinstance(body, str): body_bytes: bytes = body.encode("utf-8") elif body is None: body_bytes = b"" elif isinstance(body, types.GeneratorType): try: chunk = body.send(None) if isinstance(chunk, str): chunk = chunk.encode("utf-8") return {"type": "http.request", "body": chunk, "more_body": True} except StopIteration: request_complete = True return {"type": "http.request", "body": b""} else: body_bytes = body request_complete = True return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert ( not response_started ), 'Received multiple "http.response.start" messages.' raw_kwargs["version"] = 11 raw_kwargs["status"] = message["status"] raw_kwargs["reason"] = _get_reason_phrase(message["status"]) raw_kwargs["headers"] = [ (key.decode(), value.decode()) for key, value in message.get("headers", []) ] raw_kwargs["preload_content"] = False raw_kwargs["original_response"] = _MockOriginalResponse( raw_kwargs["headers"] ) response_started = True elif message["type"] == "http.response.body": assert ( response_started ), 'Received "http.response.body" without "http.response.start".' assert ( not response_complete.is_set() ), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": raw_kwargs["body"].write(body) if not more_body: raw_kwargs["body"].seek(0) response_complete.set() elif message["type"] == "http.response.template": template = message["template"] context = message["context"] try: with self.portal_factory() as portal: response_complete = portal.call(anyio.Event) portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: raise exc if self.raise_server_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: raw_kwargs = { "version": 11, "status": 500, "reason": "Internal Server Error", "headers": [], "preload_content": False, "original_response": _MockOriginalResponse([]), "body": io.BytesIO(), } raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) response = self.build_response(request, raw) if template is not None: response.template = template response.context = context return response class WebSocketTestSession: def __init__( self, app: ASGI3App, scope: Scope, portal_factory: _PortalFactoryType, ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None self.extra_headers = None self.portal_factory = portal_factory self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() def __enter__(self) -> "WebSocketTestSession": self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context(self.portal_factory()) try: _: "Future[None]" = self.portal.start_task_soon(self._run) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) except Exception: self.exit_stack.close() raise self.accepted_subprotocol = message.get("subprotocol", None) self.extra_headers = message.get("headers", None) return self def __exit__(self, *args: typing.Any) -> None: try: self.close(1000) finally: self.exit_stack.close() while not self._send_queue.empty(): message = self._send_queue.get() if isinstance(message, BaseException): raise message async def _run(self) -> None: """ The sub-thread in which the websocket session runs. """ scope = self.scope receive = self._asgi_receive send = self._asgi_send try: await self.app(scope, receive, send) except BaseException as exc: self._send_queue.put(exc) raise async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): await anyio.sleep(0) return self._receive_queue.get() async def _asgi_send(self, message: Message) -> None: self._send_queue.put(message) def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": raise WebSocketDisconnect( message.get("code", 1000), message.get("reason", "") ) def send(self, message: Message) -> None: self._receive_queue.put(message) def send_text(self, data: str) -> None: self.send({"type": "websocket.receive", "text": data}) def send_bytes(self, data: bytes) -> None: self.send({"type": "websocket.receive", "bytes": data}) def send_json(self, data: typing.Any, mode: str = "text") -> None: assert mode in ["text", "binary"] text = json.dumps(data) if mode == "text": self.send({"type": "websocket.receive", "text": text}) else: self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) def close(self, code: int = 1000) -> None: self.send({"type": "websocket.disconnect", "code": code}) def receive(self) -> Message: message = self._send_queue.get() if isinstance(message, BaseException): raise message return message def receive_text(self) -> str: message = self.receive() self._raise_on_close(message) return message["text"] def receive_bytes(self) -> bytes: message = self.receive() self._raise_on_close(message) return message["bytes"] def receive_json(self, mode: str = "text") -> typing.Any: assert mode in ["text", "binary"] message = self.receive() self._raise_on_close(message) if mode == "text": text = message["text"] else: text = message["bytes"].decode("utf-8") return json.loads(text) class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. task: "Future[None]" portal: typing.Optional[anyio.abc.BlockingPortal] = None def __init__( self, app: typing.Union[ASGI2App, ASGI3App], base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> None: super().__init__() self.async_backend = _AsyncBackend( backend=backend, backend_options=backend_options or {} ) if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app else: app = typing.cast(ASGI2App, app) asgi_app = _WrapASGI2(app) #  type: ignore adapter = _ASGIAdapter( asgi_app, portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ) self.mount("http://", adapter) self.mount("https://", adapter) self.mount("ws://", adapter) self.mount("wss://", adapter) self.headers.update({"user-agent": "testclient"}) self.app = asgi_app self.base_url = base_url @contextlib.contextmanager def _portal_factory( self, ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: if self.portal is not None: yield self.portal else: with anyio.start_blocking_portal(**self.async_backend) as portal: yield portal def request( # type: ignore self, method: str, url: str, params: Params = None, data: DataType = None, headers: typing.MutableMapping[str, str] = None, cookies: Cookies = None, files: FileType = None, auth: AuthType = None, timeout: TimeOut = None, allow_redirects: bool = None, proxies: typing.MutableMapping[str, str] = None, hooks: typing.Any = None, stream: bool = None, verify: typing.Union[bool, str] = None, cert: typing.Union[str, typing.Tuple[str, str]] = None, json: typing.Any = None, ) -> requests.Response: url = urljoin(self.base_url, url) return super().request( method, url, params=params, data=data, headers=headers, cookies=cookies, files=files, auth=auth, timeout=timeout, allow_redirects=allow_redirects, proxies=proxies, hooks=hooks, stream=stream, verify=verify, cert=cert, json=json, ) def websocket_connect( self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any ) -> typing.Any: url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) headers.setdefault("connection", "upgrade") headers.setdefault("sec-websocket-key", "testserver==") headers.setdefault("sec-websocket-version", "13") if subprotocols is not None: headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) kwargs["headers"] = headers try: super().request("GET", url, **kwargs) except _Upgrade as exc: session = exc.session else: raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover return session def __enter__(self) -> "TestClient": with contextlib.ExitStack() as stack: self.portal = portal = stack.enter_context( anyio.start_blocking_portal(**self.async_backend) ) @stack.callback def reset_portal() -> None: self.portal = None self.stream_send = StapledObjectStream( *anyio.create_memory_object_stream(math.inf) ) self.stream_receive = StapledObjectStream( *anyio.create_memory_object_stream(math.inf) ) self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) @stack.callback def wait_shutdown() -> None: portal.call(self.wait_shutdown) self.exit_stack = stack.pop_all() return self def __exit__(self, *args: typing.Any) -> None: self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan"} try: await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: await self.stream_send.send(None) async def wait_startup(self) -> None: await self.stream_receive.send({"type": "lifespan.startup"}) async def receive() -> typing.Any: message = await self.stream_send.receive() if message is None: self.task.result() return message message = await receive() assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": await receive() async def wait_shutdown(self) -> None: async def receive() -> typing.Any: message = await self.stream_send.receive() if message is None: self.task.result() return message async with self.stream_send: await self.stream_receive.send({"type": "lifespan.shutdown"}) message = await receive() assert message["type"] in ( "lifespan.shutdown.complete", "lifespan.shutdown.failed", ) if message["type"] == "lifespan.shutdown.failed": await receive() starlette-0.18.0/starlette/types.py0000644000175100001710000000045614173233741020145 0ustar runnerdocker00000000000000import typing Scope = typing.MutableMapping[str, typing.Any] Message = typing.MutableMapping[str, typing.Any] Receive = typing.Callable[[], typing.Awaitable[Message]] Send = typing.Callable[[Message], typing.Awaitable[None]] ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] starlette-0.18.0/starlette/websockets.py0000644000175100001710000001353214173233741021151 0ustar runnerdocker00000000000000import enum import json import typing from starlette.requests import HTTPConnection from starlette.types import Message, Receive, Scope, Send class WebSocketState(enum.Enum): CONNECTING = 0 CONNECTED = 1 DISCONNECTED = 2 class WebSocketDisconnect(Exception): def __init__(self, code: int = 1000, reason: str = None) -> None: self.code = code self.reason = reason or "" class WebSocket(HTTPConnection): def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: super().__init__(scope) assert scope["type"] == "websocket" self._receive = receive self._send = send self.client_state = WebSocketState.CONNECTING self.application_state = WebSocketState.CONNECTING async def receive(self) -> Message: """ Receive ASGI websocket messages, ensuring valid state transitions. """ if self.client_state == WebSocketState.CONNECTING: message = await self._receive() message_type = message["type"] assert message_type == "websocket.connect" self.client_state = WebSocketState.CONNECTED return message elif self.client_state == WebSocketState.CONNECTED: message = await self._receive() message_type = message["type"] assert message_type in {"websocket.receive", "websocket.disconnect"} if message_type == "websocket.disconnect": self.client_state = WebSocketState.DISCONNECTED return message else: raise RuntimeError( 'Cannot call "receive" once a disconnect message has been received.' ) async def send(self, message: Message) -> None: """ Send ASGI websocket messages, ensuring valid state transitions. """ if self.application_state == WebSocketState.CONNECTING: message_type = message["type"] assert message_type in {"websocket.accept", "websocket.close"} if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED else: self.application_state = WebSocketState.CONNECTED await self._send(message) elif self.application_state == WebSocketState.CONNECTED: message_type = message["type"] assert message_type in {"websocket.send", "websocket.close"} if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED await self._send(message) else: raise RuntimeError('Cannot call "send" once a close message has been sent.') async def accept( self, subprotocol: str = None, headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None, ) -> None: headers = headers or [] if self.client_state == WebSocketState.CONNECTING: # If we haven't yet seen the 'connect' message, then wait for it first. await self.receive() await self.send( {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers} ) def _raise_on_disconnect(self, message: Message) -> None: if message["type"] == "websocket.disconnect": raise WebSocketDisconnect(message["code"]) async def receive_text(self) -> str: assert self.application_state == WebSocketState.CONNECTED message = await self.receive() self._raise_on_disconnect(message) return message["text"] async def receive_bytes(self) -> bytes: assert self.application_state == WebSocketState.CONNECTED message = await self.receive() self._raise_on_disconnect(message) return message["bytes"] async def receive_json(self, mode: str = "text") -> typing.Any: assert mode in ["text", "binary"] assert self.application_state == WebSocketState.CONNECTED message = await self.receive() self._raise_on_disconnect(message) if mode == "text": text = message["text"] else: text = message["bytes"].decode("utf-8") return json.loads(text) async def iter_text(self) -> typing.AsyncIterator[str]: try: while True: yield await self.receive_text() except WebSocketDisconnect: pass async def iter_bytes(self) -> typing.AsyncIterator[bytes]: try: while True: yield await self.receive_bytes() except WebSocketDisconnect: pass async def iter_json(self) -> typing.AsyncIterator[typing.Any]: try: while True: yield await self.receive_json() except WebSocketDisconnect: pass async def send_text(self, data: str) -> None: await self.send({"type": "websocket.send", "text": data}) async def send_bytes(self, data: bytes) -> None: await self.send({"type": "websocket.send", "bytes": data}) async def send_json(self, data: typing.Any, mode: str = "text") -> None: assert mode in ["text", "binary"] text = json.dumps(data) if mode == "text": await self.send({"type": "websocket.send", "text": text}) else: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) async def close(self, code: int = 1000, reason: str = None) -> None: await self.send( {"type": "websocket.close", "code": code, "reason": reason or ""} ) class WebSocketClose: def __init__(self, code: int = 1000, reason: str = None) -> None: self.code = code self.reason = reason or "" async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await send( {"type": "websocket.close", "code": self.code, "reason": self.reason} ) starlette-0.18.0/starlette.egg-info/0000755000175100001710000000000014173233775020123 5ustar runnerdocker00000000000000starlette-0.18.0/starlette.egg-info/PKG-INFO0000644000175100001710000001620314173233775021222 0ustar runnerdocker00000000000000Metadata-Version: 2.1 Name: starlette Version: 0.18.0 Summary: The little ASGI library that shines. Home-page: https://github.com/encode/starlette Author: Tom Christie Author-email: tom@tomchristie.com License: BSD Description:

starlette

✨ The little ASGI framework that shines. ✨

Build Status Package version

--- **Documentation**: [https://www.starlette.io/](https://www.starlette.io/) --- # Starlette Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit, which is ideal for building high performance async services. It is production-ready, and gives you the following: * Seriously impressive performance. * WebSocket support. * In-process background tasks. * Startup and shutdown events. * Test client built on `requests`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. * Few hard dependencies. * Compatible with `asyncio` and `trio` backends. ## Requirements Python 3.6+ ## Installation ```shell $ pip3 install starlette ``` You'll also want to install an ASGI server, such as [uvicorn](http://www.uvicorn.org/), [daphne](https://github.com/django/daphne/), or [hypercorn](https://pgjones.gitlab.io/hypercorn/). ```shell $ pip3 install uvicorn ``` ## Example **example.py**: ```python from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route async def homepage(request): return JSONResponse({'hello': 'world'}) routes = [ Route("/", endpoint=homepage) ] app = Starlette(debug=True, routes=routes) ``` Then run the application using Uvicorn: ```shell $ uvicorn example:app ``` For a more complete example, see [encode/starlette-example](https://github.com/encode/starlette-example). ## Dependencies Starlette only requires `anyio`, and the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. You can install all of these with `pip3 install starlette[full]`. ## Framework or Toolkit Starlette is designed to be used either as a complete framework, or as an ASGI toolkit. You can use any of its components independently. ```python from starlette.responses import PlainTextResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = PlainTextResponse('Hello, world!') await response(scope, receive, send) ``` Run the `app` application in `example.py`: ```shell $ uvicorn example:app INFO: Started server process [11509] INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) ``` Run uvicorn with `--reload` to enable auto-reloading on code changes. ## Modularity The modularity that Starlette is designed on promotes building re-usable components that can be shared between any ASGI framework. This should enable an ecosystem of shared middleware and mountable applications. The clean API separation also means it's easier to understand each component in isolation. ## Performance Independent TechEmpower benchmarks show Starlette applications running under Uvicorn as [one of the fastest Python frameworks available](https://www.techempower.com/benchmarks/#section=data-r17&hw=ph&test=fortune&l=zijzen-1). *(\*)* For high throughput loads you should: * Run using gunicorn using the `uvicorn` worker class. * Use one or two workers per-CPU core. (You might need to experiment with this.) * Disable access logging. Eg. ```shell gunicorn -w 4 -k uvicorn.workers.UvicornWorker --log-level warning example:app ``` Several of the ASGI servers also have pure Python implementations available, so you can also run under `PyPy` if your application code has parts that are CPU constrained. Either programatically: ```python uvicorn.run(..., http='h11', loop='asyncio') ``` Or using Gunicorn: ```shell gunicorn -k uvicorn.workers.UvicornH11Worker ... ```

— ⭐️ —

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation Platform: UNKNOWN Classifier: Development Status :: 3 - Alpha Classifier: Environment :: Web Environment Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: BSD License Classifier: Operating System :: OS Independent Classifier: Topic :: Internet :: WWW/HTTP Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Requires-Python: >=3.6 Description-Content-Type: text/markdown Provides-Extra: full starlette-0.18.0/starlette.egg-info/SOURCES.txt0000644000175100001710000000207614173233775022014 0ustar runnerdocker00000000000000LICENSE.md MANIFEST.in README.md setup.cfg setup.py starlette/__init__.py starlette/_compat.py starlette/applications.py starlette/authentication.py starlette/background.py starlette/concurrency.py starlette/config.py starlette/convertors.py starlette/datastructures.py starlette/endpoints.py starlette/exceptions.py starlette/formparsers.py starlette/requests.py starlette/responses.py starlette/routing.py starlette/schemas.py starlette/staticfiles.py starlette/status.py starlette/templating.py starlette/testclient.py starlette/types.py starlette/websockets.py starlette.egg-info/PKG-INFO starlette.egg-info/SOURCES.txt starlette.egg-info/dependency_links.txt starlette.egg-info/not-zip-safe starlette.egg-info/requires.txt starlette.egg-info/top_level.txt starlette/middleware/__init__.py starlette/middleware/authentication.py starlette/middleware/base.py starlette/middleware/cors.py starlette/middleware/errors.py starlette/middleware/gzip.py starlette/middleware/httpsredirect.py starlette/middleware/sessions.py starlette/middleware/trustedhost.py starlette/middleware/wsgi.pystarlette-0.18.0/starlette.egg-info/dependency_links.txt0000644000175100001710000000000114173233775024171 0ustar runnerdocker00000000000000 starlette-0.18.0/starlette.egg-info/not-zip-safe0000644000175100001710000000000114173233764022347 0ustar runnerdocker00000000000000 starlette-0.18.0/starlette.egg-info/requires.txt0000644000175100001710000000025214173233775022522 0ustar runnerdocker00000000000000anyio<4,>=3.0.0 [:python_version < "3.10"] typing_extensions [:python_version < "3.7"] contextlib2>=21.6.0 [full] itsdangerous jinja2 python-multipart pyyaml requests starlette-0.18.0/starlette.egg-info/top_level.txt0000644000175100001710000000001214173233775022646 0ustar runnerdocker00000000000000starlette