././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8510914 magic-wormhole-transit-relay-0.3.1/0000755000175000017500000000000014660511003016521 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/.coveragerc0000644000175000017500000000136114001344755020652 0ustar00meejahmeejah# -*- mode: conf -*- [run] # only record trace data for wormhole_transit_relay.* source = wormhole_transit_relay # and don't trace the test files themselves, or Versioneer's stuff omit = src/wormhole_transit_relay/test/* src/wormhole_transit_relay/_version.py # This allows 'coverage combine' to correlate the tracing data built while # running tests in multiple tox virtualenvs. To take advantage of this # properly, use "coverage erase" before tox, "coverage run --parallel-mode" # inside tox to avoid overwriting the output data (by writing it into # .coverage-XYZ instead of just .coverage), and run "coverage combine" # afterwards. [paths] source = src/ .tox/*/lib/python*/site-packages/ .tox/pypy*/site-packages/ ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/LICENSE0000644000175000017500000000205514001344755017537 0ustar00meejahmeejahMIT License Copyright (c) 2017 Brian Warner Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/MANIFEST.in0000644000175000017500000000034314001344755020266 0ustar00meejahmeejahinclude versioneer.py include src/wormhole_transit_relay/_version.py include LICENSE README.md NEWS.md recursive-include docs *.md *.rst *.dot include .coveragerc tox.ini include misc/*.py include misc/munin/wormhole_transit* ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027250.0 magic-wormhole-transit-relay-0.3.1/NEWS.md0000644000175000017500000000163714660510562017637 0ustar00meejahmeejahUser-visible changes in "magic-wormhole-transit-relay": ## unreleased * drop Python 2, Python 3.5 and 3.6 support * add Python 3.9, 3.10, 3.11 and 3.12 to CI * update versioneer to 0.29 ## Release 0.2.1 (11-Sep-2019) * listen on IPv4+IPv6 properly (#12) ## Release 0.2.0 (10-Sep-2019) * listen on IPv4+IPv6 socket by default (#12) * enable SO_KEEPALIVE on all connections (#9) * drop support for py3.3 and py3.4 * improve munin plugins ## Release 0.1.2 (19-Mar-2018) * Allow more simultaneous connections, by increasing the rlimits() ceiling at startup * Improve munin plugins * Get tests working on Windows ## Release 0.1.1 (14-Feb-2018) Improve logging and munin graphing tools: previous version would count bad handshakes twice (once as "errory", and again as "lonely"). The munin plugins have been renamed. ## Release 0.1.0 (12-Nov-2017) Initial release. Forked from magic-wormhole-0.10.3 (12-Sep-2017). ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8510914 magic-wormhole-transit-relay-0.3.1/PKG-INFO0000644000175000017500000000137114660511003017620 0ustar00meejahmeejahMetadata-Version: 2.1 Name: magic-wormhole-transit-relay Version: 0.3.1 Summary: Transit Relay server for Magic-Wormhole Home-page: https://github.com/warner/magic-wormhole-transit-relay Author: Brian Warner Author-email: warner-magic-wormhole@lothar.com License: MIT License-File: LICENSE Requires-Dist: twisted>=21.2.0 Requires-Dist: autobahn>=21.3.1 Requires-Dist: pypiwin32; sys_platform == "win32" Provides-Extra: dev Requires-Dist: mock; extra == "dev" Requires-Dist: tox; extra == "dev" Requires-Dist: pyflakes; extra == "dev" Provides-Extra: build Requires-Dist: twine; extra == "build" Requires-Dist: dulwich; extra == "build" Requires-Dist: readme_renderer; extra == "build" Requires-Dist: gpg; extra == "build" Requires-Dist: wheel; extra == "build" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027250.0 magic-wormhole-transit-relay-0.3.1/README.md0000644000175000017500000000346414660510562020020 0ustar00meejahmeejah# magic-wormhole-transit-relay [![PyPI](http://img.shields.io/pypi/v/magic-wormhole-transit-relay.svg)](https://pypi.python.org/pypi/magic-wormhole-transit-relay) ![Tests](https://github.com/magic-wormhole/magic-wormhole-transit-relay/workflows/Tests/badge.svg) [![codecov.io](https://codecov.io/github/magic-wormhole/magic-wormhole-transit-relay/coverage.svg?branch=master)](https://codecov.io/github/magic-wormhole/magic-wormhole-transit-relay?branch=master) Transit Relay server for Magic-Wormhole This repository implements the Magic-Wormhole "Transit Relay", a server that helps clients establish bulk-data transit connections even when both are behind NAT boxes. Each side makes a TCP connection to this server and presents a handshake. Two connections with identical handshakes are glued together, allowing them to pretend they have a direct connection. This server used to be included in the magic-wormhole repository, but was split out into a separate repo to aid deployment and development. ## Quick Example (running on a VPS) If you would like to set up a transit server on a VPS or other publicly-accessible server running Ubuntu: ``` # Install Python 3 pip and twist apt install python3-pip python3-twisted # Install magic-wormhole-transit-relay pip3 install magic-wormhole-transit-relay # Run transit-relay in the background twistd3 transitrelay # Check on logs cat twistd.log # or `tail -f twistd.log` # Kill transit-relay kill `cat twistd.pid` ``` Assuming you _haven't_ killed transit-relay, when you do `wormhole send`, make sure you add the `--transit-helper` argument, like: ``` wormhole send --transit-helper=tcp:[server ip here]:4001 file-to-send ``` On the receiving end, paste in the command output by `wormhole send`. ## Further Instructions See docs/running.md for instructions to launch the server. ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8390918 magic-wormhole-transit-relay-0.3.1/docs/0000755000175000017500000000000014660511003017451 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/docs/logging.md0000644000175000017500000000736314001344755021441 0ustar00meejahmeejah# Usage Logs The transit relay does not emit or record any logging by default. By adding option flags to the twist/twistd command line, you can enable one of two different kinds of logs. To avoid collecting information which could later be used to correlate clients with external network traces, logged information can be "blurred". This reduces the resolution of the data, retaining enough to answer questions about how much the server is being used, but discarding fine-grained timestamps or exact transfer sizes. The ``--blur-usage=`` option enables this, and it takes an integer value (in seconds) to specify the desired time window. ## Logging JSON Upon Each Connection If --log-fd is provided, a line will be written to the given (numeric) file descriptor after each connection is done. These events could be delivered to a comprehensive logging system like XXX for offline analysis. Each line will be a complete JSON object (starting with ``{``, ending with ``}\n``, and containing no internal newlines). The keys will be: * ``started``: number, seconds since epoch * ``total_time``: number, seconds from open to last close * ``waiting_time``: number, seconds from start to 2nd side appearing, or null * ``total_bytes``: number, total bytes relayed (sum of both directions) * ``mood``: string, one of: happy, lonely, errory A mood of ``happy`` means both sides gave a correct handshake. ``lonely`` means a second matching side never appeared (and thus ``waiting_time`` will be null). ``errory`` means the first side gave an invalid handshake. If --blur-usage= is provided, then ``started`` will be rounded to the given time interval, and ``total_bytes`` will be rounded to a fixed set of buckets: * file sizes less than 1MB: rounded to the next largest multiple of 10kB * less than 1GB: multiple of 1MB * 1GB or larger: multiple of 100MB ## Usage Database If --usage-db= is provided, the server will maintain a SQLite database in the given file. Current, recent, and historical usage data will be written to the database, and external tools can query the DB for metrics: the munin plugins in misc/ may be useful. Timestamps and sizes in this file will respect --blur-usage. The four tables are: ``current`` contains a single row, with these columns: * connected: number of paired connections * waiting: number of not-yet-paired connections * partal_bytes: bytes transmitted over not-yet-complete connections ``since_reboot`` contains a single row, with these columns: * bytes: sum of ``total_bytes`` * connections: number of completed connections * mood_happy: count of connections that finished "happy": both sides gave correct handshake * mood_lonely: one side gave good handshake, other side never showed up * mood_errory: one side gave a bad handshake ``all_time`` contains a single row, with these columns: * bytes: * connections: * mood_happy: * mood_lonely: * mood_errory: ``usage`` contains one row per closed connection, with these columns: * started: seconds since epoch, rounded to "blur time" * total_time: seconds from first open to last close * waiting_time: seconds from first open to second open, or None * bytes: total bytes relayed (in both directions) * result: (string) the mood: happy, lonely, errory All tables will be updated after each connection is finished. In addition, the ``current`` table will be updated at least once every 5 minutes. ## Logfiles for twistd If daemonized by twistd, the server will write ``twistd.pid`` and ``twistd.log`` files as usual. By default ``twistd.log`` will only contain startup, shutdown, and exception messages. Setting ``--log-fd=1`` (file descriptor 1 is always stdout) will cause the per-connection JSON lines to be interleaved with any messages sent to Twisted's logging system. It may be better to use a different file descriptor. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1661535598.0 magic-wormhole-transit-relay-0.3.1/docs/running.md0000644000175000017500000001532414302202556021463 0ustar00meejahmeejah# Running the Transit Relay First off, you probably don't need to run a relay. The ``wormhole`` command, as shipped from magic-wormhole.io, is configured to use a default Transit Relay operated by the author of Magic-Wormhole. This can be changed with the ``--transit-helper=`` argument, and other applications that import the Wormhole library might point elsewhere. The only reasons to run a separate relay are: * You are a kind-hearted server admin who wishes to support the project by paying the bandwidth costs incurred by your friends, who you instruct in the use of ``--transit-helper=``. * You publish a different application, and want to provide your users with a relay that fails at different times than the official one ## Installation To run a transit relay, first you need an environment to install it. * create a virtualenv * ``pip install magic-wormhole-transit-relay`` into this virtualenv ``` % virtualenv tr-venv ... % tr-venv/bin/pip install magic-wormhole-transit-relay ... ``` ## Running The transit relay is not a standalone program: rather it is a plugin for the Twisted application-running tools named ``twist`` (which only runs in the foreground) and ``twistd`` (which daemonizes). To run the relay for testing, use something like this: ``` % tr-venv/bin/twist transitrelay [ARGS] 2017-11-09T17:07:28-0800 [-] not blurring access times 2017-11-09T17:07:28-0800 [-] Transit starting on 4001 2017-11-09T17:07:28-0800 [wormhole_transit_relay.transit_server.Transit#info] Starting factory ... ``` The relevant arguments are: * ``--port=``: the endpoint to listen on, like ``tcp:4001`` * ``--log-fd=``: writes JSON lines to the given file descriptor for each connection * ``--usage-db=``: maintains a SQLite database with current and historical usage data * ``--blur-usage=``: round logged timestamps and data sizes For WebSockets support, two additional arguments: * ``--websocket``: the endpoint to listen for websocket connections on, like ``tcp:4002`` * ``--websocket-url``: the URL of the WebSocket connection. This may be different from the listening endpoint because of port-forwarding and so forth. By default it will be ``ws://localhost:`` if not provided When you use ``twist``, the relay runs in the foreground, so it will generally exit as soon as the controlling terminal exits. For persistent environments, you should daemonize the server. ## Minimizing Log Data The server code attempts to strike a balance between minimizing data collected about users, and recording enough information to manage the server and monitor its operation. The standard `twistd.log` file does not record IP addresses unless an error occurs. The optional `--log-fd=` file (and the SQLite database generated if `--usage-db=` is enabled) record the time at which the first side connected, the time until the second side connected, the total transfer time, the total number of bytes transferred, and the success/failure status (the "mood"). If `--blur-usage=` is provided, these recorded file sizes are rounded down: sizes less than 1kB are recorded as 0, sizes up to 1MB are rounded to the nearest kB, sizes up to 1GB are rounded to the nearest MB, and sizes above 1GB are rounded to the nearest 100MB. The argument to `--blur-usage=` is treated as a number of seconds, and the "first side connects" timestamp is rounded to a multiple of this. For example, `--blur-usage=3600` means all timestamps are rounded down to the nearest hour. The waiting time and total time deltas are recorded without rounding. ## Daemonization A production installation will want to daemonize the server somehow. One option is to use ``twistd`` (the daemonizing version of ``twist``). This takes the same plugin name and arguments as ``twist``, but forks into the background, detaches from the controlling terminal, and writes all output into a logfile: ``` % tr-venv/bin/twistd transitrelay [ARGS] % cat twistd.log 2017-11-09T17:07:28-0800 [-] not blurring access times 2017-11-09T17:07:28-0800 [-] Transit starting on 4001 2017-11-09T17:07:28-0800 [wormhole_transit_relay.transit_server.Transit#info] Starting factory ... % cat twistd.pid; echo 18985 ``` To shut down a ``twistd``-based server, you'll need to look in the ``twistd.pid`` file for the process id, and kill it: ``` % kill `cat twistd.pid` ``` To start the server each time the host reboots, you might use a crontab "@reboot" job, or a systemd unit. Another option is to run ``twist`` underneath a daemonization tool like ``daemontools`` or ``start-stop-daemon``. Since ``twist`` is just a regular program, this leaves the daemonization tool in charge of issues like restarting a process that exits unexpectedly, limiting the rate of respawning, and switching to the correct user-id and base directory. Packagers who create an installable transit-relay server package should choose a suitable daemonization tool that matches the practices of the target operating system. For example, Debian/Ubuntu packages should probably include a systemd unit that runs ``twist transitrelay`` in some ``/var/run/magic-wormhole-transit-relay/`` directory. Production environments that want to monitor the server for capacity management can use the ``--log-fd=`` option to emit logs, then route those logs into a suitable analysis tool. Other environments might be content to use ``--usage-db=`` and run the included Munin plugins to monitor usage. There is also a [Dockerfile](https://github.com/ggeorgovassilis/magic-wormhole-transit-relay-docker), written by George Georgovassilis, which you might find useful. ## Configuring Clients The transit relay will listen on an "endpoint" (usually a TCP port, but it could be a unix-domain socket or any other Endpoint that Twisted knows how to listen on). By default this is ``tcp:4001``. The relay does not know what hostname or IP address might point at it. Clients are configured with a "Transit Helper" setting that includes both the hostname and the port number, like the default ``tcp:transit.magic-wormhole.io:4001``. The standard ``wormhole`` tool takes a ``--transit-helper=`` argument to override this. Other applications that use ``wormhole`` as a library will have internal means to configure which transit relay they use. If you run your own transit relay, you will need to provide the new settings to your clients for it to be used. The standard ``wormhole`` tool is used by two sides: the sender and the receiver. Both sides exchange their configured transit relay with their partner. So if the sender overrides ``--transit-helper=`` but the receiver does not, they might wind up using either relay server, depending upon which one gets an established connection first. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/docs/transit.md0000644000175000017500000000405614001344755021473 0ustar00meejahmeejah# Transit Protocol The Transit protocol is responsible for establishing an encrypted bidirectional record stream between two programs. It must be given a "transit key" and a set of "hints" which help locate the other end (which are both delivered by Wormhole). The protocol tries hard to create a **direct** connection between the two ends, but if that fails, it uses a centralized relay server to ferry data between two separate TCP streams (one to each client). This repository provides that centralized relay server. For details of the protocol spoken by the clients, and the client-side API, please see ``transit.md`` in the magic-wormhole repository. ## Relay The **Transit Relay** is a host which offers TURN-like services for magic-wormhole instances. It uses a TCP-based protocol with a handshake to determine which connection wants to be connected to which. When connecting to a relay, the Transit client first writes RELAY-HANDSHAKE to the socket, which is `please relay %s\n`, where `%s` is the hex-encoded 32-byte HKDF derivative of the transit key, using `transit_relay_token` as the context. The client then waits for `ok\n`. The relay waits for a second connection that uses the same token. When this happens, the relay sends `ok\n` to both, then wires the connections together, so that everything received after the token on one is written out (after the ok) on the other. When either connection is lost, the other will be closed (the relay does not support "half-close"). When clients use a relay connection, they perform the usual sender/receiver handshake just after the `ok\n` is received: until that point they pretend the connection doesn't even exist. Direct connections are better, since they are faster and less expensive for the relay operator. If there are any potentially-viable direct connection hints available, the Transit instance will wait a few seconds before attempting to use the relay. If it has no viable direct hints, it will start using the relay right away. This prefers direct connections, but doesn't introduce completely unnecessary stalls. ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8390918 magic-wormhole-transit-relay-0.3.1/misc/0000755000175000017500000000000014660511003017454 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/misc/migrate_usage_db.py0000644000175000017500000000344314036101513023311 0ustar00meejahmeejah"""Migrate the usage data from the old bundled Transit Relay database. The magic-wormhole package used to include both servers (Rendezvous and Transit). "wormhole server" started both of these, and used the "relay.sqlite" database to store both immediate server state and long-term usage data. These were split out to their own packages: version 0.11 omitted the Transit Relay in favor of the new "magic-wormhole-transit-relay" distribution. This script reads the long-term Transit usage data from the pre-0.11 wormhole-server relay.sqlite, and copies it into a new "usage.sqlite" database in the current directory. It will refuse to touch an existing "usage.sqlite" file. The resuting "usage.sqlite" should be passed into --usage-db=, e.g. "twist transitrelay --usage=.../PATH/TO/usage.sqlite". """ import sys from wormhole_transit_relay.database import open_existing_db, create_db source_fn = sys.argv[1] source_db = open_existing_db(source_fn) target_db = create_db("usage.sqlite") num_rows = 0 for row in source_db.execute("SELECT * FROM `transit_usage`" " ORDER BY `started`").fetchall(): target_db.execute("INSERT INTO `usage`" " (`started`, `total_time`, `waiting_time`," " `total_bytes`, `result`)" " VALUES(?,?,?,?,?)", (row["started"], row["total_time"], row["waiting_time"], row["total_bytes"], row["result"])) num_rows += 1 target_db.execute("INSERT INTO `current`" " (`rebooted`, `updated`, `connected`, `waiting`," " `incomplete_bytes`)" " VALUES(?,?,?,?,?)", (0, 0, 0, 0, 0)) target_db.commit() print("usage database migrated (%d rows) into 'usage.sqlite'" % num_rows) sys.exit(0) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8390918 magic-wormhole-transit-relay-0.3.1/misc/munin/0000755000175000017500000000000014660511003020602 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/misc/munin/wormhole_transit_active0000755000175000017500000000164214036101513025464 0ustar00meejahmeejah#! /usr/bin/env python """ Use the following in /etc/munin/plugin-conf.d/wormhole : [wormhole_*] env.usagedb /path/to/your/wormhole/server/usage.sqlite """ import os, sys, time, sqlite3 CONFIG = """\ graph_title Magic-Wormhole Transit Active Channels graph_vlabel Channels graph_category wormhole waiting.label Transit Waiting waiting.draw LINE1 waiting.type GAUGE connected.label Transit Connected connected.draw LINE1 connected.type GAUGE """ if len(sys.argv) > 1 and sys.argv[1] == "config": print(CONFIG.rstrip()) sys.exit(0) dbfile = os.environ["usagedb"] assert os.path.exists(dbfile) db = sqlite3.connect(dbfile) MINUTE = 60.0 updated,waiting,connected = db.execute("SELECT `updated`,`waiting`,`connected`" " FROM `current`").fetchone() if time.time() > updated + 5*MINUTE: sys.exit(1) # expired print("waiting.value", waiting) print("connected.value", connected) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/misc/munin/wormhole_transit_bytes0000755000175000017500000000211514036101513025333 0ustar00meejahmeejah#! /usr/bin/env python """ Use the following in /etc/munin/plugin-conf.d/wormhole : [wormhole_*] env.usagedb /path/to/your/wormhole/server/usage.sqlite """ import os, sys, time, sqlite3 CONFIG = """\ graph_title Magic-Wormhole Transit Usage (since reboot) graph_vlabel Bytes Since Reboot graph_category wormhole bytes.label Transit Bytes (complete) bytes.draw LINE1 bytes.type GAUGE incomplete.label Transit Bytes (incomplete) incomplete.draw LINE1 incomplete.type GAUGE """ if len(sys.argv) > 1 and sys.argv[1] == "config": print(CONFIG.rstrip()) sys.exit(0) dbfile = os.environ["usagedb"] assert os.path.exists(dbfile) db = sqlite3.connect(dbfile) MINUTE = 60.0 updated,rebooted,incomplete = db.execute("SELECT `updated`,`rebooted`,`incomplete_bytes` FROM `current`").fetchone() if time.time() > updated + 5*MINUTE: sys.exit(1) # expired complete = db.execute("SELECT SUM(`total_bytes`) FROM `usage`" " WHERE `started` > ?", (rebooted,)).fetchone()[0] or 0 print("bytes.value", complete) print("incomplete.value", complete+incomplete) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/misc/munin/wormhole_transit_bytes_alltime0000755000175000017500000000204614036101513027045 0ustar00meejahmeejah#! /usr/bin/env python """ Use the following in /etc/munin/plugin-conf.d/wormhole : [wormhole_*] env.usagedb /path/to/your/wormhole/server/usage.sqlite """ import os, sys, time, sqlite3 CONFIG = """\ graph_title Magic-Wormhole Transit Usage (all time) graph_vlabel Bytes Since DB Creation graph_category wormhole bytes.label Transit Bytes (complete) bytes.draw LINE1 bytes.type GAUGE incomplete.label Transit Bytes (incomplete) incomplete.draw LINE1 incomplete.type GAUGE """ if len(sys.argv) > 1 and sys.argv[1] == "config": print(CONFIG.rstrip()) sys.exit(0) dbfile = os.environ["usagedb"] assert os.path.exists(dbfile) db = sqlite3.connect(dbfile) MINUTE = 60.0 updated,incomplete = db.execute("SELECT `updated`,`incomplete_bytes`" " FROM `current`").fetchone() if time.time() > updated + 5*MINUTE: sys.exit(1) # expired complete = db.execute("SELECT SUM(`total_bytes`)" " FROM `usage`").fetchone()[0] or 0 print("bytes.value", complete) print("incomplete.value", complete+incomplete) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/misc/munin/wormhole_transit_events0000755000175000017500000000345014036101513025514 0ustar00meejahmeejah#! /usr/bin/env python """ Use the following in /etc/munin/plugin-conf.d/wormhole : [wormhole_*] env.usagedb /path/to/your/wormhole/server/usage.sqlite """ import os, sys, time, sqlite3 CONFIG = """\ graph_title Magic-Wormhole Transit Server Events (since reboot) graph_vlabel Events Since Reboot graph_category wormhole happy.label Happy happy.draw LINE1 happy.type GAUGE errory.label Errory errory.draw LINE1 errory.type GAUGE lonely.label Lonely lonely.draw LINE1 lonely.type GAUGE redundant.label Redundant redundant.draw LINE1 redundant.type GAUGE """ if len(sys.argv) > 1 and sys.argv[1] == "config": print(CONFIG.rstrip()) sys.exit(0) dbfile = os.environ["usagedb"] assert os.path.exists(dbfile) db = sqlite3.connect(dbfile) MINUTE = 60.0 rebooted,updated = db.execute("SELECT `rebooted`, `updated` FROM `current`").fetchone() if time.time() > updated + 5*MINUTE: sys.exit(1) # expired count = db.execute("SELECT COUNT() FROM `usage`" " WHERE" " `started` > ? AND" " `result` = 'happy'", (rebooted,)).fetchone()[0] print("happy.value", count) count = db.execute("SELECT COUNT() FROM `usage`" " WHERE" " `started` > ? AND" " `result` = 'errory'", (rebooted,)).fetchone()[0] print("errory.value", count) count = db.execute("SELECT COUNT() FROM `usage`" " WHERE" " `started` > ? AND" " `result` = 'lonely'", (rebooted,)).fetchone()[0] print("lonely.value", count) count = db.execute("SELECT COUNT() FROM `usage`" " WHERE" " `started` > ? AND" " `result` = 'redundant'", (rebooted,)).fetchone()[0] print("redundant.value", count) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/misc/munin/wormhole_transit_events_alltime0000755000175000017500000000301714036101513027222 0ustar00meejahmeejah#! /usr/bin/env python """ Use the following in /etc/munin/plugin-conf.d/wormhole : [wormhole_*] env.usagedb /path/to/your/wormhole/server/usage.sqlite """ import os, sys, time, sqlite3 CONFIG = """\ graph_title Magic-Wormhole Transit Server Events (all time) graph_vlabel Events graph_category wormhole happy.label Happy happy.draw LINE1 happy.type GAUGE errory.label Errory errory.draw LINE1 errory.type GAUGE lonely.label Lonely lonely.draw LINE1 lonely.type GAUGE redundant.label Redundant redundant.draw LINE1 redundant.type GAUGE """ if len(sys.argv) > 1 and sys.argv[1] == "config": print(CONFIG.rstrip()) sys.exit(0) dbfile = os.environ["usagedb"] assert os.path.exists(dbfile) db = sqlite3.connect(dbfile) MINUTE = 60.0 rebooted,updated = db.execute("SELECT `rebooted`, `updated` FROM `current`").fetchone() if time.time() > updated + 5*MINUTE: sys.exit(1) # expired count = db.execute("SELECT COUNT() FROM `usage`" " WHERE `result` = 'happy'", ).fetchone()[0] print("happy.value", count) count = db.execute("SELECT COUNT() FROM `usage`" " WHERE `result` = 'errory'", ).fetchone()[0] print("errory.value", count) count = db.execute("SELECT COUNT() FROM `usage`" " WHERE `result` = 'lonely'", ).fetchone()[0] print("lonely.value", count) count = db.execute("SELECT COUNT() FROM `usage`" " WHERE `result` = 'redundant'", ).fetchone()[0] print("redundant.value", count) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8510914 magic-wormhole-transit-relay-0.3.1/setup.cfg0000644000175000017500000000035714660511003020347 0ustar00meejahmeejah[versioneer] VCS = git versionfile_source = src/wormhole_transit_relay/_version.py versionfile_build = wormhole_transit_relay/_version.py tag_prefix = parentdir_prefix = magic-wormhole-transit-relay [egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027272.0 magic-wormhole-transit-relay-0.3.1/setup.py0000644000175000017500000000200614660510610020234 0ustar00meejahmeejahfrom setuptools import setup import versioneer commands = versioneer.get_cmdclass() setup(name="magic-wormhole-transit-relay", version=versioneer.get_version(), description="Transit Relay server for Magic-Wormhole", author="Brian Warner", author_email="warner-magic-wormhole@lothar.com", license="MIT", url="https://github.com/warner/magic-wormhole-transit-relay", package_dir={"": "src"}, packages=["wormhole_transit_relay", "wormhole_transit_relay.test", "twisted.plugins", ], package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]}, install_requires=[ "twisted >= 21.2.0", "autobahn >= 21.3.1", ], extras_require={ ':sys_platform=="win32"': ["pypiwin32"], "dev": ["mock", "tox", "pyflakes"], "build": ["twine", "dulwich", "readme_renderer", "gpg", "wheel"], }, test_suite="wormhole_transit_relay.test", cmdclass=commands, ) ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1724027394.831092 magic-wormhole-transit-relay-0.3.1/src/0000755000175000017500000000000014660511003017310 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8390918 magic-wormhole-transit-relay-0.3.1/src/magic_wormhole_transit_relay.egg-info/0000755000175000017500000000000014660511003026736 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027394.0 magic-wormhole-transit-relay-0.3.1/src/magic_wormhole_transit_relay.egg-info/PKG-INFO0000644000175000017500000000137114660511002030034 0ustar00meejahmeejahMetadata-Version: 2.1 Name: magic-wormhole-transit-relay Version: 0.3.1 Summary: Transit Relay server for Magic-Wormhole Home-page: https://github.com/warner/magic-wormhole-transit-relay Author: Brian Warner Author-email: warner-magic-wormhole@lothar.com License: MIT License-File: LICENSE Requires-Dist: twisted>=21.2.0 Requires-Dist: autobahn>=21.3.1 Requires-Dist: pypiwin32; sys_platform == "win32" Provides-Extra: dev Requires-Dist: mock; extra == "dev" Requires-Dist: tox; extra == "dev" Requires-Dist: pyflakes; extra == "dev" Provides-Extra: build Requires-Dist: twine; extra == "build" Requires-Dist: dulwich; extra == "build" Requires-Dist: readme_renderer; extra == "build" Requires-Dist: gpg; extra == "build" Requires-Dist: wheel; extra == "build" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027394.0 magic-wormhole-transit-relay-0.3.1/src/magic_wormhole_transit_relay.egg-info/SOURCES.txt0000644000175000017500000000272114660511002030623 0ustar00meejahmeejah.coveragerc LICENSE MANIFEST.in NEWS.md README.md setup.cfg setup.py tox.ini versioneer.py docs/logging.md docs/running.md docs/transit.md misc/migrate_usage_db.py misc/munin/wormhole_transit_active misc/munin/wormhole_transit_bytes misc/munin/wormhole_transit_bytes_alltime misc/munin/wormhole_transit_events misc/munin/wormhole_transit_events_alltime src/magic_wormhole_transit_relay.egg-info/PKG-INFO src/magic_wormhole_transit_relay.egg-info/SOURCES.txt src/magic_wormhole_transit_relay.egg-info/dependency_links.txt src/magic_wormhole_transit_relay.egg-info/requires.txt src/magic_wormhole_transit_relay.egg-info/top_level.txt src/twisted/plugins/magic_wormhole_transit_relay.py src/wormhole_transit_relay/__init__.py src/wormhole_transit_relay/_version.py src/wormhole_transit_relay/database.py src/wormhole_transit_relay/increase_rlimits.py src/wormhole_transit_relay/server_state.py src/wormhole_transit_relay/server_tap.py src/wormhole_transit_relay/transit_server.py src/wormhole_transit_relay/usage.py src/wormhole_transit_relay/db-schemas/v1.sql src/wormhole_transit_relay/test/__init__.py src/wormhole_transit_relay/test/common.py src/wormhole_transit_relay/test/test_backpressure.py src/wormhole_transit_relay/test/test_config.py src/wormhole_transit_relay/test/test_database.py src/wormhole_transit_relay/test/test_rlimits.py src/wormhole_transit_relay/test/test_service.py src/wormhole_transit_relay/test/test_stats.py src/wormhole_transit_relay/test/test_transit_server.py././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027394.0 magic-wormhole-transit-relay-0.3.1/src/magic_wormhole_transit_relay.egg-info/dependency_links.txt0000644000175000017500000000000114660511002033003 0ustar00meejahmeejah ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027394.0 magic-wormhole-transit-relay-0.3.1/src/magic_wormhole_transit_relay.egg-info/requires.txt0000644000175000017500000000021714660511002031335 0ustar00meejahmeejahtwisted>=21.2.0 autobahn>=21.3.1 [:sys_platform=="win32"] pypiwin32 [build] twine dulwich readme_renderer gpg wheel [dev] mock tox pyflakes ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027394.0 magic-wormhole-transit-relay-0.3.1/src/magic_wormhole_transit_relay.egg-info/top_level.txt0000644000175000017500000000003714660511002031467 0ustar00meejahmeejahtwisted wormhole_transit_relay ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1724027394.831092 magic-wormhole-transit-relay-0.3.1/src/twisted/0000755000175000017500000000000014660511003020773 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8430917 magic-wormhole-transit-relay-0.3.1/src/twisted/plugins/0000755000175000017500000000000014660511003022454 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/src/twisted/plugins/magic_wormhole_transit_relay.py0000644000175000017500000000043714001344755030775 0ustar00meejahmeejahfrom twisted.application.service import ServiceMaker TransitRelay = ServiceMaker( "Magic-Wormhole Transit Relay", # name "wormhole_transit_relay.server_tap", # module "Provide the Transit Relay server for Magic-Wormhole clients.", # desc "transitrelay", # tapname ) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8430917 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/0000755000175000017500000000000014660511003024104 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027226.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/__init__.py0000644000175000017500000000011114660510532026214 0ustar00meejahmeejah from . import _version __version__ = _version.get_versions()['version'] ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8510914 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/_version.py0000644000175000017500000000076114660511003026306 0ustar00meejahmeejah # This file was generated by 'versioneer.py' (0.29) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. import json version_json = ''' { "date": "2024-08-18T20:29:48-0400", "dirty": false, "error": null, "full-revisionid": "019713c8f640bfcc12f994054e86e8a6be328d5c", "version": "0.3.1" } ''' # END VERSION_JSON def get_versions(): return json.loads(version_json) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/database.py0000644000175000017500000001147214036101513026225 0ustar00meejahmeejahimport os import sqlite3 import tempfile from pkg_resources import resource_string from twisted.python import log class DBError(Exception): pass def get_schema(version): schema_bytes = resource_string("wormhole_transit_relay", "db-schemas/v%d.sql" % version) return schema_bytes.decode("utf-8") ## def get_upgrader(new_version): ## schema_bytes = resource_string("wormhole_transit_relay", ## "db-schemas/upgrade-to-v%d.sql" % new_version) ## return schema_bytes.decode("utf-8") TARGET_VERSION = 1 def dict_factory(cursor, row): d = {} for idx, col in enumerate(cursor.description): d[col[0]] = row[idx] return d def _initialize_db_schema(db, target_version): """Creates the application schema in the given database. """ log.msg("populating new database with schema v%s" % target_version) schema = get_schema(target_version) db.executescript(schema) db.execute("INSERT INTO version (version) VALUES (?)", (target_version,)) db.commit() def _initialize_db_connection(db): """Sets up the db connection object with a row factory and with necessary foreign key settings. """ db.row_factory = dict_factory db.execute("PRAGMA foreign_keys = ON") problems = db.execute("PRAGMA foreign_key_check").fetchall() if problems: raise DBError("failed foreign key check: %s" % (problems,)) def _open_db_connection(dbfile): """Open a new connection to the SQLite3 database at the given path. """ try: db = sqlite3.connect(dbfile) _initialize_db_connection(db) except (EnvironmentError, sqlite3.OperationalError, sqlite3.DatabaseError) as e: # this indicates that the file is not a compatible database format. # Perhaps it was created with an old version, or it might be junk. raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) return db def _get_temporary_dbfile(dbfile): """Get a temporary filename near the given path. """ fd, name = tempfile.mkstemp( prefix=os.path.basename(dbfile) + ".", dir=os.path.dirname(dbfile) ) os.close(fd) return name def _atomic_create_and_initialize_db(dbfile, target_version): """Create and return a new database, initialized with the application schema. If anything goes wrong, nothing is left at the ``dbfile`` path. """ temp_dbfile = _get_temporary_dbfile(dbfile) db = _open_db_connection(temp_dbfile) _initialize_db_schema(db, target_version) db.close() os.rename(temp_dbfile, dbfile) return _open_db_connection(dbfile) def get_db(dbfile, target_version=TARGET_VERSION): """Open or create the given db file. The parent directory must exist. Returns the db connection object, or raises DBError. """ if dbfile == ":memory:": db = _open_db_connection(dbfile) _initialize_db_schema(db, target_version) elif os.path.exists(dbfile): db = _open_db_connection(dbfile) else: db = _atomic_create_and_initialize_db(dbfile, target_version) version = db.execute("SELECT version FROM version").fetchone()["version"] ## while version < target_version: ## log.msg(" need to upgrade from %s to %s" % (version, target_version)) ## try: ## upgrader = get_upgrader(version+1) ## except ValueError: # ResourceError?? ## log.msg(" unable to upgrade %s to %s" % (version, version+1)) ## raise DBError("Unable to upgrade %s to version %s, left at %s" ## % (dbfile, version+1, version)) ## log.msg(" executing upgrader v%s->v%s" % (version, version+1)) ## db.executescript(upgrader) ## db.commit() ## version = version+1 if version != target_version: raise DBError("Unable to handle db version %s" % version) return db class DBDoesntExist(Exception): pass def open_existing_db(dbfile): assert dbfile != ":memory:" if not os.path.exists(dbfile): raise DBDoesntExist() return _open_db_connection(dbfile) class DBAlreadyExists(Exception): pass def create_db(dbfile): """Create the given db file. Refuse to touch a pre-existing file. This is meant for use by migration tools, to create the output target""" if dbfile == ":memory:": db = _open_db_connection(dbfile) _initialize_db_schema(db, TARGET_VERSION) elif os.path.exists(dbfile): raise DBAlreadyExists() else: db = _atomic_create_and_initialize_db(dbfile, TARGET_VERSION) return db def dump_db(db): # to let _iterdump work, we need to restore the original row factory orig = db.row_factory try: db.row_factory = sqlite3.Row return "".join(db.iterdump()) finally: db.row_factory = orig ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8430917 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/db-schemas/0000755000175000017500000000000014660511003026112 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/db-schemas/v1.sql0000644000175000017500000000223414001344755027171 0ustar00meejahmeejah CREATE TABLE `version` -- contains one row ( `version` INTEGER -- set to 1 ); CREATE TABLE `current` -- contains one row ( `rebooted` INTEGER, -- seconds since epoch of most recent reboot `updated` INTEGER, -- when `current` was last updated `connected` INTEGER, -- number of current paired connections `waiting` INTEGER, -- number of not-yet-paired connections `incomplete_bytes` INTEGER -- bytes sent through not-yet-complete connections ); CREATE TABLE `usage` ( `started` INTEGER, -- seconds since epoch, rounded to "blur time" `total_time` INTEGER, -- seconds from open to last close `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None `total_bytes` INTEGER, -- total bytes relayed (both directions) `result` VARCHAR -- happy, scary, lonely, errory, pruney -- transit moods: -- "errory": one side gave the wrong handshake -- "lonely": good handshake, but the other side never showed up -- "redundant": good handshake, abandoned in favor of different connection -- "happy": both sides gave correct handshake ); CREATE INDEX `usage_started_index` ON `usage` (`started`); CREATE INDEX `usage_result_index` ON `usage` (`result`); ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/increase_rlimits.py0000644000175000017500000000266214001344755030027 0ustar00meejahmeejahtry: # 'resource' is unix-only from resource import getrlimit, setrlimit, RLIMIT_NOFILE except ImportError: # pragma: nocover getrlimit, setrlimit, RLIMIT_NOFILE = None, None, None # pragma: nocover from twisted.python import log def increase_rlimits(): if getrlimit is None: log.msg("unable to import 'resource', leaving rlimit alone") return soft, hard = getrlimit(RLIMIT_NOFILE) if soft >= 10000: log.msg("RLIMIT_NOFILE.soft was %d, leaving it alone" % soft) return # OS-X defaults to soft=7168, and reports a huge number for 'hard', # but won't accept anything more than soft=10240, so we can't just # set soft=hard. Linux returns (1024, 1048576) and is fine with # soft=hard. Cygwin is reported to return (256,-1) and accepts up to # soft=3200. So we try multiple values until something works. for newlimit in [hard, 10000, 3200, 1024]: log.msg("changing RLIMIT_NOFILE from (%s,%s) to (%s,%s)" % (soft, hard, newlimit, hard)) try: setrlimit(RLIMIT_NOFILE, (newlimit, hard)) log.msg("setrlimit successful") return except ValueError as e: log.msg("error during setrlimit: %s" % e) continue except: log.msg("other error during setrlimit, leaving it alone") log.err() return log.msg("unable to change rlimit, leaving it alone") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1665609062.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/server_state.py0000644000175000017500000003225514321626546027210 0ustar00meejahmeejahfrom collections import defaultdict import automat from twisted.python import log from zope.interface import ( Interface, Attribute, ) class ITransitClient(Interface): """ Represents the client side of a connection to this transit relay. This is used by TransitServerState instances. """ started_time = Attribute("timestamp when the connection was established") def send(data): """ Send some byets to the client """ def disconnect(): """ Disconnect the client transport """ def connect_partner(other): """ Hook up to our partner. :param ITransitClient other: our partner """ def disconnect_partner(): """ Disconnect our partner's transport """ class ActiveConnections(object): """ Tracks active connections. A connection is 'active' when both sides have shown up and they are glued together (and thus could be passing data back and forth if any is flowing). """ def __init__(self): self._connections = set() def register(self, side0, side1): """ A connection has become active so register both its sides :param TransitConnection side0: one side of the connection :param TransitConnection side1: one side of the connection """ self._connections.add(side0) self._connections.add(side1) def unregister(self, side): """ One side of a connection has become inactive. :param TransitConnection side: an inactive side of a connection """ self._connections.discard(side) class PendingRequests(object): """ Tracks outstanding (non-"active") requests. We register client connections against the tokens we have received. When the other side shows up we can thus match it to the correct partner connection. At this point, the connection becomes "active" is and is thus no longer "pending" and so will no longer be in this collection. """ def __init__(self, active_connections): """ :param active_connections: an instance of ActiveConnections where connections are put when both sides arrive. """ self._requests = defaultdict(set) # token -> set((side, TransitConnection)) self._active = active_connections def unregister(self, token, side, tc): """ We no longer care about a particular client (e.g. it has disconnected). """ if token in self._requests: self._requests[token].discard((side, tc)) if not self._requests[token]: # no more sides; token is dead del self._requests[token] self._active.unregister(tc) def register(self, token, new_side, new_tc): """ A client has connected and successfully offered a token (and optional 'side' token). If this is the first one for this token, we merely remember it. If it is the second side for this token we connect them together. :param bytes token: the token for this connection. :param bytes new_side: None or the side token for this connection :param TransitServerState new_tc: the state-machine of the connection :returns bool: True if we are the first side to register this token """ potentials = self._requests[token] for old in potentials: (old_side, old_tc) = old if ((old_side is None) or (new_side is None) or (old_side != new_side)): # we found a match # drop and stop tracking the rest potentials.remove(old) for (_, leftover_tc) in potentials.copy(): # Don't record this as errory. It's just a spare connection # from the same side as a connection that got used. This # can happen if the connection hint contains multiple # addresses (we don't currently support those, but it'd # probably be useful in the future). leftover_tc.partner_connection_lost() self._requests.pop(token, None) # glue the two ends together self._active.register(new_tc, old_tc) new_tc.got_partner(old_tc) old_tc.got_partner(new_tc) return False potentials.add((new_side, new_tc)) return True # TODO: timer class TransitServerState(object): """ Encapsulates the state-machine of the server side of a transit relay connection. Once the protocol has been told to relay (or to relay for a side) it starts passing all received bytes to the other side until it closes. """ _machine = automat.MethodicalMachine() _client = None _buddy = None _token = None _side = None _first = None _mood = "empty" _total_sent = 0 def __init__(self, pending_requests, usage_recorder): self._pending_requests = pending_requests self._usage = usage_recorder def get_token(self): """ :returns str: a string describing our token. This will be "-" if we have no token yet, or "{16 chars}-" if we have just a token or "{16 chars}-{16 chars}" if we have a token and a side. """ d = "-" if self._token is not None: d = self._token[:16].decode("ascii") if self._side is not None: d += "-" + self._side.decode("ascii") else: d += "-" return d @_machine.input() def connection_made(self, client): """ A client has connected. May only be called once. :param ITransitClient client: our client. """ # NB: the "only called once" is enforced by the state-machine; # this input is only valid for the "listening" state, to which # we never return. @_machine.input() def please_relay(self, token): """ A 'please relay X' message has been received (the original version of the protocol). """ @_machine.input() def please_relay_for_side(self, token, side): """ A 'please relay X for side Y' message has been received (the second version of the protocol). """ @_machine.input() def bad_token(self): """ A bad token / relay line was received (e.g. couldn't be parsed) """ @_machine.input() def got_partner(self, client): """ The partner for this relay session has been found """ @_machine.input() def connection_lost(self): """ Our transport has failed. """ @_machine.input() def partner_connection_lost(self): """ Our partner's transport has failed. """ @_machine.input() def got_bytes(self, data): """ Some bytes have arrived (that aren't part of the handshake) """ @_machine.output() def _remember_client(self, client): self._client = client # note that there is no corresponding "_forget_client" because we # may still want to access it after it is gone .. for example, to # get the .started_time for logging purposes @_machine.output() def _register_token(self, token): return self._real_register_token_for_side(token, None) @_machine.output() def _register_token_for_side(self, token, side): return self._real_register_token_for_side(token, side) @_machine.output() def _unregister(self): """ remove us from the thing that remembers tokens and sides """ return self._pending_requests.unregister(self._token, self._side, self) @_machine.output() def _send_bad(self): self._mood = "errory" self._client.send(b"bad handshake\n") if self._client.factory.log_requests: log.msg("transit handshake failure") @_machine.output() def _send_ok(self): self._client.send(b"ok\n") @_machine.output() def _send_impatient(self): self._client.send(b"impatient\n") if self._client.factory.log_requests: log.msg("transit impatience failure") @_machine.output() def _count_bytes(self, data): self._total_sent += len(data) @_machine.output() def _send_to_partner(self, data): self._buddy._client.send(data) @_machine.output() def _connect_partner(self, client): self._buddy = client self._client.connect_partner(client) @_machine.output() def _disconnect(self): self._client.disconnect() @_machine.output() def _disconnect_partner(self): self._client.disconnect_partner() # some outputs to record "usage" information .. @_machine.output() def _record_usage(self): if self._mood == "jilted": if self._buddy and self._buddy._mood == "happy": return self._usage.record( started=self._client.started_time, buddy_started=self._buddy._client.started_time if self._buddy is not None else None, result=self._mood, bytes_sent=self._total_sent, buddy_bytes=self._buddy._total_sent if self._buddy is not None else None ) # some outputs to record the "mood" .. @_machine.output() def _mood_happy(self): self._mood = "happy" @_machine.output() def _mood_lonely(self): self._mood = "lonely" @_machine.output() def _mood_redundant(self): self._mood = "redundant" @_machine.output() def _mood_impatient(self): self._mood = "impatient" @_machine.output() def _mood_errory(self): self._mood = "errory" @_machine.output() def _mood_happy_if_first(self): """ We disconnected first so we're only happy if we also connected first. """ if self._first: self._mood = "happy" else: self._mood = "jilted" def _real_register_token_for_side(self, token, side): """ A client has connected and sent a valid version 1 or version 2 handshake. If the former, `side` will be None. In either case, we remember the tokens and register ourselves. This might result in 'got_partner' notifications to two state-machines if this is the second side for a given token. :param bytes token: the token :param bytes side: The side token (or None) """ self._token = token self._side = side self._first = self._pending_requests.register(token, side, self) @_machine.state(initial=True) def listening(self): """ Initial state, awaiting connection. """ @_machine.state() def wait_relay(self): """ Waiting for a 'relay' message """ @_machine.state() def wait_partner(self): """ Waiting for our partner to connect """ @_machine.state() def relaying(self): """ Relaying bytes to our partner """ @_machine.state() def done(self): """ Terminal state """ listening.upon( connection_made, enter=wait_relay, outputs=[_remember_client], ) listening.upon( connection_lost, enter=done, outputs=[_mood_errory], ) wait_relay.upon( please_relay, enter=wait_partner, outputs=[_mood_lonely, _register_token], ) wait_relay.upon( please_relay_for_side, enter=wait_partner, outputs=[_mood_lonely, _register_token_for_side], ) wait_relay.upon( bad_token, enter=done, outputs=[_mood_errory, _send_bad, _disconnect, _record_usage], ) wait_relay.upon( got_bytes, enter=done, outputs=[_count_bytes, _mood_errory, _disconnect, _record_usage], ) wait_relay.upon( connection_lost, enter=done, outputs=[_disconnect, _record_usage], ) wait_partner.upon( got_partner, enter=relaying, outputs=[_mood_happy, _send_ok, _connect_partner], ) wait_partner.upon( connection_lost, enter=done, outputs=[_mood_lonely, _unregister, _record_usage], ) wait_partner.upon( got_bytes, enter=done, outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister, _record_usage], ) wait_partner.upon( partner_connection_lost, enter=done, outputs=[_mood_redundant, _disconnect, _record_usage], ) relaying.upon( got_bytes, enter=relaying, outputs=[_count_bytes, _send_to_partner], ) relaying.upon( connection_lost, enter=done, outputs=[_mood_happy_if_first, _disconnect_partner, _unregister, _record_usage], ) done.upon( connection_lost, enter=done, outputs=[], ) done.upon( partner_connection_lost, enter=done, outputs=[], ) # uncomment to turn on state-machine tracing # set_trace_function = _machine._setTrace ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1663026487.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/server_tap.py0000644000175000017500000000622314307742467026655 0ustar00meejahmeejahimport os from twisted.internet import reactor from twisted.python import usage from twisted.application.service import MultiService from twisted.application.internet import (TimerService, StreamServerEndpointService) from twisted.internet import endpoints from twisted.internet import protocol from autobahn.twisted.websocket import WebSocketServerFactory from . import transit_server from .usage import create_usage_tracker from .increase_rlimits import increase_rlimits from .database import get_db LONGDESC = """\ This plugin sets up a 'Transit Relay' server for magic-wormhole. This service listens for TCP connections, finds pairs which present the same handshake, and glues the two TCP sockets together. """ class Options(usage.Options): synopsis = "[--port=] [--log-fd] [--blur-usage=] [--usage-db=]" longdesc = LONGDESC optParameters = [ ("port", "p", "tcp:4001:interface=\:\:", "endpoint to listen on"), ("websocket", "w", None, "endpoint to listen for WebSocket connections"), ("websocket-url", "u", None, "WebSocket URL (derived from endpoint if not provided)"), ("blur-usage", None, None, "blur timestamps and data sizes in logs"), ("log-fd", None, None, "write JSON usage logs to this file descriptor"), ("usage-db", None, None, "record usage data (SQLite)"), ] def opt_blur_usage(self, arg): self["blur-usage"] = int(arg) def makeService(config, reactor=reactor): increase_rlimits() tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen ws_ep = ( endpoints.serverFromString(reactor, config["websocket"]) if config["websocket"] is not None else None ) log_file = ( os.fdopen(int(config["log-fd"]), "w") if config["log-fd"] is not None else None ) db = None if config["usage-db"] is None else get_db(config["usage-db"]) usage = create_usage_tracker( blur_usage=config["blur-usage"], log_file=log_file, usage_db=db, ) transit = transit_server.Transit(usage, reactor.seconds) tcp_factory = protocol.ServerFactory() tcp_factory.protocol = transit_server.TransitConnection tcp_factory.log_requests = False if ws_ep is not None: ws_url = config["websocket-url"] if ws_url is None: # we're using a "private" attribute here but I don't see # any useful alternative unless we also want to parse # Twisted endpoint-strings. ws_url = "ws://localhost:{}/".format(ws_ep._port) print("Using WebSocket URL '{}'".format(ws_url)) ws_factory = WebSocketServerFactory(ws_url) ws_factory.protocol = transit_server.WebSocketTransitConnection ws_factory.transit = transit ws_factory.log_requests = False tcp_factory.transit = transit parent = MultiService() StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent) if ws_ep is not None: StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent) TimerService(5*60.0, transit.update_stats).setServiceParent(parent) return parent ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1724027394.8470917 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/0000755000175000017500000000000014660511003025063 5ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1610992109.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/__init__.py0000644000175000017500000000000014001344755027171 0ustar00meejahmeejah././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1665609062.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/common.py0000644000175000017500000000751114321626546026746 0ustar00meejahmeejahfrom twisted.internet.protocol import ( ClientFactory, Protocol, ) from twisted.test import iosim from zope.interface import ( Interface, Attribute, implementer, ) from ..transit_server import ( Transit, TransitConnection, ) from twisted.internet.protocol import ServerFactory from ..usage import create_usage_tracker class IRelayTestClient(Interface): """ The client interface used by tests. """ connected = Attribute("True if we are currently connected else False") def send(data): """ Send some bytes. :param bytes data: the data to send """ def disconnect(): """ Terminate the connection. """ def get_received_data(): """ :returns: all the bytes received from the server on this connection. """ def reset_data(): """ Erase any received data to this point. """ class ServerBase: log_requests = False def setUp(self): self._pumps = [] self._lp = None if self.log_requests: blur_usage = None else: blur_usage = 60.0 self._setup_relay(blur_usage=blur_usage) def flush(self): did_work = False for pump in self._pumps: did_work = pump.flush() or did_work if did_work: self.flush() def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): usage = create_usage_tracker( blur_usage=blur_usage, log_file=log_file, usage_db=usage_db, ) self._transit_server = Transit(usage, lambda: 123456789.0) def new_protocol(self): """ This should be overridden by derived test-case classes to decide if they want a TCP or WebSockets protocol. """ raise NotImplementedError() def new_protocol_tcp(self): """ Create a new client protocol connected to the server. :returns: a IRelayTestClient implementation """ server_factory = ServerFactory() server_factory.protocol = TransitConnection server_factory.transit = self._transit_server server_factory.log_requests = self.log_requests server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) @implementer(IRelayTestClient) class TransitClientProtocolTcp(Protocol): """ Speak the transit client protocol used by the tests over TCP """ _received = b"" connected = False # override Protocol callbacks def connectionMade(self): self.connected = True return Protocol.connectionMade(self) def connectionLost(self, reason): self.connected = False return Protocol.connectionLost(self, reason) def dataReceived(self, data): self._received = self._received + data # IRelayTestClient def send(self, data): self.transport.write(data) def disconnect(self): self.transport.loseConnection() def reset_received_data(self): self._received = b"" def get_received_data(self): return self._received client_factory = ClientFactory() client_factory.protocol = TransitClientProtocolTcp client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) pump = iosim.connect( server_protocol, iosim.makeFakeServer(server_protocol), client_protocol, iosim.makeFakeClient(client_protocol), ) pump.flush() self._pumps.append(pump) return client_protocol def tearDown(self): if self._lp: return self._lp.stopListening() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027250.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/test_backpressure.py0000644000175000017500000001511614660510562031202 0ustar00meejahmeejahfrom io import ( StringIO, ) import sys import shutil from twisted.trial import unittest from twisted.internet.interfaces import ( IPullProducer, ) from twisted.internet.protocol import ( ProcessProtocol, ) from twisted.internet.defer import ( inlineCallbacks, Deferred, ) from autobahn.twisted.websocket import ( WebSocketClientProtocol, create_client_agent, ) from zope.interface import implementer class _CollectOutputProtocol(ProcessProtocol): """ Internal helper. Collects all output (stdout + stderr) into self.output, and callback's on done with all of it after the process exits (for any reason). """ def __init__(self): self.done = Deferred() self.running = Deferred() self.output = StringIO() def processEnded(self, reason): if not self.done.called: self.done.callback(self.output.getvalue()) def outReceived(self, data): print(data.decode(), end="", flush=True) self.output.write(data.decode(sys.getfilesystemencoding())) if not self.running.called: if "on 8088" in self.output.getvalue(): self.running.callback(None) def errReceived(self, data): print("ERR: {}".format(data.decode(sys.getfilesystemencoding()))) self.output.write(data.decode(sys.getfilesystemencoding())) def run_transit(reactor, proto, tcp_port=None, websocket_port=None): exe = shutil.which("twistd") args = [ exe, "-n", "transitrelay", ] if tcp_port is not None: args.append("--port") args.append(tcp_port) if websocket_port is not None: args.append("--websocket") args.append(websocket_port) proc = reactor.spawnProcess(proto, exe, args) return proc class Sender(WebSocketClientProtocol): """ """ def __init__(self, *args, **kw): WebSocketClientProtocol.__init__(self, *args, **kw) self.done = Deferred() self.got_ok = Deferred() def onMessage(self, payload, is_binary): print("onMessage") if not self.got_ok.called: if payload == b"ok\n": self.got_ok.callback(None) print("send: {}".format(payload.decode("utf8"))) def onClose(self, clean, code, reason): print(f"close: {clean} {code} {reason}") self.done.callback(None) class Receiver(WebSocketClientProtocol): """ """ def __init__(self, *args, **kw): WebSocketClientProtocol.__init__(self, *args, **kw) self.done = Deferred() self.first_message = Deferred() self.received = 0 def onMessage(self, payload, is_binary): print("recv: {}".format(len(payload))) self.received += len(payload) if not self.first_message.called: self.first_message.callback(None) def onClose(self, clean, code, reason): print(f"close: {clean} {code} {reason}") self.done.callback(None) class TransitWebSockets(unittest.TestCase): """ Integration-style tests of the transit WebSocket relay, using the real reactor (and running transit as a subprocess). """ @inlineCallbacks def test_buffer_fills(self): """ A running transit relay stops accepting incoming data at a reasonable amount if the peer isn't reading. This test defines that as 'less than 100MiB' although in practice Twisted seems to stop before 10MiB. """ from twisted.internet import reactor transit_proto = _CollectOutputProtocol() transit_proc = run_transit(reactor, transit_proto, websocket_port="tcp:8088") def cleanup_process(): transit_proc.signalProcess("HUP") return transit_proto.done self.addCleanup(cleanup_process) yield transit_proto.running print("Transit running") agent = create_client_agent(reactor) side_a = yield agent.open("ws://localhost:8088", {}, lambda: Sender()) side_b = yield agent.open("ws://localhost:8088", {}, lambda: Receiver()) side_a.sendMessage(b"please relay aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa for side aaaaaaaaaaaaaaaa", True) side_b.sendMessage(b"please relay aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa for side bbbbbbbbbbbbbbbb", True) yield side_a.got_ok yield side_b.first_message # remove side_b's filedescriptor from the reactor .. this # means it will not read any more data reactor.removeReader(side_b.transport) # attempt to send up to 100MiB through side_a .. we should get # backpressure before that works which only manifests itself # as this producer not being asked to produce more max_data = 1024*1024*100 # 100MiB @implementer(IPullProducer) class ProduceMessages: def __init__(self, ws, on_produce): self._ws = ws self._sent = 0 self._max = max_data self._on_produce = on_produce def resumeProducing(self): self._on_produce() if self._sent >= self._max: self._ws.sendClose() return data = b"a" * 1024*1024 self._ws.sendMessage(data, True) self._sent += len(data) print("sent {}, total {}".format(len(data), self._sent)) # our only signal is, "did our producer get asked to produce # more data" which it should do periodically. We want to stop # if we haven't seen a new data request for a while -- defined # as "more than 5 seconds". done = Deferred() last_produce = None timeout = 2 # seconds def asked_for_data(): nonlocal last_produce last_produce = reactor.seconds() data = ProduceMessages(side_a, asked_for_data) side_a.transport.registerProducer(data, False) data.resumeProducing() def check_if_done(): if last_produce is not None: if reactor.seconds() - last_produce > timeout: done.callback(None) return # recursive call to ourselves to check again soon reactor.callLater(.1, check_if_done) check_if_done() yield done mib = 1024*1024.0 print("Sent {}MiB of {}MiB before backpressure".format(data._sent / mib, max_data / mib)) self.assertTrue(data._sent < max_data, "Too much data sent") side_a.sendClose() side_b.sendClose() yield side_a.done yield side_b.done ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1661535598.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/test_config.py0000644000175000017500000000323114302202556027743 0ustar00meejahmeejahfrom twisted.trial import unittest from .. import server_tap PORT = "tcp:4001:interface=\:\:" class Config(unittest.TestCase): def test_defaults(self): o = server_tap.Options() o.parseOptions([]) self.assertEqual(o, {"blur-usage": None, "log-fd": None, "usage-db": None, "port": PORT, "websocket": None, "websocket-url": None}) def test_blur(self): o = server_tap.Options() o.parseOptions(["--blur-usage=60"]) self.assertEqual(o, {"blur-usage": 60, "log-fd": None, "usage-db": None, "port": PORT, "websocket": None, "websocket-url": None}) def test_websocket(self): o = server_tap.Options() o.parseOptions(["--websocket=tcp:4004"]) self.assertEqual(o, {"blur-usage": None, "log-fd": None, "usage-db": None, "port": PORT, "websocket": "tcp:4004", "websocket-url": None}) def test_websocket_url(self): o = server_tap.Options() o.parseOptions(["--websocket=tcp:4004", "--websocket-url=ws://example.com/"]) self.assertEqual(o, {"blur-usage": None, "log-fd": None, "usage-db": None, "port": PORT, "websocket": "tcp:4004", "websocket-url": "ws://example.com/"}) def test_string(self): o = server_tap.Options() s = str(o) self.assertIn("This plugin sets up a 'Transit Relay'", s) self.assertIn("--blur-usage=", s) self.assertIn("blur timestamps and data sizes in logs", s) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/test_database.py0000644000175000017500000001167014036101513030243 0ustar00meejahmeejahimport os from twisted.python import filepath from twisted.trial import unittest from .. import database from ..database import get_db, TARGET_VERSION, dump_db, DBError class Get(unittest.TestCase): def test_create_default(self): db_url = ":memory:" db = get_db(db_url) rows = db.execute("SELECT * FROM version").fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["version"], TARGET_VERSION) def test_open_existing_file(self): basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "normal.db") db = get_db(fn) rows = db.execute("SELECT * FROM version").fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["version"], TARGET_VERSION) db2 = get_db(fn) rows = db2.execute("SELECT * FROM version").fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["version"], TARGET_VERSION) def test_open_bad_version(self): basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "old.db") db = get_db(fn) db.execute("UPDATE version SET version=999") db.commit() with self.assertRaises(DBError) as e: get_db(fn) self.assertIn("Unable to handle db version 999", str(e.exception)) def test_open_corrupt(self): basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "corrupt.db") with open(fn, "wb") as f: f.write(b"I am not a database") with self.assertRaises(DBError) as e: get_db(fn) self.assertIn("not a database", str(e.exception)) def test_failed_create_allows_subsequent_create(self): patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema") dbfile = filepath.FilePath(self.mktemp()) self.assertRaises(Exception, lambda: get_db(dbfile.path)) patch.restore() get_db(dbfile.path) def OFF_test_upgrade(self): # disabled until we add a v2 schema basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "upgrade.db") self.assertNotEqual(TARGET_VERSION, 2) # create an old-version DB in a file db = get_db(fn, 2) rows = db.execute("SELECT * FROM version").fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["version"], 2) del db # then upgrade the file to the latest version dbA = get_db(fn, TARGET_VERSION) rows = dbA.execute("SELECT * FROM version").fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["version"], TARGET_VERSION) dbA_text = dump_db(dbA) del dbA # make sure the upgrades got committed to disk dbB = get_db(fn, TARGET_VERSION) dbB_text = dump_db(dbB) del dbB self.assertEqual(dbA_text, dbB_text) # The upgraded schema should be equivalent to that of a new DB. # However a text dump will differ because ALTER TABLE always appends # the new column to the end of a table, whereas our schema puts it # somewhere in the middle (wherever it fits naturally). Also ALTER # TABLE doesn't include comments. if False: latest_db = get_db(":memory:", TARGET_VERSION) latest_text = dump_db(latest_db) with open("up.sql","w") as f: f.write(dbA_text) with open("new.sql","w") as f: f.write(latest_text) # check with "diff -u _trial_temp/up.sql _trial_temp/new.sql" self.assertEqual(dbA_text, latest_text) class Create(unittest.TestCase): def test_memory(self): db = database.create_db(":memory:") latest_text = dump_db(db) self.assertIn("CREATE TABLE", latest_text) def test_preexisting(self): basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "preexisting.db") with open(fn, "w"): pass with self.assertRaises(database.DBAlreadyExists): database.create_db(fn) def test_create(self): basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "created.db") db = database.create_db(fn) latest_text = dump_db(db) self.assertIn("CREATE TABLE", latest_text) class Open(unittest.TestCase): def test_open(self): basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "created.db") db1 = database.create_db(fn) latest_text = dump_db(db1) self.assertIn("CREATE TABLE", latest_text) db2 = database.open_existing_db(fn) self.assertIn("CREATE TABLE", dump_db(db2)) def test_doesnt_exist(self): basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "created.db") with self.assertRaises(database.DBDoesntExist): database.open_existing_db(fn) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/test_rlimits.py0000644000175000017500000000526714036101513030167 0ustar00meejahmeejahfrom unittest import mock from twisted.trial import unittest from ..increase_rlimits import increase_rlimits class RLimits(unittest.TestCase): def test_rlimit(self): def patch_r(name, *args, **kwargs): return mock.patch("wormhole_transit_relay.increase_rlimits." + name, *args, **kwargs) fakelog = [] def checklog(*expected): self.assertEqual(fakelog, list(expected)) fakelog[:] = [] NF = "NOFILE" mock_NF = patch_r("RLIMIT_NOFILE", NF) with patch_r("log.msg", fakelog.append): with patch_r("getrlimit", None): increase_rlimits() checklog("unable to import 'resource', leaving rlimit alone") with mock_NF: with patch_r("getrlimit", return_value=(20000, 30000)) as gr: increase_rlimits() self.assertEqual(gr.mock_calls, [mock.call(NF)]) checklog("RLIMIT_NOFILE.soft was 20000, leaving it alone") with patch_r("getrlimit", return_value=(10, 30000)) as gr: with patch_r("setrlimit", side_effect=TypeError("other")): with patch_r("log.err") as err: increase_rlimits() self.assertEqual(err.mock_calls, [mock.call()]) checklog("changing RLIMIT_NOFILE from (10,30000) to (30000,30000)", "other error during setrlimit, leaving it alone") for maxlimit in [40000, 20000, 9000, 2000, 1000]: def setrlimit(which, newlimit): if newlimit[0] > maxlimit: raise ValueError("nope") return None calls = [] expected = [] for tries in [30000, 10000, 3200, 1024]: calls.append(mock.call(NF, (tries, 30000))) expected.append("changing RLIMIT_NOFILE from (10,30000) to (%d,30000)" % tries) if tries > maxlimit: expected.append("error during setrlimit: nope") else: expected.append("setrlimit successful") break else: expected.append("unable to change rlimit, leaving it alone") with patch_r("setrlimit", side_effect=setrlimit) as sr: increase_rlimits() self.assertEqual(sr.mock_calls, calls) checklog(*expected) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1661535598.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/test_service.py0000644000175000017500000000503714302202556030144 0ustar00meejahmeejahfrom twisted.trial import unittest from unittest import mock from twisted.application.service import MultiService from autobahn.twisted.websocket import WebSocketServerFactory from .. import server_tap class Service(unittest.TestCase): def test_defaults(self): o = server_tap.Options() o.parseOptions([]) with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: s = server_tap.makeService(o) self.assertEqual(t.mock_calls, [mock.call(blur_usage=None, log_file=None, usage_db=None)]) self.assertIsInstance(s, MultiService) def test_blur(self): o = server_tap.Options() o.parseOptions(["--blur-usage=60"]) with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: server_tap.makeService(o) self.assertEqual(t.mock_calls, [mock.call(blur_usage=60, log_file=None, usage_db=None)]) def test_log_fd(self): o = server_tap.Options() o.parseOptions(["--log-fd=99"]) fd = object() with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: with mock.patch("wormhole_transit_relay.server_tap.os.fdopen", return_value=fd) as f: server_tap.makeService(o) self.assertEqual(f.mock_calls, [mock.call(99, "w")]) self.assertEqual(t.mock_calls, [mock.call(blur_usage=None, log_file=fd, usage_db=None)]) def test_websocket(self): """ A websocket factory is created when passing --websocket """ o = server_tap.Options() o.parseOptions(["--websocket=tcp:4004"]) services = server_tap.makeService(o) self.assertTrue( any( isinstance(s.factory, WebSocketServerFactory) for s in services.services ) ) def test_websocket_explicit_url(self): """ A websocket factory is created with --websocket and --websocket-url """ o = server_tap.Options() o.parseOptions([ "--websocket=tcp:4004", "--websocket-url=ws://example.com:4004", ]) services = server_tap.makeService(o) self.assertTrue( any( isinstance(s.factory, WebSocketServerFactory) for s in services.services ) ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1661535598.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/test_stats.py0000644000175000017500000001113414302202556027635 0ustar00meejahmeejahimport os, io, json from unittest import mock from twisted.trial import unittest from ..transit_server import Transit from ..usage import create_usage_tracker from .. import database class DB(unittest.TestCase): def test_db(self): T = 1519075308.0 class Timer: t = T def __call__(self): return self.t get_time = Timer() d = self.mktemp() os.mkdir(d) usage_db = os.path.join(d, "usage.sqlite") db = database.get_db(usage_db) t = Transit( create_usage_tracker(blur_usage=None, log_file=None, usage_db=db), get_time, ) self.assertEqual(len(t.usage._backends), 1) usage = list(t.usage._backends)[0] get_time.t = T + 1 usage.record_usage(started=123, mood="happy", total_bytes=100, total_time=10, waiting_time=2) t.update_stats() self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, total_bytes=100, total_time=10, waiting_time=2), ]) self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), dict(rebooted=T+0, updated=T+1, incomplete_bytes=0, waiting=0, connected=0)) get_time.t = T + 2 usage.record_usage(started=150, mood="errory", total_bytes=200, total_time=11, waiting_time=3) t.update_stats() self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, total_bytes=100, total_time=10, waiting_time=2), dict(result="errory", started=150, total_bytes=200, total_time=11, waiting_time=3), ]) self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), dict(rebooted=T+0, updated=T+2, incomplete_bytes=0, waiting=0, connected=0)) get_time.t = T + 3 t.update_stats() self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), dict(rebooted=T+0, updated=T+3, incomplete_bytes=0, waiting=0, connected=0)) def test_no_db(self): t = Transit( create_usage_tracker(blur_usage=None, log_file=None, usage_db=None), lambda: 0, ) self.assertEqual(0, len(t.usage._backends)) class LogToStdout(unittest.TestCase): def test_log(self): # emit lines of JSON to log_file, if set log_file = io.StringIO() t = Transit( create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None), lambda: 0, ) with mock.patch("time.time", return_value=133): t.usage.record( started=123, buddy_started=125, result="happy", bytes_sent=100, buddy_bytes=0, ) self.assertEqual(json.loads(log_file.getvalue()), {"started": 123, "total_time": 10, "waiting_time": 2, "total_bytes": 100, "mood": "happy"}) def test_log_blurred(self): # if blurring is enabled, timestamps should be rounded to the # requested amount, and sizes should be rounded up too log_file = io.StringIO() t = Transit( create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None), lambda: 0, ) with mock.patch("time.time", return_value=123 + 10): t.usage.record( started=123, buddy_started=125, result="happy", bytes_sent=11999, buddy_bytes=0, ) print(log_file.getvalue()) self.assertEqual(json.loads(log_file.getvalue()), {"started": 120, "total_time": 10, "waiting_time": 2, "total_bytes": 20000, "mood": "happy"}) def test_do_not_log(self): t = Transit( create_usage_tracker(blur_usage=60, log_file=None, usage_db=None), lambda: 0, ) t.usage.record( started=123, buddy_started=124, result="happy", bytes_sent=11999, buddy_bytes=12, ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1665609062.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/test/test_transit_server.py0000644000175000017500000005151314321626546031570 0ustar00meejahmeejahfrom binascii import hexlify from twisted.trial import unittest from twisted.test import iosim from autobahn.twisted.websocket import ( WebSocketServerFactory, WebSocketClientFactory, WebSocketClientProtocol, ) from autobahn.twisted.testing import ( create_pumper, MemoryReactorClockResolver, ) from autobahn.exception import Disconnected from zope.interface import implementer from .common import ( ServerBase, IRelayTestClient, ) from ..usage import ( MemoryUsageRecorder, blur_size, ) from ..transit_server import ( WebSocketTransitConnection, TransitServerState, ) def handshake(token, side=None): hs = b"please relay " + hexlify(token) if side is not None: hs += b" for side " + hexlify(side) hs += b"\n" return hs class _Transit: def count(self): return sum([ len(potentials) for potentials in self._transit_server.pending_requests._requests.values() ]) def test_blur_size(self): self.failUnlessEqual(blur_size(0), 0) self.failUnlessEqual(blur_size(1), 10e3) self.failUnlessEqual(blur_size(10e3), 10e3) self.failUnlessEqual(blur_size(10e3+1), 20e3) self.failUnlessEqual(blur_size(15e3), 20e3) self.failUnlessEqual(blur_size(20e3), 20e3) self.failUnlessEqual(blur_size(1e6), 1e6) self.failUnlessEqual(blur_size(1e6+1), 2e6) self.failUnlessEqual(blur_size(1.5e6), 2e6) self.failUnlessEqual(blur_size(2e6), 2e6) self.failUnlessEqual(blur_size(900e6), 900e6) self.failUnlessEqual(blur_size(1000e6), 1000e6) self.failUnlessEqual(blur_size(1050e6), 1100e6) self.failUnlessEqual(blur_size(1100e6), 1100e6) self.failUnlessEqual(blur_size(1150e6), 1200e6) def test_register(self): p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 p1.send(handshake(token1, side1)) self.flush() self.assertEqual(self.count(), 1) p1.disconnect() self.flush() self.assertEqual(self.count(), 0) # the token should be removed too self.assertEqual(len(self._transit_server.pending_requests._requests), 0) def test_both_unsided(self): p1 = self.new_protocol() p2 = self.new_protocol() token1 = b"\x00"*32 p1.send(handshake(token1, side=None)) self.flush() p2.send(handshake(token1, side=None)) self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" self.assertEqual(p1.get_received_data(), exp) self.assertEqual(p2.get_received_data(), exp) p1.reset_received_data() p2.reset_received_data() s1 = b"data1" p1.send(s1) self.flush() self.assertEqual(p2.get_received_data(), s1) p1.disconnect() self.flush() def test_sided_unsided(self): p1 = self.new_protocol() p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 p1.send(handshake(token1, side=side1)) self.flush() p2.send(handshake(token1, side=None)) self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" self.assertEqual(p1.get_received_data(), exp) self.assertEqual(p2.get_received_data(), exp) p1.reset_received_data() p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() self.assertEqual(p2.get_received_data(), s1) p1.disconnect() self.flush() def test_unsided_sided(self): p1 = self.new_protocol() p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 p1.send(handshake(token1, side=None)) p2.send(handshake(token1, side=side1)) self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" self.assertEqual(p1.get_received_data(), exp) self.assertEqual(p2.get_received_data(), exp) p1.reset_received_data() p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() def test_both_sided(self): p1 = self.new_protocol() p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 p1.send(handshake(token1, side=side1)) self.flush() p2.send(handshake(token1, side=side2)) self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" self.assertEqual(p1.get_received_data(), exp) self.assertEqual(p2.get_received_data(), exp) p1.reset_received_data() p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() def test_ignore_same_side(self): p1 = self.new_protocol() p2 = self.new_protocol() p3 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 p1.send(handshake(token1, side=side1)) self.flush() self.assertEqual(self.count(), 1) p2.send(handshake(token1, side=side1)) self.flush() self.flush() self.assertEqual(self.count(), 2) # same-side connections don't match # when the second side arrives, the spare first connection should be # closed side2 = b"\x02"*8 p3.send(handshake(token1, side=side2)) self.flush() self.assertEqual(self.count(), 0) self.assertEqual(len(self._transit_server.pending_requests._requests), 0) self.assertEqual(len(self._transit_server.active_connections._connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) p1.disconnect() p2.disconnect() p3.disconnect() def test_bad_handshake_old(self): p1 = self.new_protocol() token1 = b"\x00"*32 p1.send(b"please DELAY " + hexlify(token1) + b"\n") self.flush() exp = b"bad handshake\n" self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_bad_handshake_old_slow(self): p1 = self.new_protocol() p1.send(b"please DELAY ") self.flush() # As in test_impatience_new_slow, the current state machine has code # that can only be reached if we insert a stall here, so dataReceived # gets called twice. Hopefully we can delete this test once # dataReceived is refactored to remove that state. token1 = b"\x00"*32 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. p1.send(hexlify(token1) + b"\n") self.flush() exp = b"bad handshake\n" self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_bad_handshake_new(self): p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. p1.send(b"please DELAY " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") self.flush() exp = b"bad handshake\n" self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_binary_handshake(self): p1 = self.new_protocol() binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" # the embedded \n makes the server trigger early, before the full # expected handshake length has arrived. A non-wormhole client # writing non-ascii junk to the transit port used to trigger a # UnicodeDecodeError when it tried to coerce the incoming handshake # to unicode, due to the ("\n" in buf) check. This was fixed to use # (b"\n" in buf). This exercises the old failure. p1.send(binary_bad_handshake) self.flush() exp = b"bad handshake\n" self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_impatience_old(self): p1 = self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. p1.send(b"please relay " + hexlify(token1)) p1.send(b"\nNOWNOWNOW") self.flush() exp = b"impatient\n" self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_impatience_new(self): p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. p1.send(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1)) p1.send(b"\nNOWNOWNOW") self.flush() exp = b"impatient\n" self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_impatience_new_slow(self): p1 = self.new_protocol() # For full coverage, we need dataReceived to see a particular framing # of these two pieces of data, and ITCPTransport doesn't have flush() # (which probably wouldn't work anyways). For now, force a 100ms # stall between the two writes. I tried setTcpNoDelay(True) but it # didn't seem to help without the stall. The long-term fix is to # rewrite dataReceived() to remove the multiple "impatient" # codepaths, deleting the particular clause that this test exercises, # then remove this test. token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. p1.send(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") self.flush() p1.send(b"NOWNOWNOW") self.flush() exp = b"impatient\n" self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_short_handshake(self): p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") self.flush() p1.disconnect() def test_empty_handshake(self): p1 = self.new_protocol() # hang up before sending anything p1.disconnect() class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): log_requests = True def new_protocol(self): return self.new_protocol_tcp() class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False def new_protocol(self): return self.new_protocol_tcp() def _new_protocol_ws(transit_server, log_requests): """ Internal helper for test-suites that need to provide WebSocket client/server pairs. :returns: a 2-tuple: (iosim.IOPump, protocol) """ ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = transit_server ws_factory.log_requests = log_requests ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) @implementer(IRelayTestClient) class TransitWebSocketClientProtocol(WebSocketClientProtocol): _received = b"" connected = False def connectionMade(self): self.connected = True return super(TransitWebSocketClientProtocol, self).connectionMade() def connectionLost(self, reason): self.connected = False return super(TransitWebSocketClientProtocol, self).connectionLost(reason) def onMessage(self, data, isBinary): self._received = self._received + data def send(self, data): self.sendMessage(data, True) def get_received_data(self): return self._received def reset_received_data(self): self._received = b"" def disconnect(self): self.sendClose(1000, True) client_factory = WebSocketClientFactory() client_factory.protocol = TransitWebSocketClientProtocol client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) client_protocol.disconnect = client_protocol.dropConnection pump = iosim.connect( ws_protocol, iosim.makeFakeServer(ws_protocol), client_protocol, iosim.makeFakeClient(client_protocol), ) return pump, client_protocol class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): def new_protocol(self): return self.new_protocol_ws() def new_protocol_ws(self): pump, proto = _new_protocol_ws(self._transit_server, self.log_requests) self._pumps.append(pump) return proto def test_websocket_to_tcp(self): """ One client is WebSocket and one is TCP """ p1 = self.new_protocol_ws() p2 = self.new_protocol_tcp() token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 p1.send(handshake(token1, side=side1)) self.flush() p2.send(handshake(token1, side=side2)) self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" self.assertEqual(p1.get_received_data(), exp) self.assertEqual(p2.get_received_data(), exp) p1.reset_received_data() p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() self.flush() def test_bad_handshake_old_slow(self): """ This test only makes sense for TCP """ def test_send_closed_partner(self): """ Sending data to a closed partner causes an error that propogates to the sender. """ p1 = self.new_protocol() p2 = self.new_protocol() # set up a successful connection token = b"a" * 32 p1.send(handshake(token)) p2.send(handshake(token)) self.flush() # p2 loses connection, then p1 sends a message p2.transport.loseConnection() self.flush() # at this point, p1 learns that p2 is disconnected (because it # tried to relay "a message" but failed) # try to send more (our partner p2 is gone now though so it # should be an immediate error) with self.assertRaises(Disconnected): p1.send(b"more message") self.flush() class Usage(ServerBase, unittest.TestCase): log_requests = True def setUp(self): super(Usage, self).setUp() self._usage = MemoryUsageRecorder() self._transit_server.usage.add_backend(self._usage) def new_protocol(self): return self.new_protocol_tcp() def test_empty(self): p1 = self.new_protocol() # hang up before sending anything p1.disconnect() self.flush() # that will log the "empty" usage event self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) def test_short(self): # Note: this test only runs on TCP clients because WebSockets # already does framing (so it's either "a bad handshake" or # there's no handshake at all yet .. you can't have a "short" # one). p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") p1.disconnect() self.flush() # that will log the "empty" usage event self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual("empty", self._usage.events[0]["mood"]) def test_errory(self): p1 = self.new_protocol() p1.send(b"this is a very bad handshake\n") self.flush() # that will log the "errory" usage event, then drop the connection p1.disconnect() self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage) def test_lonely(self): p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 p1.send(handshake(token1, side=side1)) self.flush() # now we disconnect before the peer connects p1.disconnect() self.flush() self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage) self.assertIdentical(self._usage.events[0]["waiting_time"], None) def test_one_happy_one_jilted(self): p1 = self.new_protocol() p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 p1.send(handshake(token1, side=side1)) self.flush() p2.send(handshake(token1, side=side2)) self.flush() self.assertEqual(self._usage.events, []) # no events yet p1.send(b"\x00" * 13) self.flush() p2.send(b"\xff" * 7) self.flush() p1.disconnect() self.flush() self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "happy", self._usage) self.assertEqual(self._usage.events[0]["total_bytes"], 20) self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) def test_redundant(self): p1a = self.new_protocol() p1b = self.new_protocol() p1c = self.new_protocol() p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 p1a.send(handshake(token1, side=side1)) self.flush() p1b.send(handshake(token1, side=side1)) self.flush() # connect and disconnect a third client (for side1) to exercise the # code that removes a pending connection without removing the entire # token p1c.send(handshake(token1, side=side1)) p1c.disconnect() self.flush() self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "lonely") p2.send(handshake(token1, side=side2)) self.flush() self.assertEqual(len(self._transit_server.pending_requests._requests), 0) self.assertEqual(len(self._usage.events), 2, self._usage) self.assertEqual(self._usage.events[1]["mood"], "redundant") # one of the these is unecessary, but probably harmless p1a.disconnect() p1b.disconnect() self.flush() self.assertEqual(len(self._usage.events), 3, self._usage) self.assertEqual(self._usage.events[2]["mood"], "happy") class UsageWebSockets(Usage): """ All the tests of 'Usage' except with a WebSocket (instead of TCP) transport. This overrides ServerBase.new_protocol to achieve this. It might be nicer to parametrize these tests in a way that doesn't use inheritance .. but all the support etc classes are set up that way already. """ def setUp(self): super(UsageWebSockets, self).setUp() self._pump = create_pumper() self._reactor = MemoryReactorClockResolver() return self._pump.start() def tearDown(self): return self._pump.stop() def new_protocol(self): return self.new_protocol_ws() def new_protocol_ws(self): pump, proto = _new_protocol_ws(self._transit_server, self.log_requests) self._pumps.append(pump) return proto def test_short(self): """ This test essentially just tests the framing of the line-oriented TCP protocol; it doesnt' make sense for the WebSockets case because WS handles frameing: you either sent a 'bad handshake' because it is semantically invalid or no handshake (yet). """ def test_send_non_binary_message(self): """ A non-binary WebSocket message is an error """ ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) with self.assertRaises(ValueError): ws_protocol.onMessage(u"foo", isBinary=False) class State(unittest.TestCase): """ Tests related to server_state.TransitServerState """ def setUp(self): self.state = TransitServerState(None, None) def test_empty_token(self): self.assertEqual( "-", self.state.get_token(), ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027250.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/transit_server.py0000644000175000017500000002127214660510562027545 0ustar00meejahmeejahimport re import time from twisted.python import log from twisted.protocols.basic import LineReceiver from autobahn.twisted.websocket import WebSocketServerProtocol SECONDS = 1.0 MINUTE = 60*SECONDS HOUR = 60*MINUTE DAY = 24*HOUR MB = 1000*1000 from wormhole_transit_relay.server_state import ( TransitServerState, PendingRequests, ActiveConnections, ITransitClient, ) from zope.interface import implementer @implementer(ITransitClient) class TransitConnection(LineReceiver): delimiter = b'\n' # maximum length of a line we will accept before the handshake is complete. # This must be >= to the longest possible handshake message. MAX_LENGTH = 1024 started_time = None def send(self, data): """ ITransitClient API """ self.transport.write(data) def disconnect(self): """ ITransitClient API """ self.transport.loseConnection() def connect_partner(self, other): """ ITransitClient API """ self._buddy = other self._buddy._client.transport.registerProducer(self.transport, True) def disconnect_partner(self): """ ITransitClient API """ assert self._buddy is not None, "internal error: no buddy" if self.factory.log_requests: log.msg("buddy_disconnected {}".format(self._buddy.get_token())) self._buddy._client.disconnect() self._buddy = None def connectionMade(self): # ideally more like self._reactor.seconds() ... but Twisted # doesn't have a good way to get the reactor for a protocol # (besides "use the global one") self.started_time = time.time() self._state = TransitServerState( self.factory.transit.pending_requests, self.factory.transit.usage, ) self._state.connection_made(self) self.transport.setTcpKeepAlive(True) # uncomment to turn on state-machine tracing # def tracer(oldstate, theinput, newstate): # print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) # self._state.set_trace_function(tracer) def lineReceived(self, line): """ LineReceiver API """ # old: "please relay {64}\n" token = None old = re.search(br"^please relay (\w{64})$", line) if old: token = old.group(1) self._state.please_relay(token) # new: "please relay {64} for side {16}\n" new = re.search(br"^please relay (\w{64}) for side (\w{16})$", line) if new: token = new.group(1) side = new.group(2) self._state.please_relay_for_side(token, side) if token is None: self._state.bad_token() else: self.setRawMode() def rawDataReceived(self, data): """ LineReceiver API """ # We are an IPushProducer to our buddy's IConsumer, so they'll # throttle us (by calling pauseProducing()) when their outbound # buffer is full (e.g. when their downstream pipe is full). In # practice, this buffers about 10MB per connection, after which # point the sender will only transmit data as fast as the # receiver can handle it. self._state.got_bytes(data) def connectionLost(self, reason): self._state.connection_lost() class Transit(object): """ I manage pairs of simultaneous connections to a secondary TCP port, both forwarded to the other. Clients must begin each connection with "please relay TOKEN for SIDE\n" (or a legacy form without the "for SIDE"). Two connections match if they use the same TOKEN and have different SIDEs (the redundant connections are dropped when a match is made). Legacy connections match any with the same TOKEN, ignoring SIDE (so two legacy connections will match each other). I will send "ok\n" when the matching connection is established, or disconnect if no matching connection is made within MAX_WAIT_TIME seconds. I will disconnect if you send data before the "ok\n". All data you get after the "ok\n" will be from the other side. You will not receive "ok\n" until the other side has also connected and submitted a matching token (and differing SIDE). In addition, the connections will be dropped after MAXLENGTH bytes have been sent by either side, or MAXTIME seconds have elapsed after the matching connections were established. A future API will reveal these limits to clients instead of causing mysterious spontaneous failures. These relay connections are not half-closeable (unlike full TCP connections, applications will not receive any data after half-closing their outgoing side). Applications must negotiate shutdown with their peer and not close the connection until all data has finished transferring in both directions. Applications which only need to send data in one direction can use close() as usual. """ # TODO: unused MAX_WAIT_TIME = 30*SECONDS # TODO: unused MAXLENGTH = 10*MB # TODO: unused MAXTIME = 60*SECONDS def __init__(self, usage, get_timestamp): self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage self._timestamp = get_timestamp self._rebooted = self._timestamp() def update_stats(self): # TODO: when a connection is half-closed, len(active) will be odd. a # moment later (hopefully) the other side will disconnect, but # _update_stats isn't updated until later. # "waiting" doesn't count multiple parallel connections from the same # side self.usage.update_stats( rebooted=self._rebooted, updated=self._timestamp(), connected=len(self.active_connections._connections), waiting=len(self.pending_requests._requests), incomplete_bytes=sum( tc._total_sent for tc in self.active_connections._connections ), ) @implementer(ITransitClient) class WebSocketTransitConnection(WebSocketServerProtocol): started_time = None def send(self, data): """ ITransitClient API """ self.sendMessage(data, isBinary=True) def disconnect(self): """ ITransitClient API """ self.sendClose(1000, None) def connect_partner(self, other): """ ITransitClient API """ self._buddy = other self._buddy._client.transport.registerProducer(self.transport, True) def disconnect_partner(self): """ ITransitClient API """ assert self._buddy is not None, "internal error: no buddy" if self.factory.log_requests: log.msg("buddy_disconnected {}".format(self._buddy.get_token())) self._buddy._client.disconnect() self._buddy = None def connectionMade(self): """ IProtocol API """ super(WebSocketTransitConnection, self).connectionMade() self.started_time = time.time() self._first_message = True self._state = TransitServerState( self.factory.transit.pending_requests, self.factory.transit.usage, ) # uncomment to turn on state-machine tracing # def tracer(oldstate, theinput, newstate): # print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) # self._state.set_trace_function(tracer) def onOpen(self): self._state.connection_made(self) def onMessage(self, payload, isBinary): """ We may have a 'handshake' on our hands or we may just have some bytes to relay """ if not isBinary: raise ValueError( "All messages must be binary" ) if self._first_message: self._first_message = False token = None old = re.search(br"^please relay (\w{64})$", payload) if old: token = old.group(1) self._state.please_relay(token) # new: "please relay {64} for side {16}\n" new = re.search(br"^please relay (\w{64}) for side (\w{16})$", payload) if new: token = new.group(1) side = new.group(2) self._state.please_relay_for_side(token, side) if token is None: self._state.bad_token() else: self._state.got_bytes(payload) def onClose(self, wasClean, code, reason): """ IWebSocketChannel API """ self._state.connection_lost() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1661535598.0 magic-wormhole-transit-relay-0.3.1/src/wormhole_transit_relay/usage.py0000644000175000017500000001645014302202556025573 0ustar00meejahmeejahimport time import json from twisted.python import log from zope.interface import ( implementer, Interface, ) def create_usage_tracker(blur_usage, log_file, usage_db): """ :param int blur_usage: see UsageTracker :param log_file: None or a file-like object to write JSON-encoded lines of usage information to. :param usage_db: None or an sqlite3 database connection :returns: a new UsageTracker instance configured with backends. """ tracker = UsageTracker(blur_usage) if usage_db: tracker.add_backend(DatabaseUsageRecorder(usage_db)) if log_file: tracker.add_backend(LogFileUsageRecorder(log_file)) return tracker class IUsageWriter(Interface): """ Records actual usage statistics in some way """ def record_usage(started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): """ :param int started: timestemp when this connection began :param float total_time: total seconds this connection lasted :param float waiting_time: None or the total seconds one side waited for the other :param int total_bytes: the total bytes sent. In case the connection was concluded successfully, only one side will record the total bytes (but count both). :param str mood: the 'mood' of the connection """ @implementer(IUsageWriter) class MemoryUsageRecorder: """ Remebers usage records in memory. """ def __init__(self): self.events = [] def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): """ IUsageWriter. """ data = { "started": started, "total_time": total_time, "waiting_time": waiting_time, "total_bytes": total_bytes, "mood": mood, } self.events.append(data) @implementer(IUsageWriter) class LogFileUsageRecorder: """ Writes usage records to a file. The records are written in JSON, one record per line. """ def __init__(self, writable_file): self._file = writable_file def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): """ IUsageWriter. """ data = { "started": started, "total_time": total_time, "waiting_time": waiting_time, "total_bytes": total_bytes, "mood": mood, } self._file.write(json.dumps(data) + "\n") self._file.flush() @implementer(IUsageWriter) class DatabaseUsageRecorder: """ Write usage records into a database """ def __init__(self, db): self._db = db def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): """ IUsageWriter. """ self._db.execute( "INSERT INTO `usage`" " (`started`, `total_time`, `waiting_time`," " `total_bytes`, `result`)" " VALUES (?,?,?,?,?)", (started, total_time, waiting_time, total_bytes, mood) ) # original code did "self._update_stats()" here, thus causing # "global" stats update on every connection update .. should # we repeat this behavior, or really only record every # 60-seconds with the timer? self._db.commit() class UsageTracker(object): """ Tracks usage statistics of connections """ def __init__(self, blur_usage): """ :param int blur_usage: None or the number of seconds to use as a window around which to blur time statistics (e.g. "60" means times will be rounded to 1 minute intervals). When blur_usage is non-zero, sizes will also be rounded into buckets of "one megabyte", "one gigabyte" or "lots" """ self._backends = set() self._blur_usage = blur_usage if blur_usage: log.msg("blurring access times to %d seconds" % self._blur_usage) else: log.msg("not blurring access times") def add_backend(self, backend): """ Add a new backend. :param IUsageWriter backend: the backend to add """ self._backends.add(backend) def record(self, started, buddy_started, result, bytes_sent, buddy_bytes): """ :param int started: timestamp when our connection started :param int buddy_started: None, or the timestamp when our partner's connection started (will be None if we don't yet have a partner). :param str result: a label for the result of the connection (one of the "moods"). :param int bytes_sent: number of bytes we sent :param int buddy_bytes: number of bytes our partner sent """ # ideally self._reactor.seconds() or similar, but .. finished = time.time() if buddy_started is not None: starts = [started, buddy_started] total_time = finished - min(starts) waiting_time = max(starts) - min(starts) total_bytes = bytes_sent + buddy_bytes else: total_time = finished - started waiting_time = None total_bytes = bytes_sent # note that "bytes_sent" should always be 0 here, but # we're recording what the state-machine remembered in any # case if self._blur_usage: started = self._blur_usage * (started // self._blur_usage) total_bytes = blur_size(total_bytes) # This is "a dict" instead of "kwargs" because we have to make # it into a dict for the log use-case and in-memory/testing # use-case anyway so this is less repeats of the names. self._notify_backends({ "started": started, "total_time": total_time, "waiting_time": waiting_time, "total_bytes": total_bytes, "mood": result, }) def update_stats(self, rebooted, updated, connected, waiting, incomplete_bytes): """ Update general statistics. """ # in original code, this is only recorded in the database # .. perhaps a better way to do this, but .. for backend in self._backends: if isinstance(backend, DatabaseUsageRecorder): backend._db.execute("DELETE FROM `current`") backend._db.execute( "INSERT INTO `current`" " (`rebooted`, `updated`, `connected`, `waiting`," " `incomplete_bytes`)" " VALUES (?, ?, ?, ?, ?)", (int(rebooted), int(updated), connected, waiting, incomplete_bytes) ) def _notify_backends(self, data): """ Internal helper. Tell every backend we have about a new usage record. """ for backend in self._backends: backend.record_usage(**data) def round_to(size, coarseness): return int(coarseness*(1+int((size-1)/coarseness))) def blur_size(size): if size == 0: return 0 if size < 1e6: return round_to(size, 10e3) if size < 1e9: return round_to(size, 1e6) return round_to(size, 100e6) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618510667.0 magic-wormhole-transit-relay-0.3.1/tox.ini0000644000175000017500000000127414036101513020036 0ustar00meejahmeejah# Tox (http://tox.testrun.org/) is a tool for running tests # in multiple virtualenvs. This configuration file will run the # test suite on all supported python versions. To use it, "pip install tox" # and then run "tox" from this directory. [tox] envlist = {py37,py38,py39,py310,pypy} skip_missing_interpreters = True minversion = 2.4.0 [testenv] usedevelop = True extras = dev deps = pyflakes >= 1.2.3 commands = pyflakes setup.py src python -m twisted.trial {posargs:wormhole_transit_relay} [testenv:coverage] deps = pyflakes >= 1.2.3 coverage commands = pyflakes setup.py src coverage run --branch -m twisted.trial {posargs:wormhole_transit_relay} coverage xml ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1724027226.0 magic-wormhole-transit-relay-0.3.1/versioneer.py0000644000175000017500000025122514660510532021271 0ustar00meejahmeejah # Version: 0.29 """The Versioneer - like a rocketeer, but for versions. The Versioneer ============== * like a rocketeer, but for versions! * https://github.com/python-versioneer/python-versioneer * Brian Warner * License: Public Domain (Unlicense) * Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3 * [![Latest Version][pypi-image]][pypi-url] * [![Build Status][travis-image]][travis-url] This is a tool for managing a recorded version number in setuptools-based python projects. The goal is to remove the tedious and error-prone "update the embedded version string" step from your release process. Making a new release should be as easy as recording a new tag in your version-control system, and maybe making new tarballs. ## Quick Install Versioneer provides two installation modes. The "classic" vendored mode installs a copy of versioneer into your repository. The experimental build-time dependency mode is intended to allow you to skip this step and simplify the process of upgrading. ### Vendored mode * `pip install versioneer` to somewhere in your $PATH * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is available, so you can also use `conda install -c conda-forge versioneer` * add a `[tool.versioneer]` section to your `pyproject.toml` or a `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) * Note that you will need to add `tomli; python_version < "3.11"` to your build-time dependencies if you use `pyproject.toml` * run `versioneer install --vendor` in your source tree, commit the results * verify version information with `python setup.py version` ### Build-time dependency mode * `pip install versioneer` to somewhere in your $PATH * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is available, so you can also use `conda install -c conda-forge versioneer` * add a `[tool.versioneer]` section to your `pyproject.toml` or a `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) * add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`) to the `requires` key of the `build-system` table in `pyproject.toml`: ```toml [build-system] requires = ["setuptools", "versioneer[toml]"] build-backend = "setuptools.build_meta" ``` * run `versioneer install --no-vendor` in your source tree, commit the results * verify version information with `python setup.py version` ## Version Identifiers Source trees come from a variety of places: * a version-control system checkout (mostly used by developers) * a nightly tarball, produced by build automation * a snapshot tarball, produced by a web-based VCS browser, like github's "tarball from tag" feature * a release tarball, produced by "setup.py sdist", distributed through PyPI Within each source tree, the version identifier (either a string or a number, this tool is format-agnostic) can come from a variety of places: * ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows about recent "tags" and an absolute revision-id * the name of the directory into which the tarball was unpacked * an expanded VCS keyword ($Id$, etc) * a `_version.py` created by some earlier build step For released software, the version identifier is closely related to a VCS tag. Some projects use tag names that include more than just the version string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool needs to strip the tag prefix to extract the version identifier. For unreleased software (between tags), the version identifier should provide enough information to help developers recreate the same tree, while also giving them an idea of roughly how old the tree is (after version 1.2, before version 1.3). Many VCS systems can report a description that captures this, for example `git describe --tags --dirty --always` reports things like "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has uncommitted changes). The version identifier is used for multiple purposes: * to allow the module to self-identify its version: `myproject.__version__` * to choose a name and prefix for a 'setup.py sdist' tarball ## Theory of Operation Versioneer works by adding a special `_version.py` file into your source tree, where your `__init__.py` can import it. This `_version.py` knows how to dynamically ask the VCS tool for version information at import time. `_version.py` also contains `$Revision$` markers, and the installation process marks `_version.py` to have this marker rewritten with a tag name during the `git archive` command. As a result, generated tarballs will contain enough information to get the proper version. To allow `setup.py` to compute a version too, a `versioneer.py` is added to the top level of your source tree, next to `setup.py` and the `setup.cfg` that configures it. This overrides several distutils/setuptools commands to compute the version when invoked, and changes `setup.py build` and `setup.py sdist` to replace `_version.py` with a small static file that contains just the generated version data. ## Installation See [INSTALL.md](./INSTALL.md) for detailed installation instructions. ## Version-String Flavors Code which uses Versioneer can learn about its version string at runtime by importing `_version` from your main `__init__.py` file and running the `get_versions()` function. From the "outside" (e.g. in `setup.py`), you can import the top-level `versioneer.py` and run `get_versions()`. Both functions return a dictionary with different flavors of version information: * `['version']`: A condensed version string, rendered using the selected style. This is the most commonly used value for the project's version string. The default "pep440" style yields strings like `0.11`, `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section below for alternative styles. * `['full-revisionid']`: detailed revision identifier. For Git, this is the full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". * `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the commit date in ISO 8601 format. This will be None if the date is not available. * `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that this is only accurate if run in a VCS checkout, otherwise it is likely to be False or None * `['error']`: if the version string could not be computed, this will be set to a string describing the problem, otherwise it will be None. It may be useful to throw an exception in setup.py if this is set, to avoid e.g. creating tarballs with a version string of "unknown". Some variants are more useful than others. Including `full-revisionid` in a bug report should allow developers to reconstruct the exact code being tested (or indicate the presence of local changes that should be shared with the developers). `version` is suitable for display in an "about" box or a CLI `--version` output: it can be easily compared against release notes and lists of bugs fixed in various releases. The installer adds the following text to your `__init__.py` to place a basic version in `YOURPROJECT.__version__`: from ._version import get_versions __version__ = get_versions()['version'] del get_versions ## Styles The setup.cfg `style=` configuration controls how the VCS information is rendered into a version string. The default style, "pep440", produces a PEP440-compliant string, equal to the un-prefixed tag name for actual releases, and containing an additional "local version" section with more detail for in-between builds. For Git, this is TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags --dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and that this commit is two revisions ("+2") beyond the "0.11" tag. For released software (exactly equal to a known tag), the identifier will only contain the stripped tag, e.g. "0.11". Other styles are available. See [details.md](details.md) in the Versioneer source tree for descriptions. ## Debugging Versioneer tries to avoid fatal errors: if something goes wrong, it will tend to return a version of "0+unknown". To investigate the problem, run `setup.py version`, which will run the version-lookup code in a verbose mode, and will display the full contents of `get_versions()` (including the `error` string, which may help identify what went wrong). ## Known Limitations Some situations are known to cause problems for Versioneer. This details the most significant ones. More can be found on Github [issues page](https://github.com/python-versioneer/python-versioneer/issues). ### Subprojects Versioneer has limited support for source trees in which `setup.py` is not in the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are two common reasons why `setup.py` might not be in the root: * Source trees which contain multiple subprojects, such as [Buildbot](https://github.com/buildbot/buildbot), which contains both "master" and "slave" subprojects, each with their own `setup.py`, `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI distributions (and upload multiple independently-installable tarballs). * Source trees whose main purpose is to contain a C library, but which also provide bindings to Python (and perhaps other languages) in subdirectories. Versioneer will look for `.git` in parent directories, and most operations should get the right version string. However `pip` and `setuptools` have bugs and implementation details which frequently cause `pip install .` from a subproject directory to fail to find a correct version string (so it usually defaults to `0+unknown`). `pip install --editable .` should work correctly. `setup.py install` might work too. Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in some later version. [Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking this issue. The discussion in [PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the issue from the Versioneer side in more detail. [pip PR#3176](https://github.com/pypa/pip/pull/3176) and [pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve pip to let Versioneer work correctly. Versioneer-0.16 and earlier only looked for a `.git` directory next to the `setup.cfg`, so subprojects were completely unsupported with those releases. ### Editable installs with setuptools <= 18.5 `setup.py develop` and `pip install --editable .` allow you to install a project into a virtualenv once, then continue editing the source code (and test) without re-installing after every change. "Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a convenient way to specify executable scripts that should be installed along with the python package. These both work as expected when using modern setuptools. When using setuptools-18.5 or earlier, however, certain operations will cause `pkg_resources.DistributionNotFound` errors when running the entrypoint script, which must be resolved by re-installing the package. This happens when the install happens with one version, then the egg_info data is regenerated while a different version is checked out. Many setup.py commands cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into a different virtualenv), so this can be surprising. [Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes this one, but upgrading to a newer version of setuptools should probably resolve it. ## Updating Versioneer To upgrade your project to a new release of Versioneer, do the following: * install the new Versioneer (`pip install -U versioneer` or equivalent) * edit `setup.cfg` and `pyproject.toml`, if necessary, to include any new configuration settings indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. * re-run `versioneer install --[no-]vendor` in your source tree, to replace `SRC/_version.py` * commit any changed files ## Future Directions This tool is designed to make it easily extended to other version-control systems: all VCS-specific components are in separate directories like src/git/ . The top-level `versioneer.py` script is assembled from these components by running make-versioneer.py . In the future, make-versioneer.py will take a VCS name as an argument, and will construct a version of `versioneer.py` that is specific to the given VCS. It might also take the configuration arguments that are currently provided manually during installation by editing setup.py . Alternatively, it might go the other direction and include code from all supported VCS systems, reducing the number of intermediate scripts. ## Similar projects * [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time dependency * [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of versioneer * [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools plugin ## License To make Versioneer easier to embed, all its code is dedicated to the public domain. The `_version.py` that it creates is also in the public domain. Specifically, both are released under the "Unlicense", as described in https://unlicense.org/. [pypi-image]: https://img.shields.io/pypi/v/versioneer.svg [pypi-url]: https://pypi.python.org/pypi/versioneer/ [travis-image]: https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg [travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer """ # pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring # pylint:disable=missing-class-docstring,too-many-branches,too-many-statements # pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error # pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with # pylint:disable=attribute-defined-outside-init,too-many-arguments import configparser import errno import json import os import re import subprocess import sys from pathlib import Path from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from typing import NoReturn import functools have_tomllib = True if sys.version_info >= (3, 11): import tomllib else: try: import tomli as tomllib except ImportError: have_tomllib = False class VersioneerConfig: """Container for Versioneer configuration parameters.""" VCS: str style: str tag_prefix: str versionfile_source: str versionfile_build: Optional[str] parentdir_prefix: Optional[str] verbose: Optional[bool] def get_root() -> str: """Get the project root directory. We require that all commands are run from the project root, i.e. the directory that contains setup.py, setup.cfg, and versioneer.py . """ root = os.path.realpath(os.path.abspath(os.getcwd())) setup_py = os.path.join(root, "setup.py") pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") if not ( os.path.exists(setup_py) or os.path.exists(pyproject_toml) or os.path.exists(versioneer_py) ): # allow 'python path/to/setup.py COMMAND' root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) setup_py = os.path.join(root, "setup.py") pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") if not ( os.path.exists(setup_py) or os.path.exists(pyproject_toml) or os.path.exists(versioneer_py) ): err = ("Versioneer was unable to run the project root directory. " "Versioneer requires setup.py to be executed from " "its immediate directory (like 'python setup.py COMMAND'), " "or in a way that lets it use sys.argv[0] to find the root " "(like 'python path/to/setup.py COMMAND').") raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools # tree) execute all dependencies in a single python process, so # "versioneer" may be imported multiple times, and python's shared # module-import table will cache the first one. So we can't use # os.path.dirname(__file__), as that will find whichever # versioneer.py was first imported, even in later projects. my_path = os.path.realpath(os.path.abspath(__file__)) me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): print("Warning: build in %s is using versioneer.py from %s" % (os.path.dirname(my_path), versioneer_py)) except NameError: pass return root def get_config_from_root(root: str) -> VersioneerConfig: """Read the project setup.cfg file to determine Versioneer config.""" # This might raise OSError (if setup.cfg is missing), or # configparser.NoSectionError (if it lacks a [versioneer] section), or # configparser.NoOptionError (if it lacks "VCS="). See the docstring at # the top of versioneer.py for instructions on writing your setup.cfg . root_pth = Path(root) pyproject_toml = root_pth / "pyproject.toml" setup_cfg = root_pth / "setup.cfg" section: Union[Dict[str, Any], configparser.SectionProxy, None] = None if pyproject_toml.exists() and have_tomllib: try: with open(pyproject_toml, 'rb') as fobj: pp = tomllib.load(fobj) section = pp['tool']['versioneer'] except (tomllib.TOMLDecodeError, KeyError) as e: print(f"Failed to load config from {pyproject_toml}: {e}") print("Try to load it from setup.cfg") if not section: parser = configparser.ConfigParser() with open(setup_cfg) as cfg_file: parser.read_file(cfg_file) parser.get("versioneer", "VCS") # raise error if missing section = parser["versioneer"] # `cast`` really shouldn't be used, but its simplest for the # common VersioneerConfig users at the moment. We verify against # `None` values elsewhere where it matters cfg = VersioneerConfig() cfg.VCS = section['VCS'] cfg.style = section.get("style", "") cfg.versionfile_source = cast(str, section.get("versionfile_source")) cfg.versionfile_build = section.get("versionfile_build") cfg.tag_prefix = cast(str, section.get("tag_prefix")) if cfg.tag_prefix in ("''", '""', None): cfg.tag_prefix = "" cfg.parentdir_prefix = section.get("parentdir_prefix") if isinstance(section, configparser.SectionProxy): # Make sure configparser translates to bool cfg.verbose = section.getboolean("verbose") else: cfg.verbose = section.get("verbose") return cfg class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" # these dictionaries contain VCS-specific tools LONG_VERSION_PY: Dict[str, str] = {} HANDLERS: Dict[str, Dict[str, Callable]] = {} def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator """Create decorator to mark a method as the handler of a VCS.""" def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" HANDLERS.setdefault(vcs, {})[method] = f return f return decorate def run_command( commands: List[str], args: List[str], cwd: Optional[str] = None, verbose: bool = False, hide_stderr: bool = False, env: Optional[Dict[str, str]] = None, ) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) process = None popen_kwargs: Dict[str, Any] = {} if sys.platform == "win32": # This hides the console window if pythonw.exe is used startupinfo = subprocess.STARTUPINFO() startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW popen_kwargs["startupinfo"] = startupinfo for command in commands: try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git process = subprocess.Popen([command] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None), **popen_kwargs) break except OSError as e: if e.errno == errno.ENOENT: continue if verbose: print("unable to run %s" % dispcmd) print(e) return None, None else: if verbose: print("unable to find command, tried %s" % (commands,)) return None, None stdout = process.communicate()[0].strip().decode() if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) return None, process.returncode return stdout, process.returncode LONG_VERSION_PY['git'] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. # This file is released into the public domain. # Generated by versioneer-0.29 # https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" import errno import os import re import subprocess import sys from typing import Any, Callable, Dict, List, Optional, Tuple import functools def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must # each be defined on a line of their own. _version.py will just call # get_keywords(). git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} return keywords class VersioneerConfig: """Container for Versioneer configuration parameters.""" VCS: str style: str tag_prefix: str parentdir_prefix: str versionfile_source: str verbose: bool def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py cfg = VersioneerConfig() cfg.VCS = "git" cfg.style = "%(STYLE)s" cfg.tag_prefix = "%(TAG_PREFIX)s" cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" cfg.verbose = False return cfg class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" LONG_VERSION_PY: Dict[str, str] = {} HANDLERS: Dict[str, Dict[str, Callable]] = {} def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator """Create decorator to mark a method as the handler of a VCS.""" def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f return decorate def run_command( commands: List[str], args: List[str], cwd: Optional[str] = None, verbose: bool = False, hide_stderr: bool = False, env: Optional[Dict[str, str]] = None, ) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) process = None popen_kwargs: Dict[str, Any] = {} if sys.platform == "win32": # This hides the console window if pythonw.exe is used startupinfo = subprocess.STARTUPINFO() startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW popen_kwargs["startupinfo"] = startupinfo for command in commands: try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git process = subprocess.Popen([command] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None), **popen_kwargs) break except OSError as e: if e.errno == errno.ENOENT: continue if verbose: print("unable to run %%s" %% dispcmd) print(e) return None, None else: if verbose: print("unable to find command, tried %%s" %% (commands,)) return None, None stdout = process.communicate()[0].strip().decode() if process.returncode != 0: if verbose: print("unable to run %%s (error)" %% dispcmd) print("stdout was %%s" %% stdout) return None, process.returncode return stdout, process.returncode def versions_from_parentdir( parentdir_prefix: str, root: str, verbose: bool, ) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both the project name and a version string. We will also support searching up two directory levels for an appropriately named parent directory """ rootdirs = [] for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: print("Tried directories %%s but none started with prefix %%s" %% (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @register_vcs_handler("git", "get_keywords") def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. keywords: Dict[str, str] = {} try: with open(versionfile_abs, "r") as fobj: for line in fobj: if line.strip().startswith("git_refnames ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["refnames"] = mo.group(1) if line.strip().startswith("git_full ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["full"] = mo.group(1) if line.strip().startswith("git_date ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["date"] = mo.group(1) except OSError: pass return keywords @register_vcs_handler("git", "keywords") def git_versions_from_keywords( keywords: Dict[str, str], tag_prefix: str, verbose: bool, ) -> Dict[str, Any]: """Get version information from git keywords.""" if "refnames" not in keywords: raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: # Use only the last line. Previous lines may contain GPG signature # information. date = date.splitlines()[-1] # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because # it's been around since git-1.5.3, and it's too difficult to # discover which version we're using, or to work around using an # older one. date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) refnames = keywords["refnames"].strip() if refnames.startswith("$Format"): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %%d # expansion behaves like git log --decorate=short and strips out the # refs/heads/ and refs/tags/ prefixes that would let us distinguish # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%%s', no digits" %% ",".join(refs - tags)) if verbose: print("likely tags: %%s" %% ",".join(sorted(tags))) for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') if not re.match(r'\d', r): continue if verbose: print("picking %%s" %% r) return {"version": r, "full-revisionid": keywords["full"].strip(), "dirty": False, "error": None, "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") return {"version": "0+unknown", "full-revisionid": keywords["full"].strip(), "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs( tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command ) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* expanded, and _version.py hasn't already been rewritten with a short version string, meaning we're inside a checked out source tree. """ GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] # GIT_DIR can interfere with correct operation of Versioneer. # It may be intended to be passed to the Versioneer-versioned project, # but that should not change where we get our version from. env = os.environ.copy() env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %%s not under git control" %% root) raise NotThisMethod("'git rev-parse --git-dir' returned error") # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) describe_out, rc = runner(GITS, [ "describe", "--tags", "--dirty", "--always", "--long", "--match", f"{tag_prefix}[[:digit:]]*" ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") branch_name = branch_name.strip() if branch_name == "HEAD": # If we aren't exactly on a branch, pick a branch which represents # the current commit. If all else fails, we are on a branchless # commit. branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) # --contains was added in git-1.5.4 if rc != 0 or branches is None: raise NotThisMethod("'git branch --contains' returned error") branches = branches.split("\n") # Remove the first line if we're running detached if "(" in branches[0]: branches.pop(0) # Strip off the leading "* " from the list of branches. branches = [branch[2:] for branch in branches] if "master" in branches: branch_name = "master" elif not branches: branch_name = None else: # Pick the first branch that is returned. Good or bad. branch_name = branches[0] pieces["branch"] = branch_name # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out # look for -dirty suffix dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%%s'" %% describe_out) return pieces # tag full_tag = mo.group(1) if not full_tag.startswith(tag_prefix): if verbose: fmt = "tag '%%s' doesn't start with prefix '%%s'" print(fmt %% (full_tag, tag_prefix)) pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" %% (full_tag, tag_prefix)) return pieces pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) # commit: short hex revision ID pieces["short"] = mo.group(3) else: # HEX: no tags pieces["closest-tag"] = None out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() # Use only the last line. Previous lines may contain GPG signature # information. date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty Exceptions: 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += plus_or_dot(pieces) rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered def render_pep440_branch(pieces: Dict[str, Any]) -> str: """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . The ".dev0" means not master branch. Note that .dev0 sorts backwards (a feature branch will appear "older" than the master branch). Exceptions: 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: if pieces["branch"] != "master": rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: """Split pep440 version string at the post-release segment. Returns the release segments before the post-release and the post-release version number (or -1 if no post-release segment is present). """ vc = str.split(ver, ".post") return vc[0], int(vc[1] or 0) if len(vc) == 2 else None def render_pep440_pre(pieces: Dict[str, Any]) -> str: """TAG[.postN.devDISTANCE] -- No -dirty. Exceptions: 1: no tags. 0.post0.devDISTANCE """ if pieces["closest-tag"]: if pieces["distance"]: # update the post release segment tag_version, post_version = pep440_split_post(pieces["closest-tag"]) rendered = tag_version if post_version is not None: rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) else: rendered += ".post0.dev%%d" %% (pieces["distance"]) else: # no commits, use the tag as the version rendered = pieces["closest-tag"] else: # exception #1 rendered = "0.post0.dev%%d" %% pieces["distance"] return rendered def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards (a dirty tree will appear "older" than the corresponding clean one), but you shouldn't be releasing software with -dirty anyways. Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%%d" %% pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "g%%s" %% pieces["short"] else: # exception #1 rendered = "0.post%%d" %% pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += "+g%%s" %% pieces["short"] return rendered def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . The ".dev0" means not master branch. Exceptions: 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%%d" %% pieces["distance"] if pieces["branch"] != "master": rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "g%%s" %% pieces["short"] if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0.post%%d" %% pieces["distance"] if pieces["branch"] != "master": rendered += ".dev0" rendered += "+g%%s" %% pieces["short"] if pieces["dirty"]: rendered += ".dirty" return rendered def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%%d" %% pieces["distance"] if pieces["dirty"]: rendered += ".dev0" else: # exception #1 rendered = "0.post%%d" %% pieces["distance"] if pieces["dirty"]: rendered += ".dev0" return rendered def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"]: rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. The distance/hash is unconditional. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", "full-revisionid": pieces.get("long"), "dirty": None, "error": pieces["error"], "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) elif style == "pep440-branch": rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) elif style == "pep440-post-branch": rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": rendered = render_git_describe(pieces) elif style == "git-describe-long": rendered = render_git_describe_long(pieces) else: raise ValueError("unknown style '%%s'" %% style) return {"version": rendered, "full-revisionid": pieces["long"], "dirty": pieces["dirty"], "error": None, "date": pieces.get("date")} def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which # case we can only use expanded keywords. cfg = get_config() verbose = cfg.verbose try: return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass try: root = os.path.realpath(__file__) # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, "dirty": None, "error": "unable to find root of source tree", "date": None} try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) return render(pieces, cfg.style) except NotThisMethod: pass try: if cfg.parentdir_prefix: return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) except NotThisMethod: pass return {"version": "0+unknown", "full-revisionid": None, "dirty": None, "error": "unable to compute version", "date": None} ''' @register_vcs_handler("git", "get_keywords") def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. keywords: Dict[str, str] = {} try: with open(versionfile_abs, "r") as fobj: for line in fobj: if line.strip().startswith("git_refnames ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["refnames"] = mo.group(1) if line.strip().startswith("git_full ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["full"] = mo.group(1) if line.strip().startswith("git_date ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["date"] = mo.group(1) except OSError: pass return keywords @register_vcs_handler("git", "keywords") def git_versions_from_keywords( keywords: Dict[str, str], tag_prefix: str, verbose: bool, ) -> Dict[str, Any]: """Get version information from git keywords.""" if "refnames" not in keywords: raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: # Use only the last line. Previous lines may contain GPG signature # information. date = date.splitlines()[-1] # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because # it's been around since git-1.5.3, and it's too difficult to # discover which version we're using, or to work around using an # older one. date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) refnames = keywords["refnames"].strip() if refnames.startswith("$Format"): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d # expansion behaves like git log --decorate=short and strips out the # refs/heads/ and refs/tags/ prefixes that would let us distinguish # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: print("likely tags: %s" % ",".join(sorted(tags))) for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') if not re.match(r'\d', r): continue if verbose: print("picking %s" % r) return {"version": r, "full-revisionid": keywords["full"].strip(), "dirty": False, "error": None, "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") return {"version": "0+unknown", "full-revisionid": keywords["full"].strip(), "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs( tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command ) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* expanded, and _version.py hasn't already been rewritten with a short version string, meaning we're inside a checked out source tree. """ GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] # GIT_DIR can interfere with correct operation of Versioneer. # It may be intended to be passed to the Versioneer-versioned project, # but that should not change where we get our version from. env = os.environ.copy() env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) raise NotThisMethod("'git rev-parse --git-dir' returned error") # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) describe_out, rc = runner(GITS, [ "describe", "--tags", "--dirty", "--always", "--long", "--match", f"{tag_prefix}[[:digit:]]*" ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") branch_name = branch_name.strip() if branch_name == "HEAD": # If we aren't exactly on a branch, pick a branch which represents # the current commit. If all else fails, we are on a branchless # commit. branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) # --contains was added in git-1.5.4 if rc != 0 or branches is None: raise NotThisMethod("'git branch --contains' returned error") branches = branches.split("\n") # Remove the first line if we're running detached if "(" in branches[0]: branches.pop(0) # Strip off the leading "* " from the list of branches. branches = [branch[2:] for branch in branches] if "master" in branches: branch_name = "master" elif not branches: branch_name = None else: # Pick the first branch that is returned. Good or bad. branch_name = branches[0] pieces["branch"] = branch_name # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out # look for -dirty suffix dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%s'" % describe_out) return pieces # tag full_tag = mo.group(1) if not full_tag.startswith(tag_prefix): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" % (full_tag, tag_prefix)) return pieces pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) # commit: short hex revision ID pieces["short"] = mo.group(3) else: # HEX: no tags pieces["closest-tag"] = None out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() # Use only the last line. Previous lines may contain GPG signature # information. date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces def do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None: """Git-specific installation logic for Versioneer. For Git, this means creating/changing .gitattributes to mark _version.py for export-subst keyword substitution. """ GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] files = [versionfile_source] if ipy: files.append(ipy) if "VERSIONEER_PEP518" not in globals(): try: my_path = __file__ if my_path.endswith((".pyc", ".pyo")): my_path = os.path.splitext(my_path)[0] + ".py" versioneer_file = os.path.relpath(my_path) except NameError: versioneer_file = "versioneer.py" files.append(versioneer_file) present = False try: with open(".gitattributes", "r") as fobj: for line in fobj: if line.strip().startswith(versionfile_source): if "export-subst" in line.strip().split()[1:]: present = True break except OSError: pass if not present: with open(".gitattributes", "a+") as fobj: fobj.write(f"{versionfile_source} export-subst\n") files.append(".gitattributes") run_command(GITS, ["add", "--"] + files) def versions_from_parentdir( parentdir_prefix: str, root: str, verbose: bool, ) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both the project name and a version string. We will also support searching up two directory levels for an appropriately named parent directory """ rootdirs = [] for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: print("Tried directories %s but none started with prefix %s" % (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") SHORT_VERSION_PY = """ # This file was generated by 'versioneer.py' (0.29) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. import json version_json = ''' %s ''' # END VERSION_JSON def get_versions(): return json.loads(version_json) """ def versions_from_file(filename: str) -> Dict[str, Any]: """Try to determine the version from _version.py if present.""" try: with open(filename) as f: contents = f.read() except OSError: raise NotThisMethod("unable to read _version.py") mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) if not mo: mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: """Write the given version number to the given _version.py file.""" contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) print("set %s to '%s'" % (filename, versions["version"])) def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty Exceptions: 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += plus_or_dot(pieces) rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered def render_pep440_branch(pieces: Dict[str, Any]) -> str: """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . The ".dev0" means not master branch. Note that .dev0 sorts backwards (a feature branch will appear "older" than the master branch). Exceptions: 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: if pieces["branch"] != "master": rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: """Split pep440 version string at the post-release segment. Returns the release segments before the post-release and the post-release version number (or -1 if no post-release segment is present). """ vc = str.split(ver, ".post") return vc[0], int(vc[1] or 0) if len(vc) == 2 else None def render_pep440_pre(pieces: Dict[str, Any]) -> str: """TAG[.postN.devDISTANCE] -- No -dirty. Exceptions: 1: no tags. 0.post0.devDISTANCE """ if pieces["closest-tag"]: if pieces["distance"]: # update the post release segment tag_version, post_version = pep440_split_post(pieces["closest-tag"]) rendered = tag_version if post_version is not None: rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) else: rendered += ".post0.dev%d" % (pieces["distance"]) else: # no commits, use the tag as the version rendered = pieces["closest-tag"] else: # exception #1 rendered = "0.post0.dev%d" % pieces["distance"] return rendered def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards (a dirty tree will appear "older" than the corresponding clean one), but you shouldn't be releasing software with -dirty anyways. Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "g%s" % pieces["short"] else: # exception #1 rendered = "0.post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += "+g%s" % pieces["short"] return rendered def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . The ".dev0" means not master branch. Exceptions: 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%d" % pieces["distance"] if pieces["branch"] != "master": rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "g%s" % pieces["short"] if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0.post%d" % pieces["distance"] if pieces["branch"] != "master": rendered += ".dev0" rendered += "+g%s" % pieces["short"] if pieces["dirty"]: rendered += ".dirty" return rendered def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" else: # exception #1 rendered = "0.post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" return rendered def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"]: rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. The distance/hash is unconditional. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", "full-revisionid": pieces.get("long"), "dirty": None, "error": pieces["error"], "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) elif style == "pep440-branch": rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) elif style == "pep440-post-branch": rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": rendered = render_git_describe(pieces) elif style == "git-describe-long": rendered = render_git_describe_long(pieces) else: raise ValueError("unknown style '%s'" % style) return {"version": rendered, "full-revisionid": pieces["long"], "dirty": pieces["dirty"], "error": None, "date": pieces.get("date")} class VersioneerBadRootError(Exception): """The project root directory is unknown or missing key files.""" def get_versions(verbose: bool = False) -> Dict[str, Any]: """Get the project version from whatever source is available. Returns dict with two keys: 'version' and 'full'. """ if "versioneer" in sys.modules: # see the discussion in cmdclass.py:get_cmdclass() del sys.modules["versioneer"] root = get_root() cfg = get_config_from_root(root) assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` assert cfg.versionfile_source is not None, \ "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) # extract version from first of: _version.py, VCS command (e.g. 'git # describe'), parentdir. This is meant to work for developers using a # source checkout, for users of a tarball created by 'setup.py sdist', # and for users of a tarball/zipball created by 'git archive' or github's # download-from-tag feature or the equivalent in other VCSes. get_keywords_f = handlers.get("get_keywords") from_keywords_f = handlers.get("keywords") if get_keywords_f and from_keywords_f: try: keywords = get_keywords_f(versionfile_abs) ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) if verbose: print("got version from expanded keyword %s" % ver) return ver except NotThisMethod: pass try: ver = versions_from_file(versionfile_abs) if verbose: print("got version from file %s %s" % (versionfile_abs, ver)) return ver except NotThisMethod: pass from_vcs_f = handlers.get("pieces_from_vcs") if from_vcs_f: try: pieces = from_vcs_f(cfg.tag_prefix, root, verbose) ver = render(pieces, cfg.style) if verbose: print("got version from VCS %s" % ver) return ver except NotThisMethod: pass try: if cfg.parentdir_prefix: ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) if verbose: print("got version from parentdir %s" % ver) return ver except NotThisMethod: pass if verbose: print("unable to compute version") return {"version": "0+unknown", "full-revisionid": None, "dirty": None, "error": "unable to compute version", "date": None} def get_version() -> str: """Get the short version string for this project.""" return get_versions()["version"] def get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None): """Get the custom setuptools subclasses used by Versioneer. If the package uses a different cmdclass (e.g. one from numpy), it should be provide as an argument. """ if "versioneer" in sys.modules: del sys.modules["versioneer"] # this fixes the "python setup.py develop" case (also 'install' and # 'easy_install .'), in which subdependencies of the main project are # built (using setup.py bdist_egg) in the same python process. Assume # a main project A and a dependency B, which use different versions # of Versioneer. A's setup.py imports A's Versioneer, leaving it in # sys.modules by the time B's setup.py is executed, causing B to run # with the wrong versioneer. Setuptools wraps the sub-dep builds in a # sandbox that restores sys.modules to it's pre-build state, so the # parent is protected against the child's "import versioneer". By # removing ourselves from sys.modules here, before the child build # happens, we protect the child from the parent's versioneer too. # Also see https://github.com/python-versioneer/python-versioneer/issues/52 cmds = {} if cmdclass is None else cmdclass.copy() # we add "version" to setuptools from setuptools import Command class cmd_version(Command): description = "report generated version string" user_options: List[Tuple[str, str, str]] = [] boolean_options: List[str] = [] def initialize_options(self) -> None: pass def finalize_options(self) -> None: pass def run(self) -> None: vers = get_versions(verbose=True) print("Version: %s" % vers["version"]) print(" full-revisionid: %s" % vers.get("full-revisionid")) print(" dirty: %s" % vers.get("dirty")) print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) cmds["version"] = cmd_version # we override "build_py" in setuptools # # most invocation pathways end up running build_py: # distutils/build -> build_py # distutils/install -> distutils/build ->.. # setuptools/bdist_wheel -> distutils/install ->.. # setuptools/bdist_egg -> distutils/install_lib -> build_py # setuptools/install -> bdist_egg ->.. # setuptools/develop -> ? # pip install: # copies source tree to a tempdir before running egg_info/etc # if .git isn't copied too, 'git describe' will fail # then does setup.py bdist_wheel, or sometimes setup.py install # setup.py egg_info -> ? # pip install -e . and setuptool/editable_wheel will invoke build_py # but the build_py command is not expected to copy any files. # we override different "build_py" commands for both environments if 'build_py' in cmds: _build_py: Any = cmds['build_py'] else: from setuptools.command.build_py import build_py as _build_py class cmd_build_py(_build_py): def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_py.run(self) if getattr(self, "editable_mode", False): # During editable installs `.py` and data files are # not copied to build_lib return # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) cmds["build_py"] = cmd_build_py if 'build_ext' in cmds: _build_ext: Any = cmds['build_ext'] else: from setuptools.command.build_ext import build_ext as _build_ext class cmd_build_ext(_build_ext): def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_ext.run(self) if self.inplace: # build_ext --inplace will only build extensions in # build/lib<..> dir with no _version.py to write to. # As in place builds will already have a _version.py # in the module dir, we do not need to write one. return # now locate _version.py in the new build/ directory and replace # it with an updated value if not cfg.versionfile_build: return target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) if not os.path.exists(target_versionfile): print(f"Warning: {target_versionfile} does not exist, skipping " "version update. This can happen if you are running build_ext " "without first running build_py.") return print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) cmds["build_ext"] = cmd_build_ext if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe # type: ignore # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION # "product_version": versioneer.get_version(), # ... class cmd_build_exe(_build_exe): def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() target_versionfile = cfg.versionfile_source print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) _build_exe.run(self) os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write(LONG % {"DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, }) cmds["build_exe"] = cmd_build_exe del cmds["build_py"] if 'py2exe' in sys.modules: # py2exe enabled? try: from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore except ImportError: from py2exe.distutils_buildexe import py2exe as _py2exe # type: ignore class cmd_py2exe(_py2exe): def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() target_versionfile = cfg.versionfile_source print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) _py2exe.run(self) os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write(LONG % {"DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, }) cmds["py2exe"] = cmd_py2exe # sdist farms its file list building out to egg_info if 'egg_info' in cmds: _egg_info: Any = cmds['egg_info'] else: from setuptools.command.egg_info import egg_info as _egg_info class cmd_egg_info(_egg_info): def find_sources(self) -> None: # egg_info.find_sources builds the manifest list and writes it # in one shot super().find_sources() # Modify the filelist and normalize it root = get_root() cfg = get_config_from_root(root) self.filelist.append('versioneer.py') if cfg.versionfile_source: # There are rare cases where versionfile_source might not be # included by default, so we must be explicit self.filelist.append(cfg.versionfile_source) self.filelist.sort() self.filelist.remove_duplicates() # The write method is hidden in the manifest_maker instance that # generated the filelist and was thrown away # We will instead replicate their final normalization (to unicode, # and POSIX-style paths) from setuptools import unicode_utils normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') for f in self.filelist.files] manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') with open(manifest_filename, 'w') as fobj: fobj.write('\n'.join(normalized)) cmds['egg_info'] = cmd_egg_info # we override different "sdist" commands for both environments if 'sdist' in cmds: _sdist: Any = cmds['sdist'] else: from setuptools.command.sdist import sdist as _sdist class cmd_sdist(_sdist): def run(self) -> None: versions = get_versions() self._versioneer_generated_versions = versions # unless we update this, the command will keep using the old # version self.distribution.metadata.version = versions["version"] return _sdist.run(self) def make_release_tree(self, base_dir: str, files: List[str]) -> None: root = get_root() cfg = get_config_from_root(root) _sdist.make_release_tree(self, base_dir, files) # now locate _version.py in the new base_dir directory # (remembering that it may be a hardlink) and replace it with an # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, self._versioneer_generated_versions) cmds["sdist"] = cmd_sdist return cmds CONFIG_ERROR = """ setup.cfg is missing the necessary Versioneer configuration. You need a section like: [versioneer] VCS = git style = pep440 versionfile_source = src/myproject/_version.py versionfile_build = myproject/_version.py tag_prefix = parentdir_prefix = myproject- You will also need to edit your setup.py to use the results: import versioneer setup(version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), ...) Please read the docstring in ./versioneer.py for configuration instructions, edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. """ SAMPLE_CONFIG = """ # See the docstring in versioneer.py for instructions. Note that you must # re-run 'versioneer.py setup' after changing this section, and commit the # resulting files. [versioneer] #VCS = git #style = pep440 #versionfile_source = #versionfile_build = #tag_prefix = #parentdir_prefix = """ OLD_SNIPPET = """ from ._version import get_versions __version__ = get_versions()['version'] del get_versions """ INIT_PY_SNIPPET = """ from . import {0} __version__ = {0}.get_versions()['version'] """ def do_setup() -> int: """Do main VCS-independent setup function for installing Versioneer.""" root = get_root() try: cfg = get_config_from_root(root) except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: if isinstance(e, (OSError, configparser.NoSectionError)): print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) return 1 print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write(LONG % {"DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, }) ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") maybe_ipy: Optional[str] = ipy if os.path.exists(ipy): try: with open(ipy, "r") as f: old = f.read() except OSError: old = "" module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] snippet = INIT_PY_SNIPPET.format(module) if OLD_SNIPPET in old: print(" replacing boilerplate in %s" % ipy) with open(ipy, "w") as f: f.write(old.replace(OLD_SNIPPET, snippet)) elif snippet not in old: print(" appending to %s" % ipy) with open(ipy, "a") as f: f.write(snippet) else: print(" %s unmodified" % ipy) else: print(" %s doesn't exist, ok" % ipy) maybe_ipy = None # Make VCS-specific changes. For git, this means creating/changing # .gitattributes to mark _version.py for export-subst keyword # substitution. do_vcs_install(cfg.versionfile_source, maybe_ipy) return 0 def scan_setup_py() -> int: """Validate the contents of setup.py against Versioneer's expectations.""" found = set() setters = False errors = 0 with open("setup.py", "r") as f: for line in f.readlines(): if "import versioneer" in line: found.add("import") if "versioneer.get_cmdclass()" in line: found.add("cmdclass") if "versioneer.get_version()" in line: found.add("get_version") if "versioneer.VCS" in line: setters = True if "versioneer.versionfile_source" in line: setters = True if len(found) != 3: print("") print("Your setup.py appears to be missing some important items") print("(but I might be wrong). Please make sure it has something") print("roughly like the following:") print("") print(" import versioneer") print(" setup( version=versioneer.get_version(),") print(" cmdclass=versioneer.get_cmdclass(), ...)") print("") errors += 1 if setters: print("You should remove lines like 'versioneer.VCS = ' and") print("'versioneer.versionfile_source = ' . This configuration") print("now lives in setup.cfg, and should be removed from setup.py") print("") errors += 1 return errors def setup_command() -> NoReturn: """Set up Versioneer and exit with appropriate error code.""" errors = do_setup() errors += scan_setup_py() sys.exit(1 if errors else 0) if __name__ == "__main__": cmd = sys.argv[1] if cmd == "setup": setup_command()