././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1625583712.7161546 fbtftp-0.5/0000755000076500000240000000000000000000000012236 5ustar00skozlovstaff././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/AUTHORS0000644000076500000240000000024600000000000013310 0ustar00skozlovstafffbtftp was created by Angelo Failla for Facebook. Here the list of people who contributed: * Marcin Wyszynski * Andrea Barberio ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/CONTRIBUTING.md0000644000076500000240000000436300000000000014475 0ustar00skozlovstaff# How to contribute Contributing to `fbtftp` follows the same process of contributing to any GitHub repository. In short: * forking the project * making your changes * making a pull request More details in the following paragraphs. # Prerequisites In order to contribute to `fbtftp` you need a GitHub account. If you don't have one already, please [sign up on GitHub](https://github.com/signup/free) first. # Making your changes To make changes you have to: * fork the `fbtftp` repository. See [Fork a repo](https://help.github.com/articles/fork-a-repo/) on GitHub's documentation * make your changes locally. See Coding Style below * make a pull request. See [Using pull requests](https://help.github.com/articles/using-pull-requests/) on GitHub's documentation Once we receive your pull request, one of our project members will review your changes, if necessary they will ask you to make additional changes, and if the patch is good enough, it will be merged in the main repository. # Coding style `fbtftp` is written in Python 3 and follows the [PEP-8 Style Guide](https://www.python.org/dev/peps/pep-0008/) plus some Facebook specific style guids. We want to keep the style consistent throughout the code, so we will not accept pull requests that do not pass the style checks. The style checking is done when running `make test`, please make sure to run it before submitting your patch. You might also consider installing and using `yapf` to automatically format the code to follow our style guidelines. A `.style.yapf` is provided to facilitate this. Run this before you send your PR: ``` $ pip3 install yapf $ make clean $ yapf -i $(find . -name ".py") ``` # I don't want to make a pull request! We love pull requests, but it's not necessary to write code to contribute. If for any reason you can't make a pull request (e.g. you just want to suggest us an improvement), let us know. [Create an issue](https://help.github.com/articles/creating-an-issue/) on the `fbtftp` issue tracker and we will review your request. # Code of Conduct Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.facebook.com/codeofconduct) so that you can understand what actions will and will not be tolerated. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/LICENSE0000644000076500000240000000207600000000000013250 0ustar00skozlovstaffMIT License Copyright (c) Facebook, Inc. and its affiliates. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/LICENSE-examples0000644000076500000240000000121200000000000015053 0ustar00skozlovstaffCopyright (c) 2016-present, Facebook, Inc. All rights reserved. The examples provided by Facebook are for non-commercial testing and evaluation purposes only. Facebook reserves all rights not expressly granted. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL FACEBOOK BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/MANIFEST.in0000644000076500000240000000056700000000000014004 0ustar00skozlovstaff# Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. include AUTHORS include CONTRIBUTING.md include LICENSE include LICENSE-examples include PATENTS include README.md recursive-include fbtftp * recursive-include examples * recursive-include tests * ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1625583712.7163372 fbtftp-0.5/PKG-INFO0000644000076500000240000002072700000000000013343 0ustar00skozlovstaffMetadata-Version: 2.1 Name: fbtftp Version: 0.5 Summary: A python3 framework to build dynamic TFTP servers Home-page: https://www.github.com/facebook/fbtftp Author: Angelo Failla Author-email: pallotron@fb.com License: BSD Keywords: tftp daemon infrastructure provisioning netboot Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: License :: OSI Approved :: MIT License Classifier: Operating System :: POSIX :: Linux Classifier: Programming Language :: Python :: 3 :: Only Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Topic :: Software Development :: Libraries :: Application Frameworks Classifier: Topic :: System :: Boot Classifier: Topic :: Utilities Classifier: Intended Audience :: Developers Description-Content-Type: text/markdown License-File: LICENSE License-File: LICENSE-examples License-File: AUTHORS [![Build Status](https://travis-ci.org/facebook/fbtftp.svg?branch=master)](https://travis-ci.org/facebook/fbtftp) [![codebeat badge](https://codebeat.co/badges/2d4c7650-4752-4adf-a570-1948ecb4d6a8)](https://codebeat.co/projects/github-com-facebook-fbtftp) # What is fbtftp? `fbtftp` is Facebook's implementation of a dynamic TFTP server framework. It lets you create custom TFTP servers and wrap your own logic into it in a very simple manner. Facebook currently uses it in production, and it's deployed at global scale across all of our data centers. # Why did you do that? We love to use existing open source software and to contribute upstream, but sometimes it's just not enough at our scale. We ended up writing our own tftp framework and decided to open source it. `fbtftp` was born from the need of having an easy-to-configure and easy-to-expand TFTP server, that would work at large scale. The standard `in.tftpd` is a 20+ years old piece of software written in C that is very difficult to extend. `fbtftp` is written in `python3` and lets you plug your own logic to: * publish per session and server wide statistics to your infrastructure * define how response data is built: * can be a file from disk; * can be a file created dynamically; * you name it! # How do you use `fbtftp` at Facebook? We created our own Facebook-specific server based on the framework to: * stream static files (initrd and kernels) from our http repositories (no need to fill your tftp root directory with files); * generate grub2 per-machine configuration dynamically (no need to copy grub2 configuration files on disk); * publish per-server and per-connection statistics to our internal monitoring systems; * deployment is easy and "container-ready", just copy the application somewhere, start it and you are done. # Is it better than the other TFTP servers? It depends on your needs! `fbtftp` is written in Python 3 using a multiprocessing model; its primary focus is not speed, but flexibility and scalability. Yet it is fast enough at our datacenter scale :) It is well-suited for large installations where scalability and custom features are needed. # What does it support? The framework implements the following RFCs: * [RFC 1350](https://tools.ietf.org/html/rfc1350) (the main TFTP specification) * [RFC 2347](https://tools.ietf.org/html/rfc2347) (Option Extension) * [RFC 2348](https://tools.ietf.org/html/rfc2348) (Blocksize option) * [RFC 2349](https://tools.ietf.org/html/rfc2349) (Timeout Interval and Transfer Size Options). Note that the server framework only support RRQs (read only) operations. (Who uses WRQ TFTP requests in 2019? :P) # How does it work? All you need to do is understanding three classes and two callback functions, and you are good to go: * `BaseServer`: This class implements the process which deals with accepting new requests on the UDP port provided. Default TFTP parameters like timeout, port number and number of retries can be passed. This class doesn't have to be used directly, you must inherit from it and override `get_handler()` method to return an instance of `BaseHandler`. The class accepts a `server_stats_callback`, more about it below. the callback is not re-entrant, if you need this you have to implement your own locking logic. This callback is executed periodically and you can use it to publish server level stats to your monitoring infrastructure. A series of predefined counters are provided. Refer to the class documentation to find out more. * `BaseHandler`: This class deals with talking to a single client. This class lives into its separate process, process which is spawned by the `BaserServer` class, which will make sure to reap the child properly when the session is over. Do not use this class as is, instead inherit from it and override the `get_response_data()` method. Such method must return an instance of a subclass of `ResponseData`. * `ResponseData`: it's a file-like class that implements `read(num_bytes)`, `size()` and `close()`. As the previous two classes you'll have to inherit from this and implement those methods. This class basically let you define how to return the actual data * `server_stats_callback`: function that is called periodically (every 60 seconds by default). The callback is not re-entrant, if you need this you have to implement your own locking logic. This callback is executed periodically and you can use it to publish server level stats to your monitoring infrastructure. A series of predefined counters are provided. Refer to the class documentation to find out more. * `session_stats_callback`: function that is called when a client session is over. # Requirements * Linux (or any system that supports [`epoll`](http://linux.die.net/man/4/epoll)) * BSD (or any system that supports [`kqueue`](https://www.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2)) * Python 3.4+ # Installation `fbtftp` is distributed with the standard `distutils` package, so you can build it with: ``` python setup.py build ``` and install it with: ``` python setup.py install ``` Be sure to run as root if you want to install `fbtftp` system wide. You can also use a `virtualenv`, or install it as user by running: ``` python setup.py install --user ``` # Example Writing your own server is simple. Let's take a look at how to write a simple server that serves files from disk: ```python from fbtftp.base_handler import BaseHandler from fbtftp.base_handler import ResponseData from fbtftp.base_server import BaseServer import os class FileResponseData(ResponseData): def __init__(self, path): self._size = os.stat(path).st_size self._reader = open(path, 'rb') def read(self, n): return self._reader.read(n) def size(self): return self._size def close(self): self._reader.close() def print_session_stats(stats): print(stats) def print_server_stats(stats): counters = stats.get_and_reset_all_counters() print('Server stats - every {} seconds'.format(stats.interval)) print(counters) class StaticHandler(BaseHandler): def __init__(self, server_addr, peer, path, options, root, stats_callback): self._root = root super().__init__(server_addr, peer, path, options, stats_callback) def get_response_data(self): return FileResponseData(os.path.join(self._root, self._path)) class StaticServer(BaseServer): def __init__(self, address, port, retries, timeout, root, handler_stats_callback, server_stats_callback=None): self._root = root self._handler_stats_callback = handler_stats_callback super().__init__(address, port, retries, timeout, server_stats_callback) def get_handler(self, server_addr, peer, path, options): return StaticHandler( server_addr, peer, path, options, self._root, self._handler_stats_callback) def main(): server = StaticServer(address='::', port=69, retries=3, timeout=5, root='/var/tftproot', handler_stats_callback=print_session_stats, server_stats_callback=print_server_stats) try: server.run() except KeyboardInterrupt: server.close() if __name__ == '__main__': main() ``` # Who wrote it? `fbtftp` was created by Marcin Wyszynski (@marcinwyszynski) and Angelo Failla at Facebook Ireland. Other honorable contributors: * Andrea Barberio # License MIT License ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/README.md0000644000076500000240000001665200000000000013527 0ustar00skozlovstaff[![Build Status](https://travis-ci.org/facebook/fbtftp.svg?branch=master)](https://travis-ci.org/facebook/fbtftp) [![codebeat badge](https://codebeat.co/badges/2d4c7650-4752-4adf-a570-1948ecb4d6a8)](https://codebeat.co/projects/github-com-facebook-fbtftp) # What is fbtftp? `fbtftp` is Facebook's implementation of a dynamic TFTP server framework. It lets you create custom TFTP servers and wrap your own logic into it in a very simple manner. Facebook currently uses it in production, and it's deployed at global scale across all of our data centers. # Why did you do that? We love to use existing open source software and to contribute upstream, but sometimes it's just not enough at our scale. We ended up writing our own tftp framework and decided to open source it. `fbtftp` was born from the need of having an easy-to-configure and easy-to-expand TFTP server, that would work at large scale. The standard `in.tftpd` is a 20+ years old piece of software written in C that is very difficult to extend. `fbtftp` is written in `python3` and lets you plug your own logic to: * publish per session and server wide statistics to your infrastructure * define how response data is built: * can be a file from disk; * can be a file created dynamically; * you name it! # How do you use `fbtftp` at Facebook? We created our own Facebook-specific server based on the framework to: * stream static files (initrd and kernels) from our http repositories (no need to fill your tftp root directory with files); * generate grub2 per-machine configuration dynamically (no need to copy grub2 configuration files on disk); * publish per-server and per-connection statistics to our internal monitoring systems; * deployment is easy and "container-ready", just copy the application somewhere, start it and you are done. # Is it better than the other TFTP servers? It depends on your needs! `fbtftp` is written in Python 3 using a multiprocessing model; its primary focus is not speed, but flexibility and scalability. Yet it is fast enough at our datacenter scale :) It is well-suited for large installations where scalability and custom features are needed. # What does it support? The framework implements the following RFCs: * [RFC 1350](https://tools.ietf.org/html/rfc1350) (the main TFTP specification) * [RFC 2347](https://tools.ietf.org/html/rfc2347) (Option Extension) * [RFC 2348](https://tools.ietf.org/html/rfc2348) (Blocksize option) * [RFC 2349](https://tools.ietf.org/html/rfc2349) (Timeout Interval and Transfer Size Options). Note that the server framework only support RRQs (read only) operations. (Who uses WRQ TFTP requests in 2019? :P) # How does it work? All you need to do is understanding three classes and two callback functions, and you are good to go: * `BaseServer`: This class implements the process which deals with accepting new requests on the UDP port provided. Default TFTP parameters like timeout, port number and number of retries can be passed. This class doesn't have to be used directly, you must inherit from it and override `get_handler()` method to return an instance of `BaseHandler`. The class accepts a `server_stats_callback`, more about it below. the callback is not re-entrant, if you need this you have to implement your own locking logic. This callback is executed periodically and you can use it to publish server level stats to your monitoring infrastructure. A series of predefined counters are provided. Refer to the class documentation to find out more. * `BaseHandler`: This class deals with talking to a single client. This class lives into its separate process, process which is spawned by the `BaserServer` class, which will make sure to reap the child properly when the session is over. Do not use this class as is, instead inherit from it and override the `get_response_data()` method. Such method must return an instance of a subclass of `ResponseData`. * `ResponseData`: it's a file-like class that implements `read(num_bytes)`, `size()` and `close()`. As the previous two classes you'll have to inherit from this and implement those methods. This class basically let you define how to return the actual data * `server_stats_callback`: function that is called periodically (every 60 seconds by default). The callback is not re-entrant, if you need this you have to implement your own locking logic. This callback is executed periodically and you can use it to publish server level stats to your monitoring infrastructure. A series of predefined counters are provided. Refer to the class documentation to find out more. * `session_stats_callback`: function that is called when a client session is over. # Requirements * Linux (or any system that supports [`epoll`](http://linux.die.net/man/4/epoll)) * BSD (or any system that supports [`kqueue`](https://www.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2)) * Python 3.4+ # Installation `fbtftp` is distributed with the standard `distutils` package, so you can build it with: ``` python setup.py build ``` and install it with: ``` python setup.py install ``` Be sure to run as root if you want to install `fbtftp` system wide. You can also use a `virtualenv`, or install it as user by running: ``` python setup.py install --user ``` # Example Writing your own server is simple. Let's take a look at how to write a simple server that serves files from disk: ```python from fbtftp.base_handler import BaseHandler from fbtftp.base_handler import ResponseData from fbtftp.base_server import BaseServer import os class FileResponseData(ResponseData): def __init__(self, path): self._size = os.stat(path).st_size self._reader = open(path, 'rb') def read(self, n): return self._reader.read(n) def size(self): return self._size def close(self): self._reader.close() def print_session_stats(stats): print(stats) def print_server_stats(stats): counters = stats.get_and_reset_all_counters() print('Server stats - every {} seconds'.format(stats.interval)) print(counters) class StaticHandler(BaseHandler): def __init__(self, server_addr, peer, path, options, root, stats_callback): self._root = root super().__init__(server_addr, peer, path, options, stats_callback) def get_response_data(self): return FileResponseData(os.path.join(self._root, self._path)) class StaticServer(BaseServer): def __init__(self, address, port, retries, timeout, root, handler_stats_callback, server_stats_callback=None): self._root = root self._handler_stats_callback = handler_stats_callback super().__init__(address, port, retries, timeout, server_stats_callback) def get_handler(self, server_addr, peer, path, options): return StaticHandler( server_addr, peer, path, options, self._root, self._handler_stats_callback) def main(): server = StaticServer(address='::', port=69, retries=3, timeout=5, root='/var/tftproot', handler_stats_callback=print_session_stats, server_stats_callback=print_server_stats) try: server.run() except KeyboardInterrupt: server.close() if __name__ == '__main__': main() ``` # Who wrote it? `fbtftp` was created by Marcin Wyszynski (@marcinwyszynski) and Angelo Failla at Facebook Ireland. Other honorable contributors: * Andrea Barberio # License MIT License ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1625583712.7090049 fbtftp-0.5/examples/0000755000076500000240000000000000000000000014054 5ustar00skozlovstaff././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/examples/server.py0000644000076500000240000000723100000000000015737 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import logging import os from fbtftp.base_handler import BaseHandler from fbtftp.base_handler import ResponseData from fbtftp.base_server import BaseServer class FileResponseData(ResponseData): def __init__(self, path): self._size = os.stat(path).st_size self._reader = open(path, "rb") def read(self, n): return self._reader.read(n) def size(self): return self._size def close(self): self._reader.close() def print_session_stats(stats): logging.info("Stats: for %r requesting %r" % (stats.peer, stats.file_path)) logging.info("Error: %r" % stats.error) logging.info("Time spent: %dms" % (stats.duration() * 1e3)) logging.info("Packets sent: %d" % stats.packets_sent) logging.info("Packets ACKed: %d" % stats.packets_acked) logging.info("Bytes sent: %d" % stats.bytes_sent) logging.info("Options: %r" % stats.options) logging.info("Blksize: %r" % stats.blksize) logging.info("Retransmits: %d" % stats.retransmits) logging.info("Server port: %d" % stats.server_addr[1]) logging.info("Client port: %d" % stats.peer[1]) def print_server_stats(stats): """ Print server stats - see the ServerStats class """ # NOTE: remember to reset the counters you use, to allow the next cycle to # start fresh counters = stats.get_and_reset_all_counters() logging.info("Server stats - every %d seconds" % stats.interval) if "process_count" in counters: logging.info( "Number of spawned TFTP workers in stats time frame : %d" % counters["process_count"] ) class StaticHandler(BaseHandler): def __init__(self, server_addr, peer, path, options, root, stats_callback): self._root = root super().__init__(server_addr, peer, path, options, stats_callback) def get_response_data(self): return FileResponseData(os.path.join(self._root, self._path)) class StaticServer(BaseServer): def __init__( self, address, port, retries, timeout, root, handler_stats_callback, server_stats_callback=None, ): self._root = root self._handler_stats_callback = handler_stats_callback super().__init__(address, port, retries, timeout, server_stats_callback) def get_handler(self, server_addr, peer, path, options): return StaticHandler( server_addr, peer, path, options, self._root, self._handler_stats_callback ) def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--ip", type=str, default="::", help="IP address to bind to") parser.add_argument("--port", type=int, default=1969, help="port to bind to") parser.add_argument( "--retries", type=int, default=5, help="number of per-packet retries" ) parser.add_argument( "--timeout_s", type=int, default=2, help="timeout for packet retransmission" ) parser.add_argument( "--root", type=str, default="", help="root of the static filesystem" ) return parser.parse_args() def main(): args = get_arguments() logging.getLogger().setLevel(logging.DEBUG) server = StaticServer( args.ip, args.port, args.retries, args.timeout_s, args.root, print_session_stats, print_server_stats, ) try: server.run() except KeyboardInterrupt: server.close() if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1625583712.7112944 fbtftp-0.5/fbtftp/0000755000076500000240000000000000000000000013523 5ustar00skozlovstaff././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/fbtftp/__init__.py0000644000076500000240000000056700000000000015644 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .base_handler import BaseHandler, ResponseData, SessionStats from .base_server import BaseServer __all__ = ["BaseHandler", "BaseServer", "ResponseData", "SessionStats"] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/fbtftp/base_handler.py0000644000076500000240000003604300000000000016512 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from collections import OrderedDict import io import ipaddress import logging import multiprocessing import socket import struct import sys import time from . import constants from .netascii import NetasciiReader class ResponseData: """A base class representing a file-like object""" def read(self, n): raise NotImplementedError() def size(self): raise NotImplementedError() def close(self): raise NotImplementedError() class StringResponseData(ResponseData): """ A convenience subclass of `ResponseData` that transforms an input String into a file-like object. """ def __init__(self, string): self._size = len(string.encode("latin-1")) self._reader = io.StringIO(string) def read(self, n): return bytes(self._reader.read(n).encode("latin-1")) def size(self): return self._size def close(self): pass class SessionStats: """ SessionStats represents a digest of what happened during a session. Data inside the object gets populated at the end of a session. See `__init__` to see what you'll get. Note: You should never need to instantiate an object of this class. This object is what gets passed to the callback you provide to the `BaseHandler` class. """ def __init__(self, server_addr, peer, file_path): self.peer = peer self.server_addr = server_addr self.file_path = file_path self.error = {} self.options = {} self.start_time = time.time() self.packets_sent = 0 self.packets_acked = 0 self.bytes_sent = 0 self.retransmits = 0 self.blksize = constants.DEFAULT_BLKSIZE def duration(self): return time.time() - self.start_time class BaseHandler(multiprocessing.Process): def __init__(self, server_addr, peer, path, options, stats_callback): """ Class that deals with talking to a single client. Being a subclass of `multiprocessing.Process` this will run in a separate process from the main process. Note: Do not use this class as is, inherit from it and override the `get_response_data` method which must return a subclass of `ResponseData`. Args: server_addr (tuple): (ip, port) of the server peer (tuple): (ip, port of) the peer path (string): requested file options (dict): a dictionary containing the options the client wants to negotiate. stats_callback (callable): a callable that will be executed at the end of the session. It gets passed an instance of the `SessionStats` class. """ self._timeout = int(options["default_timeout"]) self._server_addr = server_addr self._reset_timeout() self._retries = int(options["retries"]) self._block_size = constants.DEFAULT_BLKSIZE self._last_block_sent = 0 self._retransmits = 0 self._global_retransmits = 0 self._current_block = None self._should_stop = False self._waiting_last_ack = False self._path = path self._options = options self._stats_callback = stats_callback self._response_data = None self._listener = None self._peer = peer logging.info( "New connection from peer `%s` asking for path `%s`" % (str(peer), str(path)) ) self._family = socket.AF_INET6 # the format of the peer tuple is different for v4 and v6 if isinstance(ipaddress.ip_address(server_addr[0]), ipaddress.IPv4Address): self._family = socket.AF_INET # peer address format is different in v4 world self._peer = (self._peer[0].replace("::ffff:", ""), self._peer[1]) self._stats = SessionStats(self._server_addr, self._peer, self._path) try: self._response_data = self.get_response_data() except FileNotFoundError as e: logging.warning(str(e)) self._stats.error = { "error_code": constants.ERR_FILE_NOT_FOUND, "error_message": str(e), } except Exception as e: logging.exception("Caught exception: %s." % e) self._stats.error = { "error_code": constants.ERR_UNDEFINED, "error_message": str(e), } super().__init__() def _get_listener(self): if not self._listener: self._listener = socket.socket(self._family, socket.SOCK_DGRAM) self._listener.bind((str(self._server_addr[0]), 0)) return self._listener def _on_close(self): """ Called at the end of a session. This method sets number of retransmissions and calls the stats callback at the end of the session. """ self._stats.retransmits = self._global_retransmits self._stats_callback(self._stats) def _close(self, test=False): """ Wrapper around `_on_close`. Its duty is to perform the necessary cleanup. Closing `ResponseData` object, closing UDP sockets, and gracefully exiting the process with exit code of 0. """ try: self._on_close() except Exception as e: logging.exception("Exception raised when calling _on_close: %s" % e) finally: logging.debug("Closing response data object") if self._response_data: self._response_data.close() logging.debug("Closing socket") self._get_listener().close() logging.debug("Dying.") if test is False: sys.exit(0) def _parse_options(self): """ Method that deals with parsing/validation options provided by the client. """ opts_to_ack = OrderedDict() # We remove retries and default_timeout from self._options because # we don't need to include them in the OACK response to the client. # Their value is already hold in self._retries and self._timeout. del self._options["retries"] del self._options["default_timeout"] logging.info( "Options requested from peer {}: {}".format(self._peer, self._options) ) self._stats.options_in = self._options if "mode" in self._options and self._options["mode"] == "netascii": self._response_data = NetasciiReader(self._response_data) elif "mode" in self._options and self._options["mode"] != "octet": self._stats.error = { "error_code": constants.ERR_ILLEGAL_OPERATION, "error_message": "Unknown mode: %r" % self._options["mode"], } self._transmit_error() self._close() return # no way anything else will succeed now # Let's ack the options in the same order we got asked for them # The RFC mentions that option order is not significant, but it can't # hurt. This relies on Python 3.6 dicts to be ordered. for k, v in self._options.items(): if k == "blksize": opts_to_ack["blksize"] = v self._block_size = int(v) if k == "tsize": self._tsize = self._response_data.size() if self._tsize is not None: opts_to_ack["tsize"] = str(self._tsize) if k == "timeout": opts_to_ack["timeout"] = v self._timeout = int(v) self._options = opts_to_ack # only ACK options we can handle logging.info( "Options to ack for peer {}: {}".format(self._peer, self._options) ) self._stats.blksize = self._block_size self._stats.options = self._options self._stats.options_acked = self._options def run(self): """This is the main serving loop.""" if self._stats.error: self._transmit_error() self._close() return self._parse_options() if self._options: self._transmit_oack() else: self._next_block() self._transmit_data() while not self._should_stop: try: self.run_once() except (KeyboardInterrupt, SystemExit): logging.info( "Caught KeyboardInterrupt/SystemExit exception. " "Will exit." ) break self._close() def run_once(self): """The main body of the server loop.""" self.on_new_data() if time.time() > self._expire_ts: self._handle_timeout() def _reset_timeout(self): """ This method resets the connection timeout in order to extend its lifetime.. It does so setting the timestamp in the future. """ self._expire_ts = time.time() + self._timeout def on_new_data(self): """ Called when new data is available on the socket. This method will extract acknowledged block numbers and handle possible errors. """ # Note that we use blocking socket, because it has its own dedicated # process. We read only 512 bytes. try: listener = self._get_listener() listener.settimeout(self._timeout) data, peer = listener.recvfrom(constants.DEFAULT_BLKSIZE) listener.settimeout(None) except socket.timeout: return if peer != self._peer: logging.error("Unexpected peer: %s, expected %s" % (peer, self._peer)) self._should_stop = True return code, block_number = struct.unpack("!HH", data[:4]) if code == constants.OPCODE_ERROR: # When the client sends an OPCODE_ERROR# # the block number is the ERR codes in constants.py self._stats.error = { "error_code": block_number, "error_message": data[4:-1].decode("ascii", "ignore"), } # An error was reported by the client which terminates the exchange logging.error( "Error reported from client: %s" % self._stats.error["error_message"] ) self._transmit_error() self._should_stop = True return if code != constants.OPCODE_ACK: logging.error( "Expected an ACK opcode from %s, got: %d" % (self._peer, code) ) self._stats.error = { "error_code": constants.ERR_ILLEGAL_OPERATION, "error_message": "I only do reads, really", } self._transmit_error() self._should_stop = True return self._handle_ack(block_number) def _handle_ack(self, block_number): """Deals with a client ACK packet.""" if block_number != self._last_block_sent: # Unexpected ACK, let's ignore this. return self._reset_timeout() self._retransmits = 0 self._stats.packets_acked += 1 if self._waiting_last_ack: self._should_stop = True return self._next_block() self._transmit_data() def _handle_timeout(self): if self._retries >= self._retransmits: self._transmit_data() self._retransmits += 1 self._global_retransmits += 1 return error_msg = "timeout after {} retransmits.".format(self._retransmits) if self._waiting_last_ack: error_msg += " Missed last ack." self._stats.error = { "error_code": constants.ERR_UNDEFINED, "error_message": error_msg, } self._should_stop = True logging.error(self._stats.error["error_message"]) def _next_block(self): """ Reads the next block from `ResponseData`. If there are problems reading from it, an error will be reported to the client" """ self._last_block_sent += 1 if self._last_block_sent > constants.MAX_BLOCK_NUMBER: self._last_block_sent = 0 # Wrap around the block counter. try: last_size = 0 # current_block size before read. Used to check EOF. self._current_block = self._response_data.read(self._block_size) while ( len(self._current_block) != self._block_size and len(self._current_block) != last_size ): last_size = len(self._current_block) self._current_block += self._response_data.read( self._block_size - last_size ) except Exception as e: logging.exception("Error while reading from source: %s" % e) self._stats.error = { "error_code": constants.ERR_UNDEFINED, "error_message": "Error while reading from source", } self._transmit_error() self._should_stop = True def _transmit_data(self): """Method that deals with sending a block to the wire.""" if self._current_block is None: self._transmit_oack() return fmt = "!HH%ds" % len(self._current_block) packet = struct.pack( fmt, constants.OPCODE_DATA, self._last_block_sent, self._current_block ) self._get_listener().sendto(packet, self._peer) self._stats.packets_sent += 1 self._stats.bytes_sent += len(self._current_block) if len(self._current_block) < self._block_size: self._waiting_last_ack = True def _transmit_oack(self): """Method that deals with sending OACK datagrams on the wire.""" opts = [] for key, val in self._options.items(): fmt = str("%dsx%ds" % (len(key), len(val))) opts.append( struct.pack( fmt, bytes(key.encode("latin-1")), bytes(val.encode("latin-1")) ) ) opts.append(b"") fmt = str("!H") packet = struct.pack(fmt, constants.OPCODE_OACK) + b"\x00".join(opts) self._get_listener().sendto(packet, self._peer) self._stats.packets_sent += 1 def _transmit_error(self): """Transmits an error to the client and terminates the exchange.""" fmt = str( "!HH%dsx" % (len(self._stats.error["error_message"].encode("latin-1"))) ) packet = struct.pack( fmt, constants.OPCODE_ERROR, self._stats.error["error_code"], bytes(self._stats.error["error_message"].encode("latin-1")), ) self._get_listener().sendto(packet, self._peer) def get_response_data(self): """ This method has to be overridden and must return an object of type `ResponseData`. """ raise NotImplementedError() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/fbtftp/base_server.py0000644000076500000240000003042700000000000016403 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import collections import ipaddress import logging import selectors import socket import struct import threading import time import traceback from . import constants class ServerStats: def __init__(self, server_addr=None, interval=None): """ `ServerStats` represents a digest of what happened during the server's lifetime. This class exposes a counter interface with get/set/reset methods and an atomic get-and-reset. An instance of this class is passed to a periodic function that is executed by a background thread inside the `BaseServer` object. See `stats_callback` in the `BaseServer` constructor. If you use it in a metric publishing callback, remember to use atomic operations and to reset the counters to have a fresh start. E.g. see `get_and_reset_all_counters'. Args: server_addr (str): the server address, either v4 or v6. interval (int): stats interval in seconds. Note: `server_addr` and `interval` are provided by the `BaseServer` class. They are not used in this class, they are there for the programmer's convenience, in case one wants to use them. """ self.server_addr = server_addr self.interval = interval self.start_time = time.time() self._counters = collections.Counter() self._counters_lock = threading.Lock() def get_all_counters(self): """ Return all counters as a dictionary. This operation is atomic. Returns: dict: all the counters. """ with self._counters_lock: return dict(self._counters) def get_and_reset_all_counters(self): """ Return all counters as a dictionary and reset them. This operation is atomic. Returns: dict: all the counters """ with self._counters_lock: counters = dict(self._counters) self._counters.clear() return counters def get_counter(self, name): """ Get a counter value by name. Do not use this method if you have to reset a counter after getting it. Use `get_and_reset_counter` instead. Args: name (str): the counter Returns: int: the value of the counter """ return self._counters[name] def set_counter(self, name, value): """ Set a counter value by name, atomically. Args: name (str): counter's name value (str): counter's value """ with self._counters_lock: self._counters[name] = value def increment_counter(self, name, increment=1): """ Increment a counter value by name, atomically. The increment can be negative. Args: name (str): the counter's name increment (int): the increment step, defaults to 1. """ with self._counters_lock: self._counters[name] += increment def reset_counter(self, name): """ Reset counter atomically. Args: name (str): counter's name """ with self._counters_lock: self._counters[name] = 0 def get_and_reset_counter(self, name): """ Get and reset a counter value by name atomically. Args: name (str): counter's name Returns: : counter's value """ with self._counters_lock: value = self._counters[name] self._counters[name] = 0 return value def reset_all_counters(self): """ Reset all the counters atomically. """ with self._counters_lock: self._counters.clear() def duration(self): """ Return the server uptime using naive timestamps. Returns: float: uptime in seconds. """ return time.time() - self.start_time class BaseServer: def __init__( self, address, port, retries, timeout, server_stats_callback=None, stats_interval_seconds=constants.DATAPOINTS_INTERVAL_SECONDS, ): """ This base class implements the process which deals with accepting new requests. Note: This class doesn't have to be used directly, you must inherit from it and override the `get_handler()`` method to return an instance of `BaseHandler`. Args: address (str): address (IPv4 or IPv6) the server needs to bind to. port (int): the port the server needs to bind to. retries (int): number of retries, how many times the server has to retry sending a datagram before it will interrupt the communication. This is passed to the `BaseHandler` class. timeout (int): time in seconds, this is passed to the `BaseHandler` class. It used in two ways: - as timeout in `socket.socket.recvfrom()`. - as maximum time to expect an ACK from a client. server_stats_callback (callable): a callable, this gets called periodically by a background thread. The callable must accept one argument which is an instance of the `ServerStats` class. The statistics callback is not re-entrant, if you need this you have to implement your own locking logic. stats_interval_seconds (int): how often, in seconds, `server_stats_callback` will be executed. """ self._address = address self._port = port self._retries = retries self._timeout = timeout self._server_stats_callback = server_stats_callback # the format of the peer tuple is different for v4 and v6 self._family = socket.AF_INET6 if isinstance(ipaddress.ip_address(self._address), ipaddress.IPv4Address): self._family = socket.AF_INET self._listener = socket.socket(self._family, socket.SOCK_DGRAM) self._listener.setblocking(0) # non-blocking self._listener.bind((address, port)) self._selector = selectors.DefaultSelector() self._selector.register(self._listener, selectors.EVENT_READ) self._should_stop = False self._server_stats = ServerStats(address, stats_interval_seconds) self._metrics_timer = None def run(self, run_once=False): """ Run the infinite serving loop. Args: run_once (bool): If True it will exit the loop after first iteration. Note this is only used in unit tests. """ # First start of the server stats thread self.restart_stats_timer(run_once) while not self._should_stop: self.run_once() if run_once: break self._selector.close() self._listener.close() if self._metrics_timer is not None: self._metrics_timer.cancel() def _metrics_callback_wrapper(self, run_once=False): """ Runs the callback, catches and logs exceptions, reschedules a new run for the callback, only if run_once is False (this is used only in unit tests). """ logging.debug("Running the metrics callback") try: self._server_stats_callback(self._server_stats) except Exception as exc: logging.exception(str(exc)) if not run_once: self.restart_stats_timer() def restart_stats_timer(self, run_once=False): """ Start metric pushing timer thread, if a callback was specified. """ if self._server_stats_callback is None: logging.warning( "No callback specified for server statistics " "logging, will continue without" ) return self._metrics_timer = threading.Timer( self._server_stats.interval, self._metrics_callback_wrapper, [run_once] ) logging.debug( "Starting the metrics callback in {sec}s".format( sec=self._server_stats.interval ) ) self._metrics_timer.start() def run_once(self): """ Uses edge polling object (`socket.epoll`) as an event notification facility to know when data is ready to be retrived from the listening socket. See http://linux.die.net/man/4/epoll . """ events = self._selector.select() for key, mask in events: if not mask & selectors.EVENT_READ: continue if key.fd == self._listener.fileno(): self.on_new_data() continue def on_new_data(self): """ Deals with incoming RRQ packets. This is called by `run_once` when data is available on the listening socket. This method deals with extracting all the relevant information from the request (like file, transfer mode, path, and options). If all is good it will run the `get_handler` method, which returns a `BaseHandler` object. `BaseHandler` is a subclass of a `multiprocessing.Process` class so calling `start()` on it will cause a `fork()`. """ data, peer = self._listener.recvfrom(constants.DEFAULT_BLKSIZE) code = struct.unpack("!H", data[:2])[0] if code != constants.OPCODE_RRQ: logging.warning( "unexpected TFTP opcode %d, expected %d" % (code, constants.OPCODE_RRQ) ) return # extract options tokens = list(filter(bool, data[2:].decode("latin-1").split("\x00"))) if len(tokens) < 2 or len(tokens) % 2 != 0: logging.error( "Received malformed packet, ignoring " "(tokens length: {tl})".format(tl=len(tokens)) ) return path = tokens[0] options = collections.OrderedDict( [ ("mode", tokens[1].lower()), ("default_timeout", self._timeout), ("retries", self._retries), ] ) pos = 2 while pos < len(tokens): options[tokens[pos].lower()] = tokens[pos + 1] pos += 2 # fork a child process try: proc = self.get_handler((self._address, self._port), peer, path, options) if proc is None: logging.warning( "The handler is null! Not serving the request from %s", peer ) return proc.daemon = True proc.start() except Exception as e: logging.error( "creating a handler for %r raised an exception %s" % (path, e) ) logging.error(traceback.format_exc()) # Increment number of spawned TFTP workers in stats time frame self._server_stats.increment_counter("process_count") def get_handler(self, server_addr, peer, path, options): """ Returns an instance of `BaseHandler`. Note: This is a virtual method and must be overridden in a sub-class. This method must return an instance of `BaseHandler`. Args: server_addr (tuple): tuple containing ip of the server and listening port. peer (tuple): tuple containing ip and port of the client. path (string): the file path requested by the client options (dict): a dictionary containing the options the clients wants to negotiate. Example of options: - mode (string): can be netascii or octet. See RFC 1350. - retries (int) - timeout (int) - tsize (int): transfer size option. See RFC 1784. - blksize: size of blocks. See RFC 1783 and RFC 2349. """ raise NotImplementedError() def close(self): """ Stops the server, by setting a boolean flag which will be picked by the main while loop. """ self._should_stop = True ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/fbtftp/constants.py0000644000076500000240000000242100000000000016110 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # TFTP opcodes OPCODE_RRQ = 1 OPCODE_WRQ = 2 OPCODE_DATA = 3 OPCODE_ACK = 4 OPCODE_ERROR = 5 OPCODE_OACK = 6 # TFTP modes (encodings) MODE_NETASCII = "netascii" MODE_BINARY = "octet" # TFTP error codes ERR_UNDEFINED = 0 # Not defined, see error msg (if any) - RFC 1350. ERR_FILE_NOT_FOUND = 1 # File not found - RFC 1350. ERR_ACCESS_VIOLATION = 2 # Access violation - RFC 1350. ERR_DISK_FULL = 3 # Disk full or allocation exceeded - RFC 1350. ERR_ILLEGAL_OPERATION = 4 # Illegal TFTP operation - RFC 1350. ERR_UNKNOWN_TRANSFER_ID = 5 # Unknown transfer ID - RFC 1350. ERR_FILE_EXISTS = 6 # File already exists - RFC 1350. ERR_NO_SUCH_USER = 7 # No such user - RFC 1350. ERR_INVALID_OPTIONS = 8 # One or more options are invalid - RFC 2347. # TFTP's block number is an unsigned 16 bit integer so for large files and # small window size we need to support rollover. MAX_BLOCK_NUMBER = 65535 # this is the default blksize as defined by RFC 1350 DEFAULT_BLKSIZE = 512 # Metric-related constants # How many seconds to aggregate before sampling datapoints DATAPOINTS_INTERVAL_SECONDS = 60 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/fbtftp/netascii.py0000644000076500000240000000344700000000000015704 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import io class NetasciiReader: """ NetasciiReader encodes data coming from a reader into NetASCII. If the size of the returned data needs to be known in advance this will actually have to load the whole content of its underlying reader into memory which is suboptimal but also the only way in which we can make NetASCII work with the 'tsize' TFTP extension. Note: This is an internal class and should not be modified. """ def __init__(self, reader): self._reader = reader self._buffer = bytearray() self._slurp = None self._size = None def read(self, size): if self._slurp is not None: return self._slurp.read(size) data, buffer_size = bytearray(), 0 if self._buffer: buffer_size = len(self._buffer) data.extend(self._buffer) for char in self._reader.read(size - buffer_size): if char == ord("\n"): data.extend([ord("\r"), ord("\n")]) elif char == ord("\r"): data.extend([ord("\r"), 0]) else: data.append(char) self._buffer = bytearray(data[size:]) return data[:size] def close(self): self._reader.close() def size(self): if self._size is not None: return self._size slurp, size = io.BytesIO(), 0 while True: data = self.read(512) if not data: break size += slurp.write(data) self._slurp, self._size = slurp, size self._slurp.seek(0) return size ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1625583712.7130976 fbtftp-0.5/fbtftp.egg-info/0000755000076500000240000000000000000000000015215 5ustar00skozlovstaff././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625583712.0 fbtftp-0.5/fbtftp.egg-info/PKG-INFO0000644000076500000240000002072700000000000016322 0ustar00skozlovstaffMetadata-Version: 2.1 Name: fbtftp Version: 0.5 Summary: A python3 framework to build dynamic TFTP servers Home-page: https://www.github.com/facebook/fbtftp Author: Angelo Failla Author-email: pallotron@fb.com License: BSD Keywords: tftp daemon infrastructure provisioning netboot Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: License :: OSI Approved :: MIT License Classifier: Operating System :: POSIX :: Linux Classifier: Programming Language :: Python :: 3 :: Only Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Topic :: Software Development :: Libraries :: Application Frameworks Classifier: Topic :: System :: Boot Classifier: Topic :: Utilities Classifier: Intended Audience :: Developers Description-Content-Type: text/markdown License-File: LICENSE License-File: LICENSE-examples License-File: AUTHORS [![Build Status](https://travis-ci.org/facebook/fbtftp.svg?branch=master)](https://travis-ci.org/facebook/fbtftp) [![codebeat badge](https://codebeat.co/badges/2d4c7650-4752-4adf-a570-1948ecb4d6a8)](https://codebeat.co/projects/github-com-facebook-fbtftp) # What is fbtftp? `fbtftp` is Facebook's implementation of a dynamic TFTP server framework. It lets you create custom TFTP servers and wrap your own logic into it in a very simple manner. Facebook currently uses it in production, and it's deployed at global scale across all of our data centers. # Why did you do that? We love to use existing open source software and to contribute upstream, but sometimes it's just not enough at our scale. We ended up writing our own tftp framework and decided to open source it. `fbtftp` was born from the need of having an easy-to-configure and easy-to-expand TFTP server, that would work at large scale. The standard `in.tftpd` is a 20+ years old piece of software written in C that is very difficult to extend. `fbtftp` is written in `python3` and lets you plug your own logic to: * publish per session and server wide statistics to your infrastructure * define how response data is built: * can be a file from disk; * can be a file created dynamically; * you name it! # How do you use `fbtftp` at Facebook? We created our own Facebook-specific server based on the framework to: * stream static files (initrd and kernels) from our http repositories (no need to fill your tftp root directory with files); * generate grub2 per-machine configuration dynamically (no need to copy grub2 configuration files on disk); * publish per-server and per-connection statistics to our internal monitoring systems; * deployment is easy and "container-ready", just copy the application somewhere, start it and you are done. # Is it better than the other TFTP servers? It depends on your needs! `fbtftp` is written in Python 3 using a multiprocessing model; its primary focus is not speed, but flexibility and scalability. Yet it is fast enough at our datacenter scale :) It is well-suited for large installations where scalability and custom features are needed. # What does it support? The framework implements the following RFCs: * [RFC 1350](https://tools.ietf.org/html/rfc1350) (the main TFTP specification) * [RFC 2347](https://tools.ietf.org/html/rfc2347) (Option Extension) * [RFC 2348](https://tools.ietf.org/html/rfc2348) (Blocksize option) * [RFC 2349](https://tools.ietf.org/html/rfc2349) (Timeout Interval and Transfer Size Options). Note that the server framework only support RRQs (read only) operations. (Who uses WRQ TFTP requests in 2019? :P) # How does it work? All you need to do is understanding three classes and two callback functions, and you are good to go: * `BaseServer`: This class implements the process which deals with accepting new requests on the UDP port provided. Default TFTP parameters like timeout, port number and number of retries can be passed. This class doesn't have to be used directly, you must inherit from it and override `get_handler()` method to return an instance of `BaseHandler`. The class accepts a `server_stats_callback`, more about it below. the callback is not re-entrant, if you need this you have to implement your own locking logic. This callback is executed periodically and you can use it to publish server level stats to your monitoring infrastructure. A series of predefined counters are provided. Refer to the class documentation to find out more. * `BaseHandler`: This class deals with talking to a single client. This class lives into its separate process, process which is spawned by the `BaserServer` class, which will make sure to reap the child properly when the session is over. Do not use this class as is, instead inherit from it and override the `get_response_data()` method. Such method must return an instance of a subclass of `ResponseData`. * `ResponseData`: it's a file-like class that implements `read(num_bytes)`, `size()` and `close()`. As the previous two classes you'll have to inherit from this and implement those methods. This class basically let you define how to return the actual data * `server_stats_callback`: function that is called periodically (every 60 seconds by default). The callback is not re-entrant, if you need this you have to implement your own locking logic. This callback is executed periodically and you can use it to publish server level stats to your monitoring infrastructure. A series of predefined counters are provided. Refer to the class documentation to find out more. * `session_stats_callback`: function that is called when a client session is over. # Requirements * Linux (or any system that supports [`epoll`](http://linux.die.net/man/4/epoll)) * BSD (or any system that supports [`kqueue`](https://www.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2)) * Python 3.4+ # Installation `fbtftp` is distributed with the standard `distutils` package, so you can build it with: ``` python setup.py build ``` and install it with: ``` python setup.py install ``` Be sure to run as root if you want to install `fbtftp` system wide. You can also use a `virtualenv`, or install it as user by running: ``` python setup.py install --user ``` # Example Writing your own server is simple. Let's take a look at how to write a simple server that serves files from disk: ```python from fbtftp.base_handler import BaseHandler from fbtftp.base_handler import ResponseData from fbtftp.base_server import BaseServer import os class FileResponseData(ResponseData): def __init__(self, path): self._size = os.stat(path).st_size self._reader = open(path, 'rb') def read(self, n): return self._reader.read(n) def size(self): return self._size def close(self): self._reader.close() def print_session_stats(stats): print(stats) def print_server_stats(stats): counters = stats.get_and_reset_all_counters() print('Server stats - every {} seconds'.format(stats.interval)) print(counters) class StaticHandler(BaseHandler): def __init__(self, server_addr, peer, path, options, root, stats_callback): self._root = root super().__init__(server_addr, peer, path, options, stats_callback) def get_response_data(self): return FileResponseData(os.path.join(self._root, self._path)) class StaticServer(BaseServer): def __init__(self, address, port, retries, timeout, root, handler_stats_callback, server_stats_callback=None): self._root = root self._handler_stats_callback = handler_stats_callback super().__init__(address, port, retries, timeout, server_stats_callback) def get_handler(self, server_addr, peer, path, options): return StaticHandler( server_addr, peer, path, options, self._root, self._handler_stats_callback) def main(): server = StaticServer(address='::', port=69, retries=3, timeout=5, root='/var/tftproot', handler_stats_callback=print_session_stats, server_stats_callback=print_server_stats) try: server.run() except KeyboardInterrupt: server.close() if __name__ == '__main__': main() ``` # Who wrote it? `fbtftp` was created by Marcin Wyszynski (@marcinwyszynski) and Angelo Failla at Facebook Ireland. Other honorable contributors: * Andrea Barberio # License MIT License ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625583712.0 fbtftp-0.5/fbtftp.egg-info/SOURCES.txt0000644000076500000240000000075400000000000017107 0ustar00skozlovstaffAUTHORS CONTRIBUTING.md LICENSE LICENSE-examples MANIFEST.in README.md setup.cfg setup.py examples/server.py fbtftp/__init__.py fbtftp/base_handler.py fbtftp/base_server.py fbtftp/constants.py fbtftp/netascii.py fbtftp.egg-info/PKG-INFO fbtftp.egg-info/SOURCES.txt fbtftp.egg-info/dependency_links.txt fbtftp.egg-info/top_level.txt tests/base_handler_test.py tests/base_server_test.py tests/integration_test.py tests/malformed_request_test.py tests/netascii_test.py tests/server_stats_test.py././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625583712.0 fbtftp-0.5/fbtftp.egg-info/dependency_links.txt0000644000076500000240000000000100000000000021263 0ustar00skozlovstaff ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625583712.0 fbtftp-0.5/fbtftp.egg-info/top_level.txt0000644000076500000240000000000700000000000017744 0ustar00skozlovstafffbtftp ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1625583712.7173612 fbtftp-0.5/setup.cfg0000644000076500000240000000025500000000000014061 0ustar00skozlovstaff[nosetests] detailed-errors = 1 with-coverage = 1 cover-package = fbtftp cover-erase = 1 verbosity = 2 [flake8] max-line-length = 90 [egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578513.0 fbtftp-0.5/setup.py0000644000076500000240000000400600000000000013750 0ustar00skozlovstaff# Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from os import path from setuptools import find_packages, setup from setuptools.command.test import test as TestCommand # Inspired by the example at https://pytest.org/latest/goodpractises.html class NoseTestCommand(TestCommand): def finalize_options(self): TestCommand.finalize_options(self) self.test_args = [] self.test_suite = True def run_tests(self): # Run nose ensuring that argv simulates running nosetests directly import nose nose.run_exit(argv=["nosetests"]) here = path.abspath(path.dirname(__file__)) with open(path.join(here, "README.md"), encoding="utf-8") as f: long_description = f.read() setup( name="fbtftp", version="0.5", description="A python3 framework to build dynamic TFTP servers", long_description=long_description, long_description_content_type="text/markdown", author="Angelo Failla", author_email="pallotron@fb.com", license="BSD", classifiers=[ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Topic :: Software Development :: Libraries :: Application Frameworks", "Topic :: System :: Boot", "Topic :: Utilities", "Intended Audience :: Developers", ], keywords="tftp daemon infrastructure provisioning netboot", url="https://www.github.com/facebook/fbtftp", packages=find_packages(exclude=["tests"]), tests_require=["nose", "coverage", "mock"], cmdclass={"test": NoseTestCommand}, ) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1625583712.7157924 fbtftp-0.5/tests/0000755000076500000240000000000000000000000013400 5ustar00skozlovstaff././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/tests/base_handler_test.py0000644000076500000240000003534400000000000017431 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from collections import OrderedDict from unittest.mock import patch, Mock, call from fbtftp.netascii import NetasciiReader import socket import time import unittest from fbtftp.base_handler import BaseHandler, StringResponseData from fbtftp import constants class MockSocketListener: def __init__(self, network_queue, peer): self._network_queue = network_queue self._peer = peer def recvfrom(self, blocksize): return self._network_queue.pop(0), self._peer class MockHandler(BaseHandler): def __init__( self, server_addr, peer, path, options, stats_callback, network_queue=() ): self.response = StringResponseData("foo") super().__init__(server_addr, peer, path, options, stats_callback) self.network_queue = network_queue self.peer = peer self._listener = MockSocketListener(network_queue, peer) self._listener.sendto = Mock() self._listener.close = Mock() self._listener.settimeout = Mock() def get_response_data(self): """ returns a mock ResponseData object""" self._response_data = Mock() self._response_data.read = self.response.read self._response_data.size = self.response.size return self._response_data class testSessionHandler(unittest.TestCase): def setUp(self): self.options = OrderedDict( [ ("default_timeout", 10), ("retries", 2), ("mode", "netascii"), ("blksize", 1492), ("tsize", 0), ("timeout", 99), ] ) self.server_addr = ("127.0.0.1", 1234) self.peer = ("127.0.0.1", 5678) self.handler = MockHandler( server_addr=self.server_addr, peer=self.peer, path="want/bacon/file", options=self.options, stats_callback=self.stats_callback, ) def stats_callback(self): pass def init(self, universe=4): if universe == 4: server_addr = ("127.0.0.1", 1234) peer = ("127.0.0.1", 5678) else: server_addr = ("::1", 1234) peer = ("::1", 5678) handler = BaseHandler( server_addr=server_addr, peer=peer, path="want/bacon/file", options=self.options, stats_callback=self.stats_callback, ) self.assertEqual(handler._timeout, 10) self.assertEqual(handler._server_addr, server_addr) # make sure expire_ts is in the future self.assertGreater(handler._expire_ts, time.time()) self.assertEqual(handler._retries, 2) self.assertEqual(handler._block_size, constants.DEFAULT_BLKSIZE) self.assertEqual(handler._last_block_sent, 0) self.assertEqual(handler._retransmits, 0) self.assertEqual(handler._current_block, None) self.assertEqual(handler._should_stop, False) self.assertEqual(handler._path, "want/bacon/file") self.assertEqual(handler._options, self.options) self.assertEqual(handler._stats_callback, self.stats_callback) self.assertEqual(handler._peer, peer) self.assertIsInstance(handler._get_listener(), socket.socket) if universe == 6: self.assertEqual(handler._get_listener().family, socket.AF_INET6) else: self.assertEqual(handler._get_listener().family, socket.AF_INET) def testInitV6(self): self.init(universe=6) def testInitV4(self): self.init(universe=4) def testResponseDataException(self): server_addr = ("127.0.0.1", 1234) peer = ("127.0.0.1", 5678) with patch.object(MockHandler, "get_response_data") as mock: mock.side_effect = Exception("boom!") handler = MockHandler( server_addr=server_addr, peer=peer, path="want/bacon/file", options=self.options, stats_callback=self.stats_callback, ) self.assertEqual( handler._stats.error, {"error_message": "boom!", "error_code": 0} ) def testParseOptionsNetascii(self): self.handler._response_data = StringResponseData("foo\nbar\n") self.handler._parse_options() self.assertEqual( self.handler._stats.options_in, {"mode": "netascii", "blksize": 1492, "tsize": 0, "timeout": 99}, ) self.assertIsInstance(self.handler._response_data, NetasciiReader) self.assertEqual(self.handler._stats.blksize, 1492) # options acked by the server don't include the mode expected_opts_to_ack = self.options del expected_opts_to_ack["mode"] # tsize include the number of bytes in the response expected_opts_to_ack["tsize"] = str(self.handler._response_data.size()) self.assertEqual(self.handler._stats.options, expected_opts_to_ack) self.assertEqual(self.handler._stats.options_acked, expected_opts_to_ack) self.assertEqual(self.handler._tsize, int(expected_opts_to_ack["tsize"])) def testParseOptionsBadMode(self): options = { "default_timeout": 10, "retries": 2, "mode": "IamBadAndIShoudlFeelBad", "blksize": 1492, "tsize": 0, "timeout": 99, } self.handler = MockHandler( server_addr=self.server_addr, peer=self.peer, path="want/bacon/file", options=options, stats_callback=Mock(), ) self.handler._close = Mock() self.handler._parse_options() self.handler._close.assert_called_with() self.assertEqual( self.handler._stats.error["error_code"], constants.ERR_ILLEGAL_OPERATION ) self.assertTrue( self.handler._stats.error["error_message"].startswith("Unknown mode:") ) self.handler._get_listener().sendto.assert_called_with( # \x00\x05 == OPCODE_ERROR # \x00\x04 == ERR_ILLEGAL_OPERATION b"\x00\x05\x00\x04Unknown mode: 'IamBadAndIShoudlFeelBad'\x00", ("127.0.0.1", 5678), ) def testClose(self): options = { "default_timeout": 10, "retries": 2, "mode": "IamBadAndIShoudlFeelBad", "blksize": 1492, "tsize": 0, "timeout": 99, } self.handler = MockHandler( server_addr=self.server_addr, peer=self.peer, path="want/bacon/file", options=options, stats_callback=Mock(), ) self.handler._retransmits = 100 self.handler._close(True) self.assertEqual(self.handler._retransmits, 100) self.handler._stats_callback.assert_called_with(self.handler._stats) self.handler._get_listener().close.assert_called_with() self.handler._response_data.close.assert_called_with() self.handler._on_close = Mock() self.handler._on_close.side_effect = Exception("boom!") self.handler._close(True) def testRun(self): # mock methods self.handler._close = Mock() self.handler._transmit_error = Mock() self.handler._parse_options = Mock() self.handler._transmit_oack = Mock() self.handler._transmit_data = Mock() self.handler._next_block = Mock() self.handler._stats.error = {"error_message": "boom!", "error_code": 0} self.handler.run() self.handler._close.assert_called_with() self.handler._transmit_error.assert_called_with() self.handler._stats.error = {} self.handler._should_stop = True self.handler.run() self.handler._parse_options.assert_called_with() self.handler._transmit_oack.assert_called_with() self.handler._options = {} self.handler.run() self.handler._next_block.assert_called_with() self.handler._transmit_data.assert_called_with() def testRunOne(self): self.handler.on_new_data = Mock() self.handler._handle_timeout = Mock() self.handler._expire_ts = time.time() + 1000 self.handler.run_once() self.handler.on_new_data.assert_called_with() self.handler._expire_ts = time.time() - 1000 self.handler.run_once() self.handler.on_new_data.assert_called_with() self.handler._handle_timeout.assert_called_with() def testOnNewDataHandleAck(self): self.handler = MockHandler( server_addr=self.server_addr, peer=self.peer, path="want/bacon/file", options=self.options, stats_callback=self.stats_callback, # client acknolwedges DATA block 1, we expect to send DATA block 2 network_queue=[b"\x00\x04\x00\x01"], ) self.handler._last_block_sent = 1 self.handler.on_new_data() self.handler._get_listener().settimeout.assert_has_calls( [call(self.handler._timeout), call(None)] ) # data response sohuld look like this: # # 2 bytes 2 bytes n bytes # --------------------------------------- # | Opcode = 3 | Block # | Data | # --------------------------------------- self.handler._get_listener().sendto.assert_called_with( # client acknolwedges DATA block 1, we expect to send DATA block 2 b"\x00\x03\x00\x02foo", ("127.0.0.1", 5678), ) def testOnNewDataTimeout(self): self.handler._get_listener().recvfrom = Mock(side_effect=socket.timeout()) self.handler.on_new_data() self.assertFalse(self.handler._should_stop) self.assertEqual(self.handler._stats.error, {}) def testOnNewDataDifferentPeer(self): self.handler._get_listener().recvfrom = Mock( return_value=(b"data", ("1.2.3.4", "9999")) ) self.handler.on_new_data() self.assertTrue(self.handler._should_stop) def testOnNewDataOpCodeError(self): error = b"\x00\x05\x00\x04some_error\x00" self.handler._get_listener().recvfrom = Mock(return_value=(error, self.peer)) self.handler.on_new_data() self.assertTrue(self.handler._should_stop) self.handler._get_listener().sendto.assert_called_with(error, self.peer) def testOnNewDataNoAck(self): self.handler._get_listener().recvfrom = Mock( return_value=(b"\x00\x02\x00\x04", self.peer) ) self.handler.on_new_data() self.assertTrue(self.handler._should_stop) self.assertEqual( self.handler._stats.error, { "error_code": constants.ERR_ILLEGAL_OPERATION, "error_message": "I only do reads, really", }, ) def testHandleUnexpectedAck(self): self.handler._last_block_sent = 1 self.handler._reset_timeout = Mock() self.handler._next_block = Mock() self.handler._handle_ack(2) self.handler._reset_timeout.assert_not_called() def testHandleTimeout(self): self.handler._retries = 3 self.handler._retransmits = 2 self.handler._transmit_data = Mock() self.handler._handle_timeout() self.assertEqual(self.handler._retransmits, 3) self.handler._transmit_data.assert_called_with() self.assertEqual(self.handler._stats.error, {}) self.handler._retries = 1 self.handler._retransmits = 2 self.handler._handle_timeout() self.assertEqual( self.handler._stats.error, { "error_code": constants.ERR_UNDEFINED, "error_message": "timeout after 2 retransmits.", }, ) self.assertTrue(self.handler._should_stop) def testNextBlock(self): class MockResponse: def __init__(self, dataiter): self._dataiter = dataiter def read(self, size=0): try: return next(self._dataiter) except StopIteration: return None # single-packet file self.handler._last_block_sent = 0 self.handler._block_size = 1400 self.handler._response_data = StringResponseData("bacon") self.handler._next_block() self.assertEqual(self.handler._current_block, b"bacon") self.assertEqual(self.handler._last_block_sent, 1) # multi-packet file self.handler._last_block_sent = 0 self.handler._block_size = 1400 self.handler._response_data = StringResponseData("bacon" * 281) self.handler._next_block() self.assertEqual(self.handler._current_block, b"bacon" * 280) self.assertEqual(self.handler._last_block_sent, 1) self.handler._next_block() self.assertEqual(self.handler._current_block, b"bacon") self.assertEqual(self.handler._last_block_sent, 2) # partial read data = MockResponse(iter("bacon")) self.handler._last_block_sent = 0 self.handler._block_size = 1400 self.handler._response_data.read = data.read self.handler._next_block() self.assertEqual(self.handler._current_block, "bacon") self.assertEqual(self.handler._last_block_sent, 1) self.handler._last_block_sent = constants.MAX_BLOCK_NUMBER + 1 self.handler._next_block() self.assertEqual(self.handler._last_block_sent, 0) self.handler._response_data.read = Mock(side_effect=Exception("boom!")) self.handler._next_block() self.assertEqual( self.handler._stats.error, { "error_code": constants.ERR_UNDEFINED, "error_message": "Error while reading from source", }, ) self.assertTrue(self.handler._should_stop) def testTransmitData(self): # we have tested sending data so here we should just test the edge case # where there is no more data to send self.handler._current_block = b"" self.handler._transmit_data() self.handler._handle_ack(0) self.assertTrue(self.handler._should_stop) def testTransmitOACK(self): self.handler._options = {"opt1": "value1"} self.handler._get_listener().sendto = Mock() self.handler._stats.packets_sent = 1 self.handler._transmit_oack() self.assertEqual(self.handler._stats.packets_sent, 2) self.handler._get_listener().sendto.assert_called_with( # OACK code == 6 b"\x00\x06opt1\x00value1\x00", ("127.0.0.1", 5678), ) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/tests/base_server_test.py0000644000076500000240000001371000000000000017313 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from unittest.mock import patch, Mock import unittest from fbtftp.base_server import BaseServer MOCK_SOCKET_FILENO = 100 SELECTORS_EVENT_READ = 1 class MockSocketListener: def __init__(self, network_queue): self._network_queue = network_queue def recvfrom(self, blocksize): data = self._network_queue.pop(0) peer = "::1" # assuming v6, but this is invariant for this test return data, peer def fileno(self): # just a given socket fileno that will have to be matched by # testBaseServer.poll_mock below. This is to trick the # BaseServer.run_once()'s' select.epoll.poll() method... return MOCK_SOCKET_FILENO def close(self): pass class StaticServer(BaseServer): def __init__( self, address, port, retries, timeout, root, stats_callback, stats_interval, network_queue, ): super().__init__( address, port, retries, timeout, stats_callback, stats_interval ) self._root = root # mock the network self._listener = MockSocketListener(network_queue) self._handler = None def get_handler(self, addr, peer, path, options): """ returns a mock handler """ self._handler = Mock(addr, peer, path, options) self._handler.addr = addr self._handler.peer = peer self._handler.path = path self._handler.options = options self._handler.start = Mock() return self._handler class testBaseServer(unittest.TestCase): def setUp(self): self.host = "::" # assuming v6, but this is invariant for this test self.port = 0 # let the kernel choose self.timeout = 100 self.retries = 200 self.interval = 1 self.network_queue = [] def select_mock(self): """ mock the select.epoll.poll() method, returns an iterable containing a list of (fileno, eventmask), the fileno constant matches the MockSocketListener.fileno() method, eventmask matches select.EPOLLIN """ if len(self.network_queue) > 0: obj = lambda: None obj.fd = MOCK_SOCKET_FILENO return [(obj, SELECTORS_EVENT_READ)] return [] def prepare_and_run(self, network_queue): server = StaticServer( self.host, self.port, self.retries, self.timeout, None, Mock(), self.interval, self.network_queue, ) server._server_stats.increment_counter = Mock() server.run(run_once=True) server.close() self.assertTrue(server._should_stop) self.assertTrue(server._handler.daemon) server._handler.start.assert_called_with() self.assertEqual(server._handler.addr, ("::", 0)) self.assertEqual(server._handler.peer, "::1") server._server_stats.increment_counter.assert_called_with("process_count") return server._handler @patch("selectors.DefaultSelector") def testRRQ(self, selector_mock): # link the self.poll_mock() method with the select.epoll patched object selector_mock.return_value.select.side_effect = self.select_mock self.network_queue = [ # RRQ + file name + mode + optname + optvalue b"\x00\x01some_file\x00binascii\x00opt1_key\x00opt1_val\x00" ] handler = self.prepare_and_run(self.network_queue) self.assertEqual(handler.path, "some_file") self.assertEqual( handler.options, { "default_timeout": 100, "mode": "binascii", "opt1_key": "opt1_val", "retries": 200, }, ) def start_timer_and_wait_for_callback(self, stats_callback): server = StaticServer( self.host, self.port, self.retries, self.timeout, None, stats_callback, self.interval, [], ) server.restart_stats_timer(run_once=True) # wait for the stats callback to be executed for _ in range(10): import time time.sleep(1) if stats_callback.mock_called: print("Stats callback executed") break server._metrics_timer.cancel() def testTimer(self): stats_callback = Mock() self.start_timer_and_wait_for_callback(stats_callback) def testTimerNoCallBack(self): stats_callback = None server = StaticServer( self.host, self.port, self.retries, self.timeout, None, stats_callback, self.interval, [], ) ret = server.restart_stats_timer(run_once=True) self.assertIsNone(ret) def testCallbackException(self): stats_callback = Mock() stats_callback.side_effect = Exception("boom!") self.start_timer_and_wait_for_callback(stats_callback) @patch("selectors.DefaultSelector") def testUnexpectedOpsCode(self, selector_mock): # link the self.poll_mock() emthod with the select.epoll patched object selector_mock.return_value.select.side_effect = self.select_mock self.network_queue = [ # RRQ + file name + mode + optname + optvalue b"\x00\xffsome_file\x00binascii\x00opt1_key\x00opt1_val\x00" ] server = StaticServer( self.host, self.port, self.retries, self.timeout, None, Mock(), self.interval, self.network_queue, ) server.run(run_once=True) self.assertIsNone(server._handler) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/tests/integration_test.py0000644000076500000240000001112300000000000017332 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from distutils.spawn import find_executable import logging import os import subprocess import tempfile import unittest from fbtftp.base_handler import ResponseData, BaseHandler from fbtftp.base_server import BaseServer class FileResponseData(ResponseData): def __init__(self, path): self._size = os.stat(path).st_size self._reader = open(path, "rb") def read(self, n): return self._reader.read(n) def size(self): return self._size def close(self): self._reader.close() class StaticHandler(BaseHandler): def __init__(self, server_addr, peer, path, options, root, stats_callback): self._root = root super().__init__(server_addr, peer, path, options, stats_callback) def get_response_data(self): return FileResponseData(os.path.join(self._root, self._path)) class StaticServer(BaseServer): def __init__(self, address, port, retries, timeout, root, stats_callback): self._root = root self._stats_callback = stats_callback super().__init__(address, port, retries, timeout) def get_handler(self, server_addr, peer, path, options): return StaticHandler( server_addr, peer, path, options, self._root, self._stats_callback ) def busyboxClient(filename, blksize=1400, port=1069): # We use busybox cli to test various bulksizes p = subprocess.Popen( [ find_executable("busybox"), "tftp", "-l", "/dev/stdout", "-r", filename, "-g", "-b", str(blksize), "localhost", str(port), ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) stdout, stderr = p.communicate(timeout=1) return (stdout, stderr, p.returncode) @unittest.skipUnless( find_executable("busybox"), "busybox binary not present, install it if you want to run " "integration tests", ) class integrationTest(unittest.TestCase): def setUp(self): logging.getLogger().setLevel(logging.DEBUG) self.tmpdirname = tempfile.TemporaryDirectory() logging.info("Created temporary directory %s" % self.tmpdirname) self.tmpfile = "%s/%s" % (self.tmpdirname.name, "test.file") self.tmpfile_data = os.urandom(512 * 5) with open(self.tmpfile, "wb") as fout: fout.write(self.tmpfile_data) self.called_stats_times = 0 def tearDown(self): self.tmpdirname.cleanup() def stats(self, data): logging.debug("Inside stats function") self.assertEqual(data.peer[0], "127.0.0.1") self.assertEqual(data.file_path, self.tmpfile) self.assertEqual({}, data.error) self.assertGreater(data.start_time, 0) self.assertTrue(data.packets_acked == data.packets_sent - 1) self.assertEqual(2560, data.bytes_sent) self.assertEqual(round(data.bytes_sent / self.blksize), data.packets_sent - 1) self.assertEqual(0, data.retransmits) self.assertEqual(self.blksize, data.blksize) self.called_stats_times += 1 def testDownloadBulkSizes(self): for b in (512, 1400): self.blksize = b server = StaticServer( "::", 0, # let the kernel decide the port 2, 2, self.tmpdirname.name, self.stats, ) child_pid = os.fork() if child_pid: # I am the parent try: (p_stdout, p_stderr, p_returncode) = busyboxClient( self.tmpfile, blksize=self.blksize, # use the port chosen for the server by the kernel port=server._listener.getsockname()[1], ) self.assertEqual(0, p_returncode) if p_returncode != 0: self.fail((p_stdout, p_stderr, p_returncode)) self.assertEqual(self.tmpfile_data, p_stdout) finally: os.kill(child_pid, 15) os.waitpid(child_pid, 0) else: # I am the child try: server.run() except KeyboardInterrupt: server.close() self.assertEqual(1, self.called_stats_times) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/tests/malformed_request_test.py0000644000076500000240000000404400000000000020531 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import tempfile import unittest from fbtftp.base_server import BaseServer """ This script stresses the TFTP server by sending malformed RRQ packets and checking whether it crashed. NOTE: this test ONLY checks if the server crashed, no output or return code is checked. """ RRQ = b"\x00\x01" # if you want to add more packets for the tests, do it here TEST_PAYLOADS = ( RRQ + b"some_fi", RRQ + b"some_file\x00", RRQ + b"some_file\x00bina", RRQ + b"some_file\x00binascii\x00", RRQ + b"some_file\x00binascii\x00a", RRQ + b"some_file\x00binascii\x00a\x00", RRQ + b"some_file\x00binascii\x00a\x00b\x00", ) class MockSocketListener: def __init__(self, network_queue): self._network_queue = network_queue def recvfrom(self, blocksize): data = self._network_queue.pop(0) peer = "::1" # assuming v6, but this is invariant for this test return data, peer def close(self): pass class StaticServer(BaseServer): def __init__( self, address, port, retries, timeout, root, stats_callback, network_queue ): super().__init__(address, port, retries, timeout) self._root = root # mock the network self._listener = MockSocketListener(network_queue) class TestServerMalformedPacket(unittest.TestCase): def setUp(self): # this is removed automatically when the test ends self.tmpdir = tempfile.TemporaryDirectory() self.host = "::" # assuming v6, but this is invariant for this test self.port = 0 # let the kernel choose self.timeout = 2 def testMalformedPackets(self): for payload in TEST_PAYLOADS: server = StaticServer( self.host, self.port, 2, 2, self.tmpdir, None, [payload] ) server.on_new_data() server.close() del server ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/tests/netascii_test.py0000644000076500000240000000336500000000000016617 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import unittest from fbtftp.netascii import NetasciiReader from fbtftp.base_handler import StringResponseData class testNetAsciiReader(unittest.TestCase): def testNetAsciiReader(self): tests = [ # content, expected output ( "foo\nbar\nand another\none", bytearray(b"foo\r\nbar\r\nand another\r\none"), ), ( "foo\r\nbar\r\nand another\r\none", bytearray(b"foo\r\x00\r\nbar\r\x00\r\nand another\r\x00\r\none"), ), ] for input_content, expected in tests: with self.subTest(content=input_content): resp_data = StringResponseData(input_content) n = NetasciiReader(resp_data) self.assertGreater(n.size(), len(input_content)) output = n.read(512) self.assertEqual(output, expected) n.close() def testNetAsciiReaderBig(self): input_content = "I\nlike\ncrunchy\nbacon\n" for _ in range(5): input_content += input_content resp_data = StringResponseData(input_content) n = NetasciiReader(resp_data) self.assertGreater(n.size(), 0) self.assertGreater(n.size(), len(input_content)) block_size = 512 output = bytearray() while True: c = n.read(block_size) output += c if len(c) < block_size: break self.assertEqual(input_content.count("\n"), output.count(b"\r\n")) n.close() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1625578385.0 fbtftp-0.5/tests/server_stats_test.py0000644000076500000240000000717200000000000017544 0ustar00skozlovstaff#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import collections import time import unittest import threading from unittest.mock import patch from fbtftp.base_server import ServerStats class testServerStats(unittest.TestCase): @patch("threading.Lock") def setUp(self, mock): self.st = ServerStats(server_addr="127.0.0.1", interval=2) self.start_time = time.time() self.assertEqual(self.st.server_addr, "127.0.0.1") self.assertEqual(self.st.interval, 2) self.assertLessEqual(self.st.start_time, self.start_time) self.assertIsInstance(self.st._counters, collections.Counter) self.assertIsInstance(self.st._counters_lock, type(threading.Lock())) self.st._counters_lock = mock() def testSetGetCounters(self): self.st.set_counter("testcounter", 100) self.assertEqual(self.st.get_counter("testcounter"), 100) self.assertEqual(self.st._counters_lock.__enter__.call_count, 1) self.assertEqual(self.st._counters_lock.__exit__.call_count, 1) def testIncrementCounter(self): self.st.set_counter("testcounter", 100) self.st.increment_counter("testcounter") self.assertEqual(self.st.get_counter("testcounter"), 101) self.assertEqual(self.st._counters_lock.__enter__.call_count, 2) self.assertEqual(self.st._counters_lock.__exit__.call_count, 2) def testResetCounter(self): self.st.set_counter("testcounter", 100) self.assertEqual(self.st.get_counter("testcounter"), 100) self.st.reset_counter("testcounter") self.assertEqual(self.st.get_counter("testcounter"), 0) self.assertEqual(self.st._counters_lock.__enter__.call_count, 2) self.assertEqual(self.st._counters_lock.__exit__.call_count, 2) def testGetAndResetCounter(self): self.st.set_counter("testcounter", 100) self.assertEqual(self.st.get_and_reset_counter("testcounter"), 100) self.assertEqual(self.st.get_counter("testcounter"), 0) self.assertEqual(self.st._counters_lock.__enter__.call_count, 2) self.assertEqual(self.st._counters_lock.__exit__.call_count, 2) def testGetAllCounters(self): self.st.set_counter("testcounter1", 100) self.st.set_counter("testcounter2", 200) counters = self.st.get_all_counters() self.assertEqual(len(counters), 2) self.assertEqual(self.st._counters_lock.__enter__.call_count, 3) self.assertEqual(self.st._counters_lock.__exit__.call_count, 3) def testGetAndResetAllCounters(self): self.st.set_counter("testcounter1", 100) self.st.set_counter("testcounter2", 200) counters = self.st.get_and_reset_all_counters() self.assertEqual(len(counters), 2) self.assertEqual(counters["testcounter1"], 100) self.assertEqual(counters["testcounter2"], 200) self.assertEqual(self.st._counters_lock.__enter__.call_count, 3) self.assertEqual(self.st._counters_lock.__exit__.call_count, 3) def testResetAllCounters(self): self.st.set_counter("testcounter1", 100) self.st.set_counter("testcounter2", 200) self.st.reset_all_counters() self.assertEqual(self.st.get_counter("testcounter1"), 0) self.assertEqual(self.st.get_counter("testcounter2"), 0) self.assertEqual(self.st._counters_lock.__enter__.call_count, 3) self.assertEqual(self.st._counters_lock.__exit__.call_count, 3) def testDuration(self): self.assertGreater(self.st.duration(), 0)