pax_global_header00006660000000000000000000000064137033456220014517gustar00rootroot0000000000000052 comment=6d699aed899784dfae5fac28e29567936bed81a3 JordanMilne-Advocate-6d699ae/000077500000000000000000000000001370334562200160765ustar00rootroot00000000000000JordanMilne-Advocate-6d699ae/.coveragerc000066400000000000000000000007731370334562200202260ustar00rootroot00000000000000[run] omit = advocate/packages/* [report] # Regexes for lines to exclude from consideration exclude_lines = # Have to re-enable the standard pragma pragma: no cover # Don't complain about missing debug-only code: def __repr__ if self\.debug # Don't complain if tests don't hit defensive assertion code: raise AssertionError raise NotImplementedError # Don't complain if non-runnable code isn't run: if 0: if __name__ == .__main__.: ignore_errors = True JordanMilne-Advocate-6d699ae/.gitignore000066400000000000000000000001701370334562200200640ustar00rootroot00000000000000*~ .*.sw? *.pyc *.pyo .cache .DS_Store *.diff *.patch *.idea *.egg-info .eggs .coverage /build /dist /dev_packages /env JordanMilne-Advocate-6d699ae/.travis.yml000066400000000000000000000012571370334562200202140ustar00rootroot00000000000000sudo: false language: python cache: false python: - "2.7" - "3.6" - "3.7" - "3.8" env: - REQUESTS_VERSION="2.18.4" - REQUESTS_VERSION="2.19.1" - REQUESTS_VERSION="2.20.1" - REQUESTS_VERSION="2.21.0" - REQUESTS_VERSION="2.22.0" - REQUESTS_VERSION="2.23.0" - REQUESTS_VERSION="2.24.0" install: - pip install --src build/ -e git+https://github.com/psf/requests@v${REQUESTS_VERSION}#egg=requests -r requirements-test.txt script: - pytest --cov=advocate --cov-config=.coveragerc - pushd build/requests && PYTHONPATH=$TRAVIS_BUILD_DIR pytest -p requests_pytest_plugin && popd # Coverage metrics before_install: - pip install codecov after_success: - codecov JordanMilne-Advocate-6d699ae/LICENSE000066400000000000000000000010511370334562200171000ustar00rootroot00000000000000Copyright 2015 Jordan Milne Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. JordanMilne-Advocate-6d699ae/README.rst000066400000000000000000000215431370334562200175720ustar00rootroot00000000000000.. role:: python(code) :language: python Advocate ======== .. image:: https://travis-ci.org/JordanMilne/Advocate.svg?branch=master :target: https://travis-ci.org/JordanMilne/Advocate/ .. image:: https://codecov.io/github/JordanMilne/Advocate/coverage.svg?branch=master :target: https://codecov.io/github/JordanMilne/Advocate .. image:: https://img.shields.io/pypi/pyversions/advocate.svg .. image:: https://img.shields.io/pypi/v/advocate.svg :target: https://pypi.python.org/pypi/advocate Advocate is a set of tools based around the `requests library `_ for safely making HTTP requests on behalf of a third party. Specifically, it aims to prevent common techniques that enable `SSRF attacks `_. Advocate was inspired by `fin1te's SafeCurl project `_. Installation ============ .. code-block:: bash pip install advocate Advocate is officially supported on CPython 2.7+, CPython 3.4+ and PyPy 2. PyPy 3 may work as well, but you'll need a copy of the ipaddress module from elsewhere. Examples ======== Advocate is more-or-less a drop-in replacement for requests. In most cases you can just replace "requests" with "advocate" where necessary and be good to go: .. code-block:: python >>> import advocate >>> print advocate.get("http://google.com/") Advocate also provides a subclassed :python:`requests.Session` with sane defaults for validation already set up: .. code-block:: python >>> import advocate >>> sess = advocate.Session() >>> print sess.get("http://google.com/") >>> print sess.get("http://localhost/") advocate.exceptions.UnacceptableAddressException: ('localhost', 80) All of the wrapped request functions accept a :python:`validator` kwarg where you can set additional rules: .. code-block:: python >>> import advocate >>> validator = advocate.AddrValidator(hostname_blacklist={"*.museum",}) >>> print advocate.get("http://educational.MUSEUM/", validator=validator) advocate.exceptions.UnacceptableAddressException: educational.MUSEUM If you require more advanced rules than the defaults, but don't want to have to pass the validator kwarg everywhere, there's :python:`RequestsAPIWrapper` . You can define a wrapper in a common file and import it instead of advocate: .. code-block:: python >>> from advocate import AddrValidator, RequestsAPIWrapper >>> from advocate.packages import ipaddress >>> dougs_advocate = RequestsAPIWrapper(AddrValidator(ip_blacklist={ ... # Contains data incomprehensible to mere mortals ... ipaddress.ip_network("42.42.42.42/32") ... })) >>> print dougs_advocate.get("http://42.42.42.42/") advocate.exceptions.UnacceptableAddressException: ('42.42.42.42', 80) Other than that, you can do just about everything with Advocate that you can with an unwrapped requests. Advocate passes requests' test suite with the exception of tests that require :python:`Session.mount()`. Conditionally bypassing protection ================================== If you want to allow certain users to bypass Advocate's restrictions, just use plain 'ol requests by doing something like: .. code-block:: python if user == "mr_skeltal": requests_module = requests else: requests_module = advocate resp = requests_module.get("http://example.com/doot_doot") requests-futures support ======================== A thin wrapper around `requests-futures `_ is provided to ease writing async-friendly code: .. code-block:: python >>> from advocate.futures import FuturesSession >>> sess = FuturesSession() >>> fut = sess.get("http://example.com/") >>> fut >>> fut.result() You can do basically everything you can do with regular :python:`FuturesSession` s and :python:`advocate.Session` s: .. code-block:: python >>> from advocate import AddrValidator >>> from advocate.futures import FuturesSession >>> sess = FuturesSession(max_workers=20, validator=AddrValidator(hostname_blacklist={"*.museum"})) >>> fut = sess.get("http://anice.museum/") >>> fut >>> fut.result() Traceback (most recent call last): # [...] advocate.exceptions.UnacceptableAddressException: anice.museum When should I use Advocate? =========================== Any time you're fetching resources over HTTP for / from someone you don't trust! When should I not use Advocate? =============================== That's a tough one. There are a few cases I can think of where I wouldn't: * When good, safe support for IPv6 is important * When internal hosts use globally routable addresses and you can't guess their prefix to blacklist it ahead of time * You already have a good handle on network security within your network Actually, if you're comfortable enough with Squid and network security, you should set up a secured Squid instance on a segregated subnet and proxy through that instead. Advocate attempts to guess whether an address references an internal host and block access, but it's definitely preferable to proxy through a host can't access anything internal in the first place! Of course, if you're writing an app / library that's meant to be usable OOTB on other people's networks, Advocate + a user-configurable blacklist is probably the safer bet. This seems like it's been done before ===================================== There've been a few similar projects, but in my opinion Advocate's approach is the best because: It sees URLs the same as the underlying HTTP library ---------------------------------------------------- Parsing URLs is hard, and no two URL parsers seem to behave exactly the same. The tiniest differences in parsing between your validator and the underlying HTTP library can lead to vulnerabilities. For example, differences between PHP's :python:`parse_url` and cURL's URL parser `allowed a blacklist bypass in SafeCurl `_. Advocate doesn't do URL parsing at all, and lets requests handle it. Advocate only looks at the address requests actually tries to open a socket to. It deals with DNS rebinding --------------------------- Two consecutive calls to :python:`socket.getaddrinfo` aren't guaranteed to return the same info, depending on the system configuration. If the "safe" looking record TTLs between the verification lookup and the lookup for actually opening the socket, we may end up connecting to a very different server than the one we OK'd! Advocate gets around this by only using one :python:`getaddrinfo` call for both verification and connecting the socket. In pseudocode: .. code-block:: python def connect_socket(host, port): for res in socket.getaddrinfo(host, port): # where `res` will be a tuple containing the IP for the host if not is_blacklisted(res): # ... connect the socket using `res` See `Wikipedia's article on DNS rebinding attacks `_ for more info. It handles redirects sanely --------------------------- Most of the other SSRF-prevention libs cover this, but I've seen a lot of sample code online that doesn't. Advocate will catch it since it inspects *every* connection attempt the underlying HTTP lib makes. TODO ==== Proper IPv6 Support? -------------------- Advocate's IPv6 support is still a work-in-progress, since I'm not that familiar with the spec, and there are so many ways to tunnel IPv4 over IPv6, as well as other non-obvious gotchas. IPv6 records are ignored by default for now, but you can enable by using an :python:`AddrValidator` with :python:`allow_ipv6=True`. It should mostly work as expected, but Advocate's approach might not even make sense with most IPv6 deployments, see `Issue #3 `_ for more info. If you can think of any improvements to the IPv6 handling, please submit an issue or PR! Caveats ======= * This is beta-quality software, the API might change without warning! * :python:`mount()` ing other adapters is disallowed to prevent Advocate's validating adapters from being clobbered. * Advocate does not, and might never support the use of HTTP proxies. * Proper IPv6 support is still a WIP as noted above. Acknowledgements ================ * https://github.com/fin1te/safecurl for inspiration * https://github.com/kennethreitz/requests for the lovely requests module * https://bitbucket.org/kwi/py2-ipaddress for the backport of ipaddress * https://github.com/hakobe/paranoidhttp a similar project targeting golang * https://github.com/uber-common/paranoid-request a similar project targeting Node * http://search.cpan.org/~tsibley/LWP-UserAgent-Paranoid/ a similar project targeting Perl 5 JordanMilne-Advocate-6d699ae/advocate/000077500000000000000000000000001370334562200176645ustar00rootroot00000000000000JordanMilne-Advocate-6d699ae/advocate/__init__.py000066400000000000000000000006701370334562200220000ustar00rootroot00000000000000__version__ = "1.0.0" from requests import utils from requests.models import Request, Response, PreparedRequest from requests.status_codes import codes from requests.exceptions import ( RequestException, Timeout, URLRequired, TooManyRedirects, HTTPError, ConnectionError ) from .adapters import ValidatingHTTPAdapter from .api import * from .addrvalidator import AddrValidator from .exceptions import UnacceptableAddressException JordanMilne-Advocate-6d699ae/advocate/adapters.py000066400000000000000000000021441370334562200220420ustar00rootroot00000000000000from requests.adapters import HTTPAdapter, DEFAULT_POOLBLOCK from .addrvalidator import AddrValidator from .exceptions import ProxyDisabledException from .poolmanager import ValidatingPoolManager class ValidatingHTTPAdapter(HTTPAdapter): __attrs__ = HTTPAdapter.__attrs__ + ['_validator'] def __init__(self, *args, **kwargs): self._validator = kwargs.pop('validator', None) if not self._validator: self._validator = AddrValidator() super(ValidatingHTTPAdapter, self).__init__(*args, **kwargs) def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs): self._pool_connections = connections self._pool_maxsize = maxsize self._pool_block = block self.poolmanager = ValidatingPoolManager( num_pools=connections, maxsize=maxsize, block=block, validator=self._validator, **pool_kwargs ) def proxy_manager_for(self, proxy, **proxy_kwargs): raise ProxyDisabledException("Proxies cannot be used with Advocate") JordanMilne-Advocate-6d699ae/advocate/addrvalidator.py000066400000000000000000000244111370334562200230600ustar00rootroot00000000000000import functools import fnmatch import re import six import netifaces from .exceptions import NameserverException from .packages import ipaddress def canonicalize_hostname(hostname): """Lowercase and punycodify a hostname""" # We do the lowercasing after IDNA encoding because we only want to # lowercase the *ASCII* chars. # TODO: The differences between IDNA2003 and IDNA2008 might be relevant # to us, but both specs are damn confusing. return six.text_type(hostname.encode("idna").lower(), 'utf-8') def determine_local_addresses(): """Get all IPs that refer to this machine according to netifaces""" ips = [] for interface in netifaces.interfaces(): if_families = netifaces.ifaddresses(interface) for family_kind in {netifaces.AF_INET, netifaces.AF_INET6}: addrs = if_families.get(family_kind, []) for addr in (x.get("addr", "") for x in addrs): if family_kind == netifaces.AF_INET6: # We can't do anything sensible with the scope here addr = addr.split("%")[0] ips.append(ipaddress.ip_network(addr)) return ips def add_local_address_arg(func): """Add the "_local_addresses" kwarg if it's missing IMO this information shouldn't be cached between calls (what if one of the adapters got a new IP at runtime?,) and we don't want each function to recalculate it. Just recalculate it if the caller didn't provide it for us. """ @functools.wraps(func) def wrapper(self, *args, **kwargs): if "_local_addresses" not in kwargs: if self.autodetect_local_addresses: kwargs["_local_addresses"] = determine_local_addresses() else: kwargs["_local_addresses"] = [] return func(self, *args, **kwargs) return wrapper class AddrValidator(object): _6TO4_RELAY_NET = ipaddress.ip_network("192.88.99.0/24") # Just the well known prefix, DNS64 servers can set their own # prefix, but in practice most probably don't. _DNS64_WK_PREFIX = ipaddress.ip_network("64:ff9b::/96") DEFAULT_PORT_WHITELIST = {80, 8080, 443, 8443, 8000} def __init__( self, ip_blacklist=None, ip_whitelist=None, port_whitelist=None, port_blacklist=None, hostname_blacklist=None, allow_ipv6=False, allow_teredo=False, allow_6to4=False, allow_dns64=False, autodetect_local_addresses=True, ): if not port_blacklist and not port_whitelist: # An assortment of common HTTPS? ports. port_whitelist = self.DEFAULT_PORT_WHITELIST.copy() self.ip_blacklist = ip_blacklist or set() self.ip_whitelist = ip_whitelist or set() self.port_blacklist = port_blacklist or set() self.port_whitelist = port_whitelist or set() # TODO: ATM this can contain either regexes or globs that are converted # to regexes upon every check. Create a collection that automagically # converts them to regexes on insert? self.hostname_blacklist = hostname_blacklist or set() self.allow_ipv6 = allow_ipv6 self.allow_teredo = allow_teredo self.allow_6to4 = allow_6to4 self.allow_dns64 = allow_dns64 self.autodetect_local_addresses = autodetect_local_addresses @add_local_address_arg def is_ip_allowed(self, addr_ip, _local_addresses=None): if not isinstance(addr_ip, (ipaddress.IPv4Address, ipaddress.IPv6Address)): addr_ip = ipaddress.ip_address(addr_ip) # The whitelist should take precedence over the blacklist so we can # punch holes in blacklisted ranges if any(addr_ip in net for net in self.ip_whitelist): return True if any(addr_ip in net for net in self.ip_blacklist): return False if any(addr_ip in net for net in _local_addresses): return False if addr_ip.version == 4: if not addr_ip.is_private: # IPs for carrier-grade NAT. Seems weird that it doesn't set # `is_private`, but we need to check `not is_global` if not ipaddress.ip_network(addr_ip).is_global: return False elif addr_ip.version == 6: # You'd better have a good reason for enabling IPv6 # because Advocate's techniques don't work well without NAT. if not self.allow_ipv6: return False # v6 addresses can also map to IPv4 addresses! Tricky! v4_nested = [] if addr_ip.ipv4_mapped: v4_nested.append(addr_ip.ipv4_mapped) # WTF IPv6? Why you gotta have a billion tunneling mechanisms? # XXX: Do we even really care about these? If we're tunneling # through public servers we shouldn't be able to access # addresses on our private network, right? if addr_ip.sixtofour: if not self.allow_6to4: return False v4_nested.append(addr_ip.sixtofour) if addr_ip.teredo: if not self.allow_teredo: return False # Check both the client *and* server IPs v4_nested.extend(addr_ip.teredo) if addr_ip in self._DNS64_WK_PREFIX: if not self.allow_dns64: return False # When using the well-known prefix the last 4 bytes # are the IPv4 addr v4_nested.append(ipaddress.ip_address(addr_ip.packed[-4:])) if not all(self.is_ip_allowed(addr_v4) for addr_v4 in v4_nested): return False # fec0::*, apparently deprecated? if addr_ip.is_site_local: return False else: raise ValueError("Unsupported IP version(?): %r" % addr_ip) # 169.254.XXX.XXX, AWS uses these for autoconfiguration if addr_ip.is_link_local: return False # 127.0.0.1, ::1, etc. if addr_ip.is_loopback: return False if addr_ip.is_multicast: return False # 192.168.XXX.XXX, 10.XXX.XXX.XXX if addr_ip.is_private: return False # 255.255.255.255, ::ffff:XXXX:XXXX (v6->v4) mapping if addr_ip.is_reserved: return False # There's no reason to connect directly to a 6to4 relay if addr_ip in self._6TO4_RELAY_NET: return False # 0.0.0.0 if addr_ip.is_unspecified: return False # It doesn't look bad, so... it's must be ok! return True def _hostname_matches_pattern(self, hostname, pattern): # If they specified a string, just assume they only want basic globbing. # This stops people from not realizing they're dealing in REs and # not escaping their periods unless they specifically pass in an RE. # This has the added benefit of letting us sanely handle globbed # IDNs by default. if isinstance(pattern, six.string_types): # convert the glob to a punycode glob, then a regex pattern = fnmatch.translate(canonicalize_hostname(pattern)) hostname = canonicalize_hostname(hostname) # Down the line the hostname may get treated as a null-terminated string # (as with `socket.getaddrinfo`.) Try to account for that. # # >>> socket.getaddrinfo("example.com\x00aaaa", 80) # [(2, 1, 6, '', ('93.184.216.34', 80)), [...] no_null_hostname = hostname.split("\x00")[0] return any(re.match(pattern, x.strip(".")) for x in (no_null_hostname, hostname)) def is_hostname_allowed(self, hostname): # Sometimes (like with "external" services that your IP has privileged # access to) you might not always know the IP range to blacklist access # to, or the `A` record might change without you noticing. # For e.x.: `foocorp.external.org`. # # Another option is doing something like: # # for addrinfo in socket.getaddrinfo("foocorp.external.org", 80): # global_validator.ip_blacklist.add(ip_address(addrinfo[4][0])) # # but that's not always a good idea if they're behind a third-party lb. for pattern in self.hostname_blacklist: if self._hostname_matches_pattern(hostname, pattern): return False return True @add_local_address_arg def is_addrinfo_allowed(self, addrinfo, _local_addresses=None): assert(len(addrinfo) == 5) # XXX: Do we care about any of the other elements? Guessing not. family, socktype, proto, canonname, sockaddr = addrinfo # The 4th elem inaddrinfo may either be a touple of two or four items, # depending on whether we're dealing with IPv4 or v6 if len(sockaddr) == 2: # v4 ip, port = sockaddr elif len(sockaddr) == 4: # v6 # XXX: what *are* `flow_info` and `scope_id`? Anything useful? # Seems like we can figure out all we need about the scope from # the `is_` properties. ip, port, flow_info, scope_id = sockaddr else: raise ValueError("Unexpected addrinfo format %r" % sockaddr) # Probably won't help protect against SSRF, but might prevent our being # used to attack others' non-HTTP services. See # http://www.remote.org/jochen/sec/hfpa/ if self.port_whitelist and port not in self.port_whitelist: return False if port in self.port_blacklist: return False if self.hostname_blacklist: if not canonname: raise NameserverException( "addrinfo must contain the canon name to do blacklisting " "based on hostname. Make sure you use the " "`socket.AI_CANONNAME` flag, and that each record contains " "the canon name. Your DNS server might also be garbage." ) if not self.is_hostname_allowed(canonname): return False return self.is_ip_allowed(ip, _local_addresses=_local_addresses) JordanMilne-Advocate-6d699ae/advocate/api.py000066400000000000000000000232471370334562200210170ustar00rootroot00000000000000# -*- coding: utf-8 -*- """ advocate.api ~~~~~~~~~~~~ This module implements the Requests API, largely a copy/paste from `requests` itself. :copyright: (c) 2015 by Jordan Milne. :license: Apache2, see LICENSE for more details. """ from collections import OrderedDict import hashlib import pickle from requests import Session as RequestsSession import advocate from .adapters import ValidatingHTTPAdapter from .exceptions import MountDisabledException class Session(RequestsSession): __attrs__ = RequestsSession.__attrs__ + ["validator"] DEFAULT_VALIDATOR = None """Convenience wrapper around `requests.Session` set up for `advocate`ing""" def __init__(self, *args, **kwargs): self.validator = kwargs.pop("validator", self.DEFAULT_VALIDATOR) adapter_kwargs = kwargs.pop("_adapter_kwargs", {}) # `Session.__init__()` calls `mount()` internally, so we need to allow # it temporarily self.__mountAllowed = True RequestsSession.__init__(self, *args, **kwargs) # Drop any existing adapters self.adapters = OrderedDict() self.mount("http://", ValidatingHTTPAdapter(validator=self.validator, **adapter_kwargs)) self.mount("https://", ValidatingHTTPAdapter(validator=self.validator, **adapter_kwargs)) self.__mountAllowed = False def mount(self, *args, **kwargs): """Wrapper around `mount()` to prevent a protection bypass""" if self.__mountAllowed: super(Session, self).mount(*args, **kwargs) else: raise MountDisabledException( "mount() is disabled to prevent protection bypasses" ) def session(*args, **kwargs): return Session(*args, **kwargs) def request(method, url, **kwargs): """Constructs and sends a :class:`Request `. :param method: method for the new :class:`Request` object. :param url: URL for the new :class:`Request` object. :param params: (optional) Dictionary or bytes to be sent in the query string for the :class:`Request`. :param data: (optional) Dictionary, bytes, or file-like object to send in the body of the :class:`Request`. :param json: (optional) json data to send in the body of the :class:`Request`. :param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`. :param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`. :param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': ('filename', fileobj)}``) for multipart encoding upload. :param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth. :param timeout: (optional) How long to wait for the server to send data before giving up, as a float, or a (`connect timeout, read timeout `_) tuple. :type timeout: float or tuple :param allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE redirect following is allowed. :type allow_redirects: bool :param proxies: (optional) Dictionary mapping protocol to the URL of the proxy. :param verify: (optional) if ``True``, the SSL cert will be verified. A CA_BUNDLE path can also be provided. :param stream: (optional) if ``False``, the response content will be immediately downloaded. :param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair. :return: :class:`Response ` object :rtype: requests.Response Usage:: >>> import advocate >>> req = advocate.request('GET', 'http://httpbin.org/get') """ validator = kwargs.pop("validator", None) with Session(validator=validator) as sess: response = sess.request(method=method, url=url, **kwargs) return response def get(url, **kwargs): """Sends a GET request. :param url: URL for the new :class:`Request` object. :param \*\*kwargs: Optional arguments that ``request`` takes. :return: :class:`Response ` object :rtype: requests.Response """ kwargs.setdefault('allow_redirects', True) return request('get', url, **kwargs) def options(url, **kwargs): """Sends a OPTIONS request. :param url: URL for the new :class:`Request` object. :param \*\*kwargs: Optional arguments that ``request`` takes. :return: :class:`Response ` object :rtype: requests.Response """ kwargs.setdefault('allow_redirects', True) return request('options', url, **kwargs) def head(url, **kwargs): """Sends a HEAD request. :param url: URL for the new :class:`Request` object. :param \*\*kwargs: Optional arguments that ``request`` takes. :return: :class:`Response ` object :rtype: requests.Response """ kwargs.setdefault('allow_redirects', False) return request('head', url, **kwargs) def post(url, data=None, json=None, **kwargs): """Sends a POST request. :param url: URL for the new :class:`Request` object. :param data: (optional) Dictionary, bytes, or file-like object to send in the body of the :class:`Request`. :param json: (optional) json data to send in the body of the :class:`Request`. :param \*\*kwargs: Optional arguments that ``request`` takes. :return: :class:`Response ` object :rtype: requests.Response """ return request('post', url, data=data, json=json, **kwargs) def put(url, data=None, **kwargs): """Sends a PUT request. :param url: URL for the new :class:`Request` object. :param data: (optional) Dictionary, bytes, or file-like object to send in the body of the :class:`Request`. :param \*\*kwargs: Optional arguments that ``request`` takes. :return: :class:`Response ` object :rtype: requests.Response """ return request('put', url, data=data, **kwargs) def patch(url, data=None, **kwargs): """Sends a PATCH request. :param url: URL for the new :class:`Request` object. :param data: (optional) Dictionary, bytes, or file-like object to send in the body of the :class:`Request`. :param \*\*kwargs: Optional arguments that ``request`` takes. :return: :class:`Response ` object :rtype: requests.Response """ return request('patch', url, data=data, **kwargs) def delete(url, **kwargs): """Sends a DELETE request. :param url: URL for the new :class:`Request` object. :param \*\*kwargs: Optional arguments that ``request`` takes. :return: :class:`Response ` object :rtype: requests.Response """ return request('delete', url, **kwargs) class RequestsAPIWrapper(object): """Provides a `requests.api`-like interface with a specific validator""" # Due to how the classes are dynamically constructed pickling may not work # correctly unless loaded within the same interpreter instance. # Enable at your peril. SUPPORT_WRAPPER_PICKLING = False def __init__(self, validator): # Do this here to avoid circular import issues try: from .futures import FuturesSession have_requests_futures = True except ImportError as e: have_requests_futures = False self.validator = validator outer_self = self class _WrappedSession(Session): """An `advocate.Session` that uses the wrapper's blacklist the wrapper is meant to be a transparent replacement for `requests`, so people should be able to subclass `wrapper.Session` and still get the desired validation behaviour """ DEFAULT_VALIDATOR = outer_self.validator self._make_wrapper_cls_global(_WrappedSession) if have_requests_futures: class _WrappedFuturesSession(FuturesSession): """Like _WrappedSession, but for `FuturesSession`s""" DEFAULT_VALIDATOR = outer_self.validator self._make_wrapper_cls_global(_WrappedFuturesSession) self.FuturesSession = _WrappedFuturesSession self.request = self._default_arg_wrapper(request) self.get = self._default_arg_wrapper(get) self.options = self._default_arg_wrapper(options) self.head = self._default_arg_wrapper(head) self.post = self._default_arg_wrapper(post) self.put = self._default_arg_wrapper(put) self.patch = self._default_arg_wrapper(patch) self.delete = self._default_arg_wrapper(delete) self.session = self._default_arg_wrapper(session) self.Session = _WrappedSession def __getattr__(self, item): # This class is meant to mimic the requests base module, so if we don't # have this attribute, it might be on the base module (like the Request # class, etc.) try: return object.__getattribute__(self, item) except AttributeError: return getattr(advocate, item) def _default_arg_wrapper(self, fun): def wrapped_func(*args, **kwargs): kwargs.setdefault("validator", self.validator) return fun(*args, **kwargs) return wrapped_func def _make_wrapper_cls_global(self, cls): if not self.SUPPORT_WRAPPER_PICKLING: return # Gnarly, but necessary to give pickle a consistent module-level # reference for each wrapper. wrapper_hash = hashlib.sha256(pickle.dumps(self)).hexdigest() cls.__name__ = "_".join((cls.__name__, wrapper_hash)) cls.__qualname__ = ".".join((__name__, cls.__name__)) if not globals().get(cls.__name__): globals()[cls.__name__] = cls __all__ = ( "delete", "get", "head", "options", "patch", "post", "put", "request", "session", "Session", "RequestsAPIWrapper", ) JordanMilne-Advocate-6d699ae/advocate/connection.py000066400000000000000000000150001370334562200223710ustar00rootroot00000000000000import socket from socket import timeout as SocketTimeout from requests.packages.urllib3.connection import HTTPSConnection, HTTPConnection from requests.packages.urllib3.exceptions import ConnectTimeoutError from requests.packages.urllib3.util.connection import _set_socket_options from requests.packages.urllib3.util.connection import create_connection as old_create_connection from . import addrvalidator from .exceptions import UnacceptableAddressException from .packages import ipaddress def advocate_getaddrinfo(host, port, get_canonname=False): addrinfo = socket.getaddrinfo( host, port, 0, socket.SOCK_STREAM, 0, # We need what the DNS client sees the hostname as, correctly handles # IDNs and tricky things like `private.foocorp.org\x00.google.com`. # All IDNs will be converted to punycode. socket.AI_CANONNAME if get_canonname else 0, ) return fix_addrinfo(addrinfo) def fix_addrinfo(records): """ Propagate the canonname across records and parse IPs I'm not sure if this is just the behaviour of `getaddrinfo` on Linux, but it seems like only the first record in the set has the canonname field populated. """ def fix_record(record, canonname): sa = record[4] sa = (ipaddress.ip_address(sa[0]),) + sa[1:] return record[0], record[1], record[2], canonname, sa canonname = None if records: # Apparently the canonical name is only included in the first record? # Add it to all of them. assert(len(records[0]) == 5) canonname = records[0][3] return tuple(fix_record(x, canonname) for x in records) # Lifted from requests' urllib3, which in turn lifted it from `socket.py`. Oy! def validating_create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None, socket_options=None, validator=None): """Connect to *address* and return the socket object. Convenience function. Connect to *address* (a 2-tuple ``(host, port)``) and return the socket object. Passing the optional *timeout* parameter will set the timeout on the socket instance before attempting to connect. If no *timeout* is supplied, the global default timeout setting returned by :func:`getdefaulttimeout` is used. If *source_address* is set it must be a tuple of (host, port) for the socket to bind as a source address before making the connection. An host of '' or port 0 tells the OS to use the default. """ host, port = address # We can skip asking for the canon name if we're not doing hostname-based # blacklisting. need_canonname = False if validator.hostname_blacklist: need_canonname = True # We check both the non-canonical and canonical hostnames so we can # catch both of these: # CNAME from nonblacklisted.com -> blacklisted.com # CNAME from blacklisted.com -> nonblacklisted.com if not validator.is_hostname_allowed(host): raise UnacceptableAddressException(host) err = None addrinfo = advocate_getaddrinfo(host, port, get_canonname=need_canonname) if addrinfo: if validator.autodetect_local_addresses: local_addresses = addrvalidator.determine_local_addresses() else: local_addresses = [] for res in addrinfo: # Are we allowed to connect with this result? if not validator.is_addrinfo_allowed( res, _local_addresses=local_addresses, ): continue af, socktype, proto, canonname, sa = res # Unparse the validated IP sa = (sa[0].exploded,) + sa[1:] sock = None try: sock = socket.socket(af, socktype, proto) # If provided, set socket level options before connecting. # This is the only addition urllib3 makes to this function. _set_socket_options(sock, socket_options) if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: sock.settimeout(timeout) if source_address: sock.bind(source_address) sock.connect(sa) return sock except socket.error as _: err = _ if sock is not None: sock.close() sock = None if err is None: # If we got here, none of the results were acceptable err = UnacceptableAddressException(address) if err is not None: raise err else: raise socket.error("getaddrinfo returns an empty list") # TODO: Is there a better way to add this to multiple classes with different # base classes? I tried a mixin, but it used the base method instead. def _validating_new_conn(self): """ Establish a socket connection and set nodelay settings on it. :return: New socket connection. """ extra_kw = {} if self.source_address: extra_kw['source_address'] = self.source_address if self.socket_options: extra_kw['socket_options'] = self.socket_options try: # Hack around HTTPretty's patched sockets # TODO: some better method of hacking around it that checks if we # _would have_ connected to a private addr? conn_func = validating_create_connection if socket.getaddrinfo.__module__.startswith("httpretty"): conn_func = old_create_connection else: extra_kw["validator"] = self._validator conn = conn_func( (self.host, self.port), self.timeout, **extra_kw ) except SocketTimeout: raise ConnectTimeoutError( self, "Connection to %s timed out. (connect timeout=%s)" % (self.host, self.timeout)) return conn # Don't silently break if the private API changes across urllib3 versions assert(hasattr(HTTPConnection, '_new_conn')) assert(hasattr(HTTPSConnection, '_new_conn')) class ValidatingHTTPConnection(HTTPConnection): _new_conn = _validating_new_conn def __init__(self, *args, **kwargs): self._validator = kwargs.pop("validator") HTTPConnection.__init__(self, *args, **kwargs) class ValidatingHTTPSConnection(HTTPSConnection): _new_conn = _validating_new_conn def __init__(self, *args, **kwargs): self._validator = kwargs.pop("validator") HTTPSConnection.__init__(self, *args, **kwargs) JordanMilne-Advocate-6d699ae/advocate/connectionpool.py000066400000000000000000000012651370334562200232730ustar00rootroot00000000000000from requests.packages.urllib3 import HTTPConnectionPool, HTTPSConnectionPool from .connection import ( ValidatingHTTPConnection, ValidatingHTTPSConnection, ) # Don't silently break if the private API changes across urllib3 versions assert(hasattr(HTTPConnectionPool, 'ConnectionCls')) assert(hasattr(HTTPSConnectionPool, 'ConnectionCls')) assert(hasattr(HTTPConnectionPool, 'scheme')) assert(hasattr(HTTPSConnectionPool, 'scheme')) class ValidatingHTTPConnectionPool(HTTPConnectionPool): scheme = 'http' ConnectionCls = ValidatingHTTPConnection class ValidatingHTTPSConnectionPool(HTTPSConnectionPool): scheme = 'https' ConnectionCls = ValidatingHTTPSConnection JordanMilne-Advocate-6d699ae/advocate/exceptions.py000066400000000000000000000004651370334562200224240ustar00rootroot00000000000000class AdvocateException(Exception): pass class UnacceptableAddressException(AdvocateException): pass class NameserverException(AdvocateException): pass class MountDisabledException(AdvocateException): pass class ProxyDisabledException(NotImplementedError, AdvocateException): pass JordanMilne-Advocate-6d699ae/advocate/futures.py000066400000000000000000000023471370334562200217410ustar00rootroot00000000000000import requests_futures.sessions from concurrent.futures import ThreadPoolExecutor from requests.adapters import DEFAULT_POOLSIZE from . import Session class FuturesSession(requests_futures.sessions.FuturesSession, Session): def __init__(self, executor=None, max_workers=2, session=None, *args, **kwargs): adapter_kwargs = {} if executor is None: executor = ThreadPoolExecutor(max_workers=max_workers) # set connection pool size equal to max_workers if needed if max_workers > DEFAULT_POOLSIZE: adapter_kwargs = dict(pool_connections=max_workers, pool_maxsize=max_workers) kwargs["_adapter_kwargs"] = adapter_kwargs Session.__init__(self, *args, **kwargs) self.executor = executor self.session = session @property def session(self): return None @session.setter def session(self, value): if value is not None and not isinstance(value, Session): raise NotImplementedError("Setting the .session property to " "non-advocate values disabled " "to prevent whitelist bypasses") JordanMilne-Advocate-6d699ae/advocate/packages/000077500000000000000000000000001370334562200214425ustar00rootroot00000000000000JordanMilne-Advocate-6d699ae/advocate/packages/__init__.py000066400000000000000000000101651370334562200235560ustar00rootroot00000000000000""" Copyright (c) Donald Stufft, pip, and individual contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import absolute_import import sys class VendorAlias(object): def __init__(self, package_names): self._package_names = package_names self._vendor_name = __name__ self._vendor_pkg = self._vendor_name + "." self._vendor_pkgs = [ self._vendor_pkg + name for name in self._package_names ] def find_module(self, fullname, path=None): if fullname.startswith(self._vendor_pkg): return self def load_module(self, name): # Ensure that this only works for the vendored name if not name.startswith(self._vendor_pkg): raise ImportError( "Cannot import %s, must be a subpackage of '%s'." % ( name, self._vendor_name, ) ) if not (name == self._vendor_name or any(name.startswith(pkg) for pkg in self._vendor_pkgs)): raise ImportError( "Cannot import %s, must be one of %s." % ( name, self._vendor_pkgs ) ) # Check to see if we already have this item in sys.modules, if we do # then simply return that. if name in sys.modules: return sys.modules[name] # Check to see if we can import the vendor name try: # We do this dance here because we want to try and import this # module without hitting a recursion error because of a bunch of # VendorAlias instances on sys.meta_path real_meta_path = sys.meta_path[:] try: sys.meta_path = [ m for m in sys.meta_path if not isinstance(m, VendorAlias) ] __import__(name) module = sys.modules[name] finally: # Re-add any additions to sys.meta_path that were made while # during the import we just did, otherwise things like # requests.packages.urllib3.poolmanager will fail. for m in sys.meta_path: if m not in real_meta_path: real_meta_path.append(m) # Restore sys.meta_path with any new items. sys.meta_path = real_meta_path except ImportError: # We can't import the vendor name, so we'll try to import the # "real" name. real_name = name[len(self._vendor_pkg):] try: __import__(real_name) module = sys.modules[real_name] except ImportError: raise ImportError("No module named '%s'" % (name,)) # If we've gotten here we've found the module we're looking for, either # as part of our vendored package, or as the real name, so we'll add # it to sys.modules as the vendored name so that we don't have to do # the lookup again. sys.modules[name] = module # Finally, return the loaded module return module sys.meta_path.append(VendorAlias(["ipaddress"])) JordanMilne-Advocate-6d699ae/advocate/packages/ipaddress/000077500000000000000000000000001370334562200234205ustar00rootroot00000000000000JordanMilne-Advocate-6d699ae/advocate/packages/ipaddress/README000066400000000000000000000002331370334562200242760ustar00rootroot00000000000000ipaddress.py is from https://bitbucket.org/kwi/py2-ipaddress, any issues should be submitted there, or to Python proper if they affect the upstream module JordanMilne-Advocate-6d699ae/advocate/packages/ipaddress/__init__.py000066400000000000000000000005031370334562200255270ustar00rootroot00000000000000from __future__ import absolute_import # XXX: Python 3 before 3.3 doesn't have the `ipaddress` module and will # break on this import. try: # First try to use our bundled ipaddress module from .ipaddress import * except ImportError: # Try to pull from the global `ipaddress` module from ipaddress import * JordanMilne-Advocate-6d699ae/advocate/packages/ipaddress/ipaddress.py000066400000000000000000002201721370334562200257540ustar00rootroot00000000000000# Python 2.7 port of Python 3.4's ipaddress module. # List of compatibility changes: # Python 3 uses only new-style classes. # s/class \(\w\+\):/class \1(object):/ # Use iterator versions of map and range: from itertools import imap as map range = xrange # This backport uses bytearray instead of bytes, as bytes is the same # as str in Python 2.7. bytes = bytearray # Python 2 does not support exception chaining. # s/ from None$// # When checking for instances of int, also allow Python 2's long. _builtin_isinstance = isinstance def isinstance(val, types): if types is int: types = (int, long) elif type(types) is tuple and int in types: types += (long,) return _builtin_isinstance(val, types) # functools.lru_cache is Python 3.2+ only. # /@functools.lru_cache()/d # int().to_bytes is Python 3.2+ only. # s/\(\w+\)\.to_bytes(/_int_to_bytes(\1, / def _int_to_bytes(self, length, byteorder, signed=False): assert byteorder == 'big' and signed is False if self < 0 or self >= 256**length: raise OverflowError() return bytearray(('%0*x' % (length * 2, self)).decode('hex')) # int.from_bytes is Python 3.2+ only. # s/int\.from_bytes(/_int_from_bytes(/g def _int_from_bytes(what, byteorder, signed=False): assert byteorder == 'big' and signed is False return int(str(bytearray(what)).encode('hex'), 16) # ---------------------------------------------------------------------------- # Copyright 2007 Google Inc. # Licensed to PSF under a Contributor Agreement. """A fast, lightweight IPv4/IPv6 manipulation library in Python. This library is used to create/poke/manipulate IPv4 and IPv6 addresses and networks. """ __version__ = '1.0' import functools IPV4LENGTH = 32 IPV6LENGTH = 128 class AddressValueError(ValueError): """A Value Error related to the address.""" class NetmaskValueError(ValueError): """A Value Error related to the netmask.""" def ip_address(address): """Take an IP string/int and return an object of the correct type. Args: address: A string or integer, the IP address. Either IPv4 or IPv6 addresses may be supplied; integers less than 2**32 will be considered to be IPv4 by default. Returns: An IPv4Address or IPv6Address object. Raises: ValueError: if the *address* passed isn't either a v4 or a v6 address """ try: return IPv4Address(address) except (AddressValueError, NetmaskValueError): pass try: return IPv6Address(address) except (AddressValueError, NetmaskValueError): pass raise ValueError('%r does not appear to be an IPv4 or IPv6 address' % address) def ip_network(address, strict=True): """Take an IP string/int and return an object of the correct type. Args: address: A string or integer, the IP network. Either IPv4 or IPv6 networks may be supplied; integers less than 2**32 will be considered to be IPv4 by default. Returns: An IPv4Network or IPv6Network object. Raises: ValueError: if the string passed isn't either a v4 or a v6 address. Or if the network has host bits set. """ try: return IPv4Network(address, strict) except (AddressValueError, NetmaskValueError): pass try: return IPv6Network(address, strict) except (AddressValueError, NetmaskValueError): pass raise ValueError('%r does not appear to be an IPv4 or IPv6 network' % address) def ip_interface(address): """Take an IP string/int and return an object of the correct type. Args: address: A string or integer, the IP address. Either IPv4 or IPv6 addresses may be supplied; integers less than 2**32 will be considered to be IPv4 by default. Returns: An IPv4Interface or IPv6Interface object. Raises: ValueError: if the string passed isn't either a v4 or a v6 address. Notes: The IPv?Interface classes describe an Address on a particular Network, so they're basically a combination of both the Address and Network classes. """ try: return IPv4Interface(address) except (AddressValueError, NetmaskValueError): pass try: return IPv6Interface(address) except (AddressValueError, NetmaskValueError): pass raise ValueError('%r does not appear to be an IPv4 or IPv6 interface' % address) def v4_int_to_packed(address): """Represent an address as 4 packed bytes in network (big-endian) order. Args: address: An integer representation of an IPv4 IP address. Returns: The integer address packed as 4 bytes in network (big-endian) order. Raises: ValueError: If the integer is negative or too large to be an IPv4 IP address. """ try: return _int_to_bytes(address, 4, 'big') except: raise ValueError("Address negative or too large for IPv4") def v6_int_to_packed(address): """Represent an address as 16 packed bytes in network (big-endian) order. Args: address: An integer representation of an IPv6 IP address. Returns: The integer address packed as 16 bytes in network (big-endian) order. """ try: return _int_to_bytes(address, 16, 'big') except: raise ValueError("Address negative or too large for IPv6") def _split_optional_netmask(address): """Helper to split the netmask and raise AddressValueError if needed""" addr = str(address).split('/') if len(addr) > 2: raise AddressValueError("Only one '/' permitted in %r" % address) return addr def _find_address_range(addresses): """Find a sequence of IPv#Address. Args: addresses: a list of IPv#Address objects. Returns: A tuple containing the first and last IP addresses in the sequence. """ first = last = addresses[0] for ip in addresses[1:]: if ip._ip == last._ip + 1: last = ip else: break return (first, last) def _count_righthand_zero_bits(number, bits): """Count the number of zero bits on the right hand side. Args: number: an integer. bits: maximum number of bits to count. Returns: The number of zero bits on the right hand side of the number. """ if number == 0: return bits for i in range(bits): if (number >> i) & 1: return i # All bits of interest were zero, even if there are more in the number return bits def summarize_address_range(first, last): """Summarize a network range given the first and last IP addresses. Example: >>> list(summarize_address_range(IPv4Address('192.0.2.0'), ... IPv4Address('192.0.2.130'))) ... #doctest: +NORMALIZE_WHITESPACE [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), IPv4Network('192.0.2.130/32')] Args: first: the first IPv4Address or IPv6Address in the range. last: the last IPv4Address or IPv6Address in the range. Returns: An iterator of the summarized IPv(4|6) network objects. Raise: TypeError: If the first and last objects are not IP addresses. If the first and last objects are not the same version. ValueError: If the last object is not greater than the first. If the version of the first address is not 4 or 6. """ if (not (isinstance(first, _BaseAddress) and isinstance(last, _BaseAddress))): raise TypeError('first and last must be IP addresses, not networks') if first.version != last.version: raise TypeError("%s and %s are not of the same version" % ( first, last)) if first > last: raise ValueError('last IP address must be greater than first') if first.version == 4: ip = IPv4Network elif first.version == 6: ip = IPv6Network else: raise ValueError('unknown IP version') ip_bits = first._max_prefixlen first_int = first._ip last_int = last._ip while first_int <= last_int: nbits = min(_count_righthand_zero_bits(first_int, ip_bits), (last_int - first_int + 1).bit_length() - 1) net = ip('%s/%d' % (first, ip_bits - nbits)) yield net first_int += 1 << nbits if first_int - 1 == ip._ALL_ONES: break first = first.__class__(first_int) def _collapse_addresses_recursive(addresses): """Loops through the addresses, collapsing concurrent netblocks. Example: ip1 = IPv4Network('192.0.2.0/26') ip2 = IPv4Network('192.0.2.64/26') ip3 = IPv4Network('192.0.2.128/26') ip4 = IPv4Network('192.0.2.192/26') _collapse_addresses_recursive([ip1, ip2, ip3, ip4]) -> [IPv4Network('192.0.2.0/24')] This shouldn't be called directly; it is called via collapse_addresses([]). Args: addresses: A list of IPv4Network's or IPv6Network's Returns: A list of IPv4Network's or IPv6Network's depending on what we were passed. """ while True: last_addr = None ret_array = [] optimized = False for cur_addr in addresses: if not ret_array: last_addr = cur_addr ret_array.append(cur_addr) elif (cur_addr.network_address >= last_addr.network_address and cur_addr.broadcast_address <= last_addr.broadcast_address): optimized = True elif cur_addr == list(last_addr.supernet().subnets())[1]: ret_array[-1] = last_addr = last_addr.supernet() optimized = True else: last_addr = cur_addr ret_array.append(cur_addr) addresses = ret_array if not optimized: return addresses def collapse_addresses(addresses): """Collapse a list of IP objects. Example: collapse_addresses([IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/25')]) -> [IPv4Network('192.0.2.0/24')] Args: addresses: An iterator of IPv4Network or IPv6Network objects. Returns: An iterator of the collapsed IPv(4|6)Network objects. Raises: TypeError: If passed a list of mixed version objects. """ i = 0 addrs = [] ips = [] nets = [] # split IP addresses and networks for ip in addresses: if isinstance(ip, _BaseAddress): if ips and ips[-1]._version != ip._version: raise TypeError("%s and %s are not of the same version" % ( ip, ips[-1])) ips.append(ip) elif ip._prefixlen == ip._max_prefixlen: if ips and ips[-1]._version != ip._version: raise TypeError("%s and %s are not of the same version" % ( ip, ips[-1])) try: ips.append(ip.ip) except AttributeError: ips.append(ip.network_address) else: if nets and nets[-1]._version != ip._version: raise TypeError("%s and %s are not of the same version" % ( ip, nets[-1])) nets.append(ip) # sort and dedup ips = sorted(set(ips)) nets = sorted(set(nets)) while i < len(ips): (first, last) = _find_address_range(ips[i:]) i = ips.index(last) + 1 addrs.extend(summarize_address_range(first, last)) return iter(_collapse_addresses_recursive(sorted( addrs + nets, key=_BaseNetwork._get_networks_key))) def get_mixed_type_key(obj): """Return a key suitable for sorting between networks and addresses. Address and Network objects are not sortable by default; they're fundamentally different so the expression IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') doesn't make any sense. There are some times however, where you may wish to have ipaddress sort these for you anyway. If you need to do this, you can use this function as the key= argument to sorted(). Args: obj: either a Network or Address object. Returns: appropriate key. """ if isinstance(obj, _BaseNetwork): return obj._get_networks_key() elif isinstance(obj, _BaseAddress): return obj._get_address_key() return NotImplemented class _TotalOrderingMixin(object): # Helper that derives the other comparison operations from # __lt__ and __eq__ # We avoid functools.total_ordering because it doesn't handle # NotImplemented correctly yet (http://bugs.python.org/issue10042) def __eq__(self, other): raise NotImplementedError def __ne__(self, other): equal = self.__eq__(other) if equal is NotImplemented: return NotImplemented return not equal def __lt__(self, other): raise NotImplementedError def __le__(self, other): less = self.__lt__(other) if less is NotImplemented or not less: return self.__eq__(other) return less def __gt__(self, other): less = self.__lt__(other) if less is NotImplemented: return NotImplemented equal = self.__eq__(other) if equal is NotImplemented: return NotImplemented return not (less or equal) def __ge__(self, other): less = self.__lt__(other) if less is NotImplemented: return NotImplemented return not less class _IPAddressBase(_TotalOrderingMixin): """The mother class.""" @property def exploded(self): """Return the longhand version of the IP address as a string.""" return self._explode_shorthand_ip_string() @property def compressed(self): """Return the shorthand version of the IP address as a string.""" return str(self) @property def version(self): msg = '%200s has no version specified' % (type(self),) raise NotImplementedError(msg) def _check_int_address(self, address): if address < 0: msg = "%d (< 0) is not permitted as an IPv%d address" raise AddressValueError(msg % (address, self._version)) if address > self._ALL_ONES: msg = "%d (>= 2**%d) is not permitted as an IPv%d address" raise AddressValueError(msg % (address, self._max_prefixlen, self._version)) def _check_packed_address(self, address, expected_len): address_len = len(address) if address_len != expected_len: msg = "%r (len %d != %d) is not permitted as an IPv%d address" raise AddressValueError(msg % (address, address_len, expected_len, self._version)) def _ip_int_from_prefix(self, prefixlen): """Turn the prefix length into a bitwise netmask Args: prefixlen: An integer, the prefix length. Returns: An integer. """ return self._ALL_ONES ^ (self._ALL_ONES >> prefixlen) def _prefix_from_ip_int(self, ip_int): """Return prefix length from the bitwise netmask. Args: ip_int: An integer, the netmask in axpanded bitwise format Returns: An integer, the prefix length. Raises: ValueError: If the input intermingles zeroes & ones """ trailing_zeroes = _count_righthand_zero_bits(ip_int, self._max_prefixlen) prefixlen = self._max_prefixlen - trailing_zeroes leading_ones = ip_int >> trailing_zeroes all_ones = (1 << prefixlen) - 1 if leading_ones != all_ones: byteslen = self._max_prefixlen // 8 details = _int_to_bytes(ip_int, byteslen, 'big') msg = 'Netmask pattern %r mixes zeroes & ones' raise ValueError(msg % details) return prefixlen def _report_invalid_netmask(self, netmask_str): msg = '%r is not a valid netmask' % netmask_str raise NetmaskValueError(msg) def _prefix_from_prefix_string(self, prefixlen_str): """Return prefix length from a numeric string Args: prefixlen_str: The string to be converted Returns: An integer, the prefix length. Raises: NetmaskValueError: If the input is not a valid netmask """ # int allows a leading +/- as well as surrounding whitespace, # so we ensure that isn't the case if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): self._report_invalid_netmask(prefixlen_str) try: prefixlen = int(prefixlen_str) except ValueError: self._report_invalid_netmask(prefixlen_str) if not (0 <= prefixlen <= self._max_prefixlen): self._report_invalid_netmask(prefixlen_str) return prefixlen def _prefix_from_ip_string(self, ip_str): """Turn a netmask/hostmask string into a prefix length Args: ip_str: The netmask/hostmask to be converted Returns: An integer, the prefix length. Raises: NetmaskValueError: If the input is not a valid netmask/hostmask """ # Parse the netmask/hostmask like an IP address. try: ip_int = self._ip_int_from_string(ip_str) except AddressValueError: self._report_invalid_netmask(ip_str) # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). # Note that the two ambiguous cases (all-ones and all-zeroes) are # treated as netmasks. try: return self._prefix_from_ip_int(ip_int) except ValueError: pass # Invert the bits, and try matching a /0+1+/ hostmask instead. ip_int ^= self._ALL_ONES try: return self._prefix_from_ip_int(ip_int) except ValueError: self._report_invalid_netmask(ip_str) class _BaseAddress(_IPAddressBase): """A generic IP object. This IP class contains the version independent methods which are used by single IP addresses. """ def __init__(self, address): if (not isinstance(address, bytes) and '/' in str(address)): raise AddressValueError("Unexpected '/' in %r" % address) def __int__(self): return self._ip def __eq__(self, other): try: return (self._ip == other._ip and self._version == other._version) except AttributeError: return NotImplemented def __lt__(self, other): if self._version != other._version: raise TypeError('%s and %s are not of the same version' % ( self, other)) if not isinstance(other, _BaseAddress): raise TypeError('%s and %s are not of the same type' % ( self, other)) if self._ip != other._ip: return self._ip < other._ip return False # Shorthand for Integer addition and subtraction. This is not # meant to ever support addition/subtraction of addresses. def __add__(self, other): if not isinstance(other, int): return NotImplemented return self.__class__(int(self) + other) def __sub__(self, other): if not isinstance(other, int): return NotImplemented return self.__class__(int(self) - other) def __repr__(self): return '%s(%r)' % (self.__class__.__name__, str(self)) def __str__(self): return str(self._string_from_ip_int(self._ip)) def __hash__(self): return hash(hex(int(self._ip))) def _get_address_key(self): return (self._version, self) class _BaseNetwork(_IPAddressBase): """A generic IP network object. This IP class contains the version independent methods which are used by networks. """ def __init__(self, address): self._cache = {} def __repr__(self): return '%s(%r)' % (self.__class__.__name__, str(self)) def __str__(self): return '%s/%d' % (self.network_address, self.prefixlen) def hosts(self): """Generate Iterator over usable hosts in a network. This is like __iter__ except it doesn't return the network or broadcast addresses. """ network = int(self.network_address) broadcast = int(self.broadcast_address) for x in range(network + 1, broadcast): yield self._address_class(x) def __iter__(self): network = int(self.network_address) broadcast = int(self.broadcast_address) for x in range(network, broadcast + 1): yield self._address_class(x) def __getitem__(self, n): network = int(self.network_address) broadcast = int(self.broadcast_address) if n >= 0: if network + n > broadcast: raise IndexError return self._address_class(network + n) else: n += 1 if broadcast + n < network: raise IndexError return self._address_class(broadcast + n) def __lt__(self, other): if self._version != other._version: raise TypeError('%s and %s are not of the same version' % ( self, other)) if not isinstance(other, _BaseNetwork): raise TypeError('%s and %s are not of the same type' % ( self, other)) if self.network_address != other.network_address: return self.network_address < other.network_address if self.netmask != other.netmask: return self.netmask < other.netmask return False def __eq__(self, other): try: return (self._version == other._version and self.network_address == other.network_address and int(self.netmask) == int(other.netmask)) except AttributeError: return NotImplemented def __hash__(self): return hash(int(self.network_address) ^ int(self.netmask)) def __contains__(self, other): # always false if one is v4 and the other is v6. if self._version != other._version: return False # dealing with another network. if isinstance(other, _BaseNetwork): return False # dealing with another address else: # address return (int(self.network_address) <= int(other._ip) <= int(self.broadcast_address)) def overlaps(self, other): """Tell if self is partly contained in other.""" return self.network_address in other or ( self.broadcast_address in other or ( other.network_address in self or ( other.broadcast_address in self))) @property def broadcast_address(self): x = self._cache.get('broadcast_address') if x is None: x = self._address_class(int(self.network_address) | int(self.hostmask)) self._cache['broadcast_address'] = x return x @property def hostmask(self): x = self._cache.get('hostmask') if x is None: x = self._address_class(int(self.netmask) ^ self._ALL_ONES) self._cache['hostmask'] = x return x @property def with_prefixlen(self): return '%s/%d' % (self.network_address, self._prefixlen) @property def with_netmask(self): return '%s/%s' % (self.network_address, self.netmask) @property def with_hostmask(self): return '%s/%s' % (self.network_address, self.hostmask) @property def num_addresses(self): """Number of hosts in the current subnet.""" return int(self.broadcast_address) - int(self.network_address) + 1 @property def _address_class(self): # Returning bare address objects (rather than interfaces) allows for # more consistent behaviour across the network address, broadcast # address and individual host addresses. msg = '%200s has no associated address class' % (type(self),) raise NotImplementedError(msg) @property def prefixlen(self): return self._prefixlen def address_exclude(self, other): """Remove an address from a larger block. For example: addr1 = ip_network('192.0.2.0/28') addr2 = ip_network('192.0.2.1/32') addr1.address_exclude(addr2) = [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] or IPv6: addr1 = ip_network('2001:db8::1/32') addr2 = ip_network('2001:db8::1/128') addr1.address_exclude(addr2) = [ip_network('2001:db8::1/128'), ip_network('2001:db8::2/127'), ip_network('2001:db8::4/126'), ip_network('2001:db8::8/125'), ... ip_network('2001:db8:8000::/33')] Args: other: An IPv4Network or IPv6Network object of the same type. Returns: An iterator of the IPv(4|6)Network objects which is self minus other. Raises: TypeError: If self and other are of differing address versions, or if other is not a network object. ValueError: If other is not completely contained by self. """ if not self._version == other._version: raise TypeError("%s and %s are not of the same version" % ( self, other)) if not isinstance(other, _BaseNetwork): raise TypeError("%s is not a network object" % other) if not (other.network_address >= self.network_address and other.broadcast_address <= self.broadcast_address): raise ValueError('%s not contained in %s' % (other, self)) if other == self: raise StopIteration # Make sure we're comparing the network of other. other = other.__class__('%s/%s' % (other.network_address, other.prefixlen)) s1, s2 = self.subnets() while s1 != other and s2 != other: if (other.network_address >= s1.network_address and other.broadcast_address <= s1.broadcast_address): yield s2 s1, s2 = s1.subnets() elif (other.network_address >= s2.network_address and other.broadcast_address <= s2.broadcast_address): yield s1 s1, s2 = s2.subnets() else: # If we got here, there's a bug somewhere. raise AssertionError('Error performing exclusion: ' 's1: %s s2: %s other: %s' % (s1, s2, other)) if s1 == other: yield s2 elif s2 == other: yield s1 else: # If we got here, there's a bug somewhere. raise AssertionError('Error performing exclusion: ' 's1: %s s2: %s other: %s' % (s1, s2, other)) def compare_networks(self, other): """Compare two IP objects. This is only concerned about the comparison of the integer representation of the network addresses. This means that the host bits aren't considered at all in this method. If you want to compare host bits, you can easily enough do a 'HostA._ip < HostB._ip' Args: other: An IP object. Returns: If the IP versions of self and other are the same, returns: -1 if self < other: eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') IPv6Network('2001:db8::1000/124') < IPv6Network('2001:db8::2000/124') 0 if self == other eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') IPv6Network('2001:db8::1000/124') == IPv6Network('2001:db8::1000/124') 1 if self > other eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') IPv6Network('2001:db8::2000/124') > IPv6Network('2001:db8::1000/124') Raises: TypeError if the IP versions are different. """ # does this need to raise a ValueError? if self._version != other._version: raise TypeError('%s and %s are not of the same type' % ( self, other)) # self._version == other._version below here: if self.network_address < other.network_address: return -1 if self.network_address > other.network_address: return 1 # self.network_address == other.network_address below here: if self.netmask < other.netmask: return -1 if self.netmask > other.netmask: return 1 return 0 def _get_networks_key(self): """Network-only key function. Returns an object that identifies this address' network and netmask. This function is a suitable "key" argument for sorted() and list.sort(). """ return (self._version, self.network_address, self.netmask) def subnets(self, prefixlen_diff=1, new_prefix=None): """The subnets which join to make the current subnet. In the case that self contains only one IP (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 for IPv6), yield an iterator with just ourself. Args: prefixlen_diff: An integer, the amount the prefix length should be increased by. This should not be set if new_prefix is also set. new_prefix: The desired new prefix length. This must be a larger number (smaller prefix) than the existing prefix. This should not be set if prefixlen_diff is also set. Returns: An iterator of IPv(4|6) objects. Raises: ValueError: The prefixlen_diff is too small or too large. OR prefixlen_diff and new_prefix are both set or new_prefix is a smaller number than the current prefix (smaller number means a larger network) """ if self._prefixlen == self._max_prefixlen: yield self return if new_prefix is not None: if new_prefix < self._prefixlen: raise ValueError('new prefix must be longer') if prefixlen_diff != 1: raise ValueError('cannot set prefixlen_diff and new_prefix') prefixlen_diff = new_prefix - self._prefixlen if prefixlen_diff < 0: raise ValueError('prefix length diff must be > 0') new_prefixlen = self._prefixlen + prefixlen_diff if new_prefixlen > self._max_prefixlen: raise ValueError( 'prefix length diff %d is invalid for netblock %s' % ( new_prefixlen, self)) first = self.__class__('%s/%s' % (self.network_address, self._prefixlen + prefixlen_diff)) yield first current = first while True: broadcast = current.broadcast_address if broadcast == self.broadcast_address: return new_addr = self._address_class(int(broadcast) + 1) current = self.__class__('%s/%s' % (new_addr, new_prefixlen)) yield current def supernet(self, prefixlen_diff=1, new_prefix=None): """The supernet containing the current network. Args: prefixlen_diff: An integer, the amount the prefix length of the network should be decreased by. For example, given a /24 network and a prefixlen_diff of 3, a supernet with a /21 netmask is returned. Returns: An IPv4 network object. Raises: ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have a negative prefix length. OR If prefixlen_diff and new_prefix are both set or new_prefix is a larger number than the current prefix (larger number means a smaller network) """ if self._prefixlen == 0: return self if new_prefix is not None: if new_prefix > self._prefixlen: raise ValueError('new prefix must be shorter') if prefixlen_diff != 1: raise ValueError('cannot set prefixlen_diff and new_prefix') prefixlen_diff = self._prefixlen - new_prefix if self.prefixlen - prefixlen_diff < 0: raise ValueError( 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % (self.prefixlen, prefixlen_diff)) # TODO (pmoody): optimize this. t = self.__class__('%s/%d' % (self.network_address, self.prefixlen - prefixlen_diff), strict=False) return t.__class__('%s/%d' % (t.network_address, t.prefixlen)) @property def is_multicast(self): """Test if the address is reserved for multicast use. Returns: A boolean, True if the address is a multicast address. See RFC 2373 2.7 for details. """ return (self.network_address.is_multicast and self.broadcast_address.is_multicast) @property def is_reserved(self): """Test if the address is otherwise IETF reserved. Returns: A boolean, True if the address is within one of the reserved IPv6 Network ranges. """ return (self.network_address.is_reserved and self.broadcast_address.is_reserved) @property def is_link_local(self): """Test if the address is reserved for link-local. Returns: A boolean, True if the address is reserved per RFC 4291. """ return (self.network_address.is_link_local and self.broadcast_address.is_link_local) @property def is_private(self): """Test if this address is allocated for private networks. Returns: A boolean, True if the address is reserved per iana-ipv4-special-registry or iana-ipv6-special-registry. """ return (self.network_address.is_private and self.broadcast_address.is_private) @property def is_global(self): """Test if this address is allocated for public networks. Returns: A boolean, True if the address is not reserved per iana-ipv4-special-registry or iana-ipv6-special-registry. """ return not self.is_private @property def is_unspecified(self): """Test if the address is unspecified. Returns: A boolean, True if this is the unspecified address as defined in RFC 2373 2.5.2. """ return (self.network_address.is_unspecified and self.broadcast_address.is_unspecified) @property def is_loopback(self): """Test if the address is a loopback address. Returns: A boolean, True if the address is a loopback address as defined in RFC 2373 2.5.3. """ return (self.network_address.is_loopback and self.broadcast_address.is_loopback) class _BaseV4(object): """Base IPv4 object. The following methods are used by IPv4 objects in both single IP addresses and networks. """ # Equivalent to 255.255.255.255 or 32 bits of 1's. _ALL_ONES = (2**IPV4LENGTH) - 1 _DECIMAL_DIGITS = frozenset('0123456789') # the valid octets for host and netmasks. only useful for IPv4. _valid_mask_octets = frozenset((255, 254, 252, 248, 240, 224, 192, 128, 0)) def __init__(self, address): self._version = 4 self._max_prefixlen = IPV4LENGTH def _explode_shorthand_ip_string(self): return str(self) def _ip_int_from_string(self, ip_str): """Turn the given IP string into an integer for comparison. Args: ip_str: A string, the IP ip_str. Returns: The IP ip_str as an integer. Raises: AddressValueError: if ip_str isn't a valid IPv4 Address. """ if not ip_str: raise AddressValueError('Address cannot be empty') octets = ip_str.split('.') if len(octets) != 4: raise AddressValueError("Expected 4 octets in %r" % ip_str) try: return _int_from_bytes(map(self._parse_octet, octets), 'big') except ValueError as exc: raise AddressValueError("%s in %r" % (exc, ip_str)) def _parse_octet(self, octet_str): """Convert a decimal octet into an integer. Args: octet_str: A string, the number to parse. Returns: The octet as an integer. Raises: ValueError: if the octet isn't strictly a decimal from [0..255]. """ if not octet_str: raise ValueError("Empty octet not permitted") # Whitelist the characters, since int() allows a lot of bizarre stuff. if not self._DECIMAL_DIGITS.issuperset(octet_str): msg = "Only decimal digits permitted in %r" raise ValueError(msg % octet_str) # We do the length check second, since the invalid character error # is likely to be more informative for the user if len(octet_str) > 3: msg = "At most 3 characters permitted in %r" raise ValueError(msg % octet_str) # Convert to integer (we know digits are legal) octet_int = int(octet_str, 10) # Any octets that look like they *might* be written in octal, # and which don't look exactly the same in both octal and # decimal are rejected as ambiguous if octet_int > 7 and octet_str[0] == '0': msg = "Ambiguous (octal/decimal) value in %r not permitted" raise ValueError(msg % octet_str) if octet_int > 255: raise ValueError("Octet %d (> 255) not permitted" % octet_int) return octet_int def _string_from_ip_int(self, ip_int): """Turns a 32-bit integer into dotted decimal notation. Args: ip_int: An integer, the IP address. Returns: The IP address as a string in dotted decimal notation. """ return '.'.join(map(str, _int_to_bytes(ip_int, 4, 'big'))) def _is_valid_netmask(self, netmask): """Verify that the netmask is valid. Args: netmask: A string, either a prefix or dotted decimal netmask. Returns: A boolean, True if the prefix represents a valid IPv4 netmask. """ mask = netmask.split('.') if len(mask) == 4: try: for x in mask: if int(x) not in self._valid_mask_octets: return False except ValueError: # Found something that isn't an integer or isn't valid return False for idx, y in enumerate(mask): if idx > 0 and y > mask[idx - 1]: return False return True try: netmask = int(netmask) except ValueError: return False return 0 <= netmask <= self._max_prefixlen def _is_hostmask(self, ip_str): """Test if the IP string is a hostmask (rather than a netmask). Args: ip_str: A string, the potential hostmask. Returns: A boolean, True if the IP string is a hostmask. """ bits = ip_str.split('.') try: parts = [x for x in map(int, bits) if x in self._valid_mask_octets] except ValueError: return False if len(parts) != len(bits): return False if parts[0] < parts[-1]: return True return False @property def max_prefixlen(self): return self._max_prefixlen @property def version(self): return self._version class IPv4Address(_BaseV4, _BaseAddress): """Represent and manipulate single IPv4 Addresses.""" def __init__(self, address): """ Args: address: A string or integer representing the IP Additionally, an integer can be passed, so IPv4Address('192.0.2.1') == IPv4Address(3221225985). or, more generally IPv4Address(int(IPv4Address('192.0.2.1'))) == IPv4Address('192.0.2.1') Raises: AddressValueError: If ipaddress isn't a valid IPv4 address. """ _BaseAddress.__init__(self, address) _BaseV4.__init__(self, address) # Efficient constructor from integer. if isinstance(address, int): self._check_int_address(address) self._ip = address return # Constructing from a packed address if isinstance(address, bytes): self._check_packed_address(address, 4) self._ip = _int_from_bytes(address, 'big') return # Assume input argument to be string or any object representation # which converts into a formatted IP string. addr_str = str(address) self._ip = self._ip_int_from_string(addr_str) @property def packed(self): """The binary representation of this address.""" return v4_int_to_packed(self._ip) @property def is_reserved(self): """Test if the address is otherwise IETF reserved. Returns: A boolean, True if the address is within the reserved IPv4 Network range. """ reserved_network = IPv4Network('240.0.0.0/4') return self in reserved_network @property def is_private(self): """Test if this address is allocated for private networks. Returns: A boolean, True if the address is reserved per iana-ipv4-special-registry. """ return (self in IPv4Network('0.0.0.0/8') or self in IPv4Network('10.0.0.0/8') or self in IPv4Network('127.0.0.0/8') or self in IPv4Network('169.254.0.0/16') or self in IPv4Network('172.16.0.0/12') or self in IPv4Network('192.0.0.0/29') or self in IPv4Network('192.0.0.170/31') or self in IPv4Network('192.0.2.0/24') or self in IPv4Network('192.168.0.0/16') or self in IPv4Network('198.18.0.0/15') or self in IPv4Network('198.51.100.0/24') or self in IPv4Network('203.0.113.0/24') or self in IPv4Network('240.0.0.0/4') or self in IPv4Network('255.255.255.255/32')) @property def is_multicast(self): """Test if the address is reserved for multicast use. Returns: A boolean, True if the address is multicast. See RFC 3171 for details. """ multicast_network = IPv4Network('224.0.0.0/4') return self in multicast_network @property def is_unspecified(self): """Test if the address is unspecified. Returns: A boolean, True if this is the unspecified address as defined in RFC 5735 3. """ unspecified_address = IPv4Address('0.0.0.0') return self == unspecified_address @property def is_loopback(self): """Test if the address is a loopback address. Returns: A boolean, True if the address is a loopback per RFC 3330. """ loopback_network = IPv4Network('127.0.0.0/8') return self in loopback_network @property def is_link_local(self): """Test if the address is reserved for link-local. Returns: A boolean, True if the address is link-local per RFC 3927. """ linklocal_network = IPv4Network('169.254.0.0/16') return self in linklocal_network class IPv4Interface(IPv4Address): def __init__(self, address): if isinstance(address, (bytes, int)): IPv4Address.__init__(self, address) self.network = IPv4Network(self._ip) self._prefixlen = self._max_prefixlen return addr = _split_optional_netmask(address) IPv4Address.__init__(self, addr[0]) self.network = IPv4Network(address, strict=False) self._prefixlen = self.network._prefixlen self.netmask = self.network.netmask self.hostmask = self.network.hostmask def __str__(self): return '%s/%d' % (self._string_from_ip_int(self._ip), self.network.prefixlen) def __eq__(self, other): address_equal = IPv4Address.__eq__(self, other) if not address_equal or address_equal is NotImplemented: return address_equal try: return self.network == other.network except AttributeError: # An interface with an associated network is NOT the # same as an unassociated address. That's why the hash # takes the extra info into account. return False def __lt__(self, other): address_less = IPv4Address.__lt__(self, other) if address_less is NotImplemented: return NotImplemented try: return self.network < other.network except AttributeError: # We *do* allow addresses and interfaces to be sorted. The # unassociated address is considered less than all interfaces. return False def __hash__(self): return self._ip ^ self._prefixlen ^ int(self.network.network_address) @property def ip(self): return IPv4Address(self._ip) @property def with_prefixlen(self): return '%s/%s' % (self._string_from_ip_int(self._ip), self._prefixlen) @property def with_netmask(self): return '%s/%s' % (self._string_from_ip_int(self._ip), self.netmask) @property def with_hostmask(self): return '%s/%s' % (self._string_from_ip_int(self._ip), self.hostmask) class IPv4Network(_BaseV4, _BaseNetwork): """This class represents and manipulates 32-bit IPv4 network + addresses.. Attributes: [examples for IPv4Network('192.0.2.0/27')] .network_address: IPv4Address('192.0.2.0') .hostmask: IPv4Address('0.0.0.31') .broadcast_address: IPv4Address('192.0.2.32') .netmask: IPv4Address('255.255.255.224') .prefixlen: 27 """ # Class to use when creating address objects _address_class = IPv4Address def __init__(self, address, strict=True): """Instantiate a new IPv4 network object. Args: address: A string or integer representing the IP [& network]. '192.0.2.0/24' '192.0.2.0/255.255.255.0' '192.0.0.2/0.0.0.255' are all functionally the same in IPv4. Similarly, '192.0.2.1' '192.0.2.1/255.255.255.255' '192.0.2.1/32' are also functionally equivalent. That is to say, failing to provide a subnetmask will create an object with a mask of /32. If the mask (portion after the / in the argument) is given in dotted quad form, it is treated as a netmask if it starts with a non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it starts with a zero field (e.g. 0.255.255.255 == /8), with the single exception of an all-zero mask which is treated as a netmask == /0. If no mask is given, a default of /32 is used. Additionally, an integer can be passed, so IPv4Network('192.0.2.1') == IPv4Network(3221225985) or, more generally IPv4Interface(int(IPv4Interface('192.0.2.1'))) == IPv4Interface('192.0.2.1') Raises: AddressValueError: If ipaddress isn't a valid IPv4 address. NetmaskValueError: If the netmask isn't valid for an IPv4 address. ValueError: If strict is True and a network address is not supplied. """ _BaseV4.__init__(self, address) _BaseNetwork.__init__(self, address) # Constructing from a packed address if isinstance(address, bytes): self.network_address = IPv4Address(address) self._prefixlen = self._max_prefixlen self.netmask = IPv4Address(self._ALL_ONES) #fixme: address/network test here return # Efficient constructor from integer. if isinstance(address, int): self.network_address = IPv4Address(address) self._prefixlen = self._max_prefixlen self.netmask = IPv4Address(self._ALL_ONES) #fixme: address/network test here. return # Assume input argument to be string or any object representation # which converts into a formatted IP prefix string. addr = _split_optional_netmask(address) self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) if len(addr) == 2: try: # Check for a netmask in prefix length form self._prefixlen = self._prefix_from_prefix_string(addr[1]) except NetmaskValueError: # Check for a netmask or hostmask in dotted-quad form. # This may raise NetmaskValueError. self._prefixlen = self._prefix_from_ip_string(addr[1]) else: self._prefixlen = self._max_prefixlen self.netmask = IPv4Address(self._ip_int_from_prefix(self._prefixlen)) if strict: if (IPv4Address(int(self.network_address) & int(self.netmask)) != self.network_address): raise ValueError('%s has host bits set' % self) self.network_address = IPv4Address(int(self.network_address) & int(self.netmask)) if self._prefixlen == (self._max_prefixlen - 1): self.hosts = self.__iter__ @property def is_global(self): """Test if this address is allocated for public networks. Returns: A boolean, True if the address is not reserved per iana-ipv4-special-registry. """ return (not (self.network_address in IPv4Network('100.64.0.0/10') and self.broadcast_address in IPv4Network('100.64.0.0/10')) and not self.is_private) class _BaseV6(object): """Base IPv6 object. The following methods are used by IPv6 objects in both single IP addresses and networks. """ _ALL_ONES = (2**IPV6LENGTH) - 1 _HEXTET_COUNT = 8 _HEX_DIGITS = frozenset('0123456789ABCDEFabcdef') def __init__(self, address): self._version = 6 self._max_prefixlen = IPV6LENGTH def _ip_int_from_string(self, ip_str): """Turn an IPv6 ip_str into an integer. Args: ip_str: A string, the IPv6 ip_str. Returns: An int, the IPv6 address Raises: AddressValueError: if ip_str isn't a valid IPv6 Address. """ if not ip_str: raise AddressValueError('Address cannot be empty') parts = ip_str.split(':') # An IPv6 address needs at least 2 colons (3 parts). _min_parts = 3 if len(parts) < _min_parts: msg = "At least %d parts expected in %r" % (_min_parts, ip_str) raise AddressValueError(msg) # If the address has an IPv4-style suffix, convert it to hexadecimal. if '.' in parts[-1]: try: ipv4_int = IPv4Address(parts.pop())._ip except AddressValueError as exc: raise AddressValueError("%s in %r" % (exc, ip_str)) parts.append('%x' % ((ipv4_int >> 16) & 0xFFFF)) parts.append('%x' % (ipv4_int & 0xFFFF)) # An IPv6 address can't have more than 8 colons (9 parts). # The extra colon comes from using the "::" notation for a single # leading or trailing zero part. _max_parts = self._HEXTET_COUNT + 1 if len(parts) > _max_parts: msg = "At most %d colons permitted in %r" % (_max_parts-1, ip_str) raise AddressValueError(msg) # Disregarding the endpoints, find '::' with nothing in between. # This indicates that a run of zeroes has been skipped. skip_index = None for i in range(1, len(parts) - 1): if not parts[i]: if skip_index is not None: # Can't have more than one '::' msg = "At most one '::' permitted in %r" % ip_str raise AddressValueError(msg) skip_index = i # parts_hi is the number of parts to copy from above/before the '::' # parts_lo is the number of parts to copy from below/after the '::' if skip_index is not None: # If we found a '::', then check if it also covers the endpoints. parts_hi = skip_index parts_lo = len(parts) - skip_index - 1 if not parts[0]: parts_hi -= 1 if parts_hi: msg = "Leading ':' only permitted as part of '::' in %r" raise AddressValueError(msg % ip_str) # ^: requires ^:: if not parts[-1]: parts_lo -= 1 if parts_lo: msg = "Trailing ':' only permitted as part of '::' in %r" raise AddressValueError(msg % ip_str) # :$ requires ::$ parts_skipped = self._HEXTET_COUNT - (parts_hi + parts_lo) if parts_skipped < 1: msg = "Expected at most %d other parts with '::' in %r" raise AddressValueError(msg % (self._HEXTET_COUNT-1, ip_str)) else: # Otherwise, allocate the entire address to parts_hi. The # endpoints could still be empty, but _parse_hextet() will check # for that. if len(parts) != self._HEXTET_COUNT: msg = "Exactly %d parts expected without '::' in %r" raise AddressValueError(msg % (self._HEXTET_COUNT, ip_str)) if not parts[0]: msg = "Leading ':' only permitted as part of '::' in %r" raise AddressValueError(msg % ip_str) # ^: requires ^:: if not parts[-1]: msg = "Trailing ':' only permitted as part of '::' in %r" raise AddressValueError(msg % ip_str) # :$ requires ::$ parts_hi = len(parts) parts_lo = 0 parts_skipped = 0 try: # Now, parse the hextets into a 128-bit integer. ip_int = 0 for i in range(parts_hi): ip_int <<= 16 ip_int |= self._parse_hextet(parts[i]) ip_int <<= 16 * parts_skipped for i in range(-parts_lo, 0): ip_int <<= 16 ip_int |= self._parse_hextet(parts[i]) return ip_int except ValueError as exc: raise AddressValueError("%s in %r" % (exc, ip_str)) def _parse_hextet(self, hextet_str): """Convert an IPv6 hextet string into an integer. Args: hextet_str: A string, the number to parse. Returns: The hextet as an integer. Raises: ValueError: if the input isn't strictly a hex number from [0..FFFF]. """ # Whitelist the characters, since int() allows a lot of bizarre stuff. if not self._HEX_DIGITS.issuperset(hextet_str): raise ValueError("Only hex digits permitted in %r" % hextet_str) # We do the length check second, since the invalid character error # is likely to be more informative for the user if len(hextet_str) > 4: msg = "At most 4 characters permitted in %r" raise ValueError(msg % hextet_str) # Length check means we can skip checking the integer value return int(hextet_str, 16) def _compress_hextets(self, hextets): """Compresses a list of hextets. Compresses a list of strings, replacing the longest continuous sequence of "0" in the list with "" and adding empty strings at the beginning or at the end of the string such that subsequently calling ":".join(hextets) will produce the compressed version of the IPv6 address. Args: hextets: A list of strings, the hextets to compress. Returns: A list of strings. """ best_doublecolon_start = -1 best_doublecolon_len = 0 doublecolon_start = -1 doublecolon_len = 0 for index, hextet in enumerate(hextets): if hextet == '0': doublecolon_len += 1 if doublecolon_start == -1: # Start of a sequence of zeros. doublecolon_start = index if doublecolon_len > best_doublecolon_len: # This is the longest sequence of zeros so far. best_doublecolon_len = doublecolon_len best_doublecolon_start = doublecolon_start else: doublecolon_len = 0 doublecolon_start = -1 if best_doublecolon_len > 1: best_doublecolon_end = (best_doublecolon_start + best_doublecolon_len) # For zeros at the end of the address. if best_doublecolon_end == len(hextets): hextets += [''] hextets[best_doublecolon_start:best_doublecolon_end] = [''] # For zeros at the beginning of the address. if best_doublecolon_start == 0: hextets = [''] + hextets return hextets def _string_from_ip_int(self, ip_int=None): """Turns a 128-bit integer into hexadecimal notation. Args: ip_int: An integer, the IP address. Returns: A string, the hexadecimal representation of the address. Raises: ValueError: The address is bigger than 128 bits of all ones. """ if ip_int is None: ip_int = int(self._ip) if ip_int > self._ALL_ONES: raise ValueError('IPv6 address is too large') hex_str = '%032x' % ip_int hextets = ['%x' % int(hex_str[x:x+4], 16) for x in range(0, 32, 4)] hextets = self._compress_hextets(hextets) return ':'.join(hextets) def _explode_shorthand_ip_string(self): """Expand a shortened IPv6 address. Args: ip_str: A string, the IPv6 address. Returns: A string, the expanded IPv6 address. """ if isinstance(self, IPv6Network): ip_str = str(self.network_address) elif isinstance(self, IPv6Interface): ip_str = str(self.ip) else: ip_str = str(self) ip_int = self._ip_int_from_string(ip_str) hex_str = '%032x' % ip_int parts = [hex_str[x:x+4] for x in range(0, 32, 4)] if isinstance(self, (_BaseNetwork, IPv6Interface)): return '%s/%d' % (':'.join(parts), self._prefixlen) return ':'.join(parts) @property def max_prefixlen(self): return self._max_prefixlen @property def version(self): return self._version class IPv6Address(_BaseV6, _BaseAddress): """Represent and manipulate single IPv6 Addresses.""" def __init__(self, address): """Instantiate a new IPv6 address object. Args: address: A string or integer representing the IP Additionally, an integer can be passed, so IPv6Address('2001:db8::') == IPv6Address(42540766411282592856903984951653826560) or, more generally IPv6Address(int(IPv6Address('2001:db8::'))) == IPv6Address('2001:db8::') Raises: AddressValueError: If address isn't a valid IPv6 address. """ _BaseAddress.__init__(self, address) _BaseV6.__init__(self, address) # Efficient constructor from integer. if isinstance(address, int): self._check_int_address(address) self._ip = address return # Constructing from a packed address if isinstance(address, bytes): self._check_packed_address(address, 16) self._ip = _int_from_bytes(address, 'big') return # Assume input argument to be string or any object representation # which converts into a formatted IP string. addr_str = str(address) self._ip = self._ip_int_from_string(addr_str) @property def packed(self): """The binary representation of this address.""" return v6_int_to_packed(self._ip) @property def is_multicast(self): """Test if the address is reserved for multicast use. Returns: A boolean, True if the address is a multicast address. See RFC 2373 2.7 for details. """ multicast_network = IPv6Network('ff00::/8') return self in multicast_network @property def is_reserved(self): """Test if the address is otherwise IETF reserved. Returns: A boolean, True if the address is within one of the reserved IPv6 Network ranges. """ reserved_networks = [IPv6Network('::/8'), IPv6Network('100::/8'), IPv6Network('200::/7'), IPv6Network('400::/6'), IPv6Network('800::/5'), IPv6Network('1000::/4'), IPv6Network('4000::/3'), IPv6Network('6000::/3'), IPv6Network('8000::/3'), IPv6Network('A000::/3'), IPv6Network('C000::/3'), IPv6Network('E000::/4'), IPv6Network('F000::/5'), IPv6Network('F800::/6'), IPv6Network('FE00::/9')] return any(self in x for x in reserved_networks) @property def is_link_local(self): """Test if the address is reserved for link-local. Returns: A boolean, True if the address is reserved per RFC 4291. """ linklocal_network = IPv6Network('fe80::/10') return self in linklocal_network @property def is_site_local(self): """Test if the address is reserved for site-local. Note that the site-local address space has been deprecated by RFC 3879. Use is_private to test if this address is in the space of unique local addresses as defined by RFC 4193. Returns: A boolean, True if the address is reserved per RFC 3513 2.5.6. """ sitelocal_network = IPv6Network('fec0::/10') return self in sitelocal_network @property def is_private(self): """Test if this address is allocated for private networks. Returns: A boolean, True if the address is reserved per iana-ipv6-special-registry. """ return (self in IPv6Network('::1/128') or self in IPv6Network('::/128') or self in IPv6Network('::ffff:0:0/96') or self in IPv6Network('100::/64') or self in IPv6Network('2001::/23') or self in IPv6Network('2001:2::/48') or self in IPv6Network('2001:db8::/32') or self in IPv6Network('2001:10::/28') or self in IPv6Network('fc00::/7') or self in IPv6Network('fe80::/10')) @property def is_global(self): """Test if this address is allocated for public networks. Returns: A boolean, true if the address is not reserved per iana-ipv6-special-registry. """ return not self.is_private @property def is_unspecified(self): """Test if the address is unspecified. Returns: A boolean, True if this is the unspecified address as defined in RFC 2373 2.5.2. """ return self._ip == 0 @property def is_loopback(self): """Test if the address is a loopback address. Returns: A boolean, True if the address is a loopback address as defined in RFC 2373 2.5.3. """ return self._ip == 1 @property def ipv4_mapped(self): """Return the IPv4 mapped address. Returns: If the IPv6 address is a v4 mapped address, return the IPv4 mapped address. Return None otherwise. """ if (self._ip >> 32) != 0xFFFF: return None return IPv4Address(self._ip & 0xFFFFFFFF) @property def teredo(self): """Tuple of embedded teredo IPs. Returns: Tuple of the (server, client) IPs or None if the address doesn't appear to be a teredo address (doesn't start with 2001::/32) """ if (self._ip >> 96) != 0x20010000: return None return (IPv4Address((self._ip >> 64) & 0xFFFFFFFF), IPv4Address(~self._ip & 0xFFFFFFFF)) @property def sixtofour(self): """Return the IPv4 6to4 embedded address. Returns: The IPv4 6to4-embedded address if present or None if the address doesn't appear to contain a 6to4 embedded address. """ if (self._ip >> 112) != 0x2002: return None return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) class IPv6Interface(IPv6Address): def __init__(self, address): if isinstance(address, (bytes, int)): IPv6Address.__init__(self, address) self.network = IPv6Network(self._ip) self._prefixlen = self._max_prefixlen return addr = _split_optional_netmask(address) IPv6Address.__init__(self, addr[0]) self.network = IPv6Network(address, strict=False) self.netmask = self.network.netmask self._prefixlen = self.network._prefixlen self.hostmask = self.network.hostmask def __str__(self): return '%s/%d' % (self._string_from_ip_int(self._ip), self.network.prefixlen) def __eq__(self, other): address_equal = IPv6Address.__eq__(self, other) if not address_equal or address_equal is NotImplemented: return address_equal try: return self.network == other.network except AttributeError: # An interface with an associated network is NOT the # same as an unassociated address. That's why the hash # takes the extra info into account. return False def __lt__(self, other): address_less = IPv6Address.__lt__(self, other) if address_less is NotImplemented: return NotImplemented try: return self.network < other.network except AttributeError: # We *do* allow addresses and interfaces to be sorted. The # unassociated address is considered less than all interfaces. return False def __hash__(self): return self._ip ^ self._prefixlen ^ int(self.network.network_address) @property def ip(self): return IPv6Address(self._ip) @property def with_prefixlen(self): return '%s/%s' % (self._string_from_ip_int(self._ip), self._prefixlen) @property def with_netmask(self): return '%s/%s' % (self._string_from_ip_int(self._ip), self.netmask) @property def with_hostmask(self): return '%s/%s' % (self._string_from_ip_int(self._ip), self.hostmask) @property def is_unspecified(self): return self._ip == 0 and self.network.is_unspecified @property def is_loopback(self): return self._ip == 1 and self.network.is_loopback class IPv6Network(_BaseV6, _BaseNetwork): """This class represents and manipulates 128-bit IPv6 networks. Attributes: [examples for IPv6('2001:db8::1000/124')] .network_address: IPv6Address('2001:db8::1000') .hostmask: IPv6Address('::f') .broadcast_address: IPv6Address('2001:db8::100f') .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') .prefixlen: 124 """ # Class to use when creating address objects _address_class = IPv6Address def __init__(self, address, strict=True): """Instantiate a new IPv6 Network object. Args: address: A string or integer representing the IPv6 network or the IP and prefix/netmask. '2001:db8::/128' '2001:db8:0000:0000:0000:0000:0000:0000/128' '2001:db8::' are all functionally the same in IPv6. That is to say, failing to provide a subnetmask will create an object with a mask of /128. Additionally, an integer can be passed, so IPv6Network('2001:db8::') == IPv6Network(42540766411282592856903984951653826560) or, more generally IPv6Network(int(IPv6Network('2001:db8::'))) == IPv6Network('2001:db8::') strict: A boolean. If true, ensure that we have been passed A true network address, eg, 2001:db8::1000/124 and not an IP address on a network, eg, 2001:db8::1/124. Raises: AddressValueError: If address isn't a valid IPv6 address. NetmaskValueError: If the netmask isn't valid for an IPv6 address. ValueError: If strict was True and a network address was not supplied. """ _BaseV6.__init__(self, address) _BaseNetwork.__init__(self, address) # Efficient constructor from integer. if isinstance(address, int): self.network_address = IPv6Address(address) self._prefixlen = self._max_prefixlen self.netmask = IPv6Address(self._ALL_ONES) return # Constructing from a packed address if isinstance(address, bytes): self.network_address = IPv6Address(address) self._prefixlen = self._max_prefixlen self.netmask = IPv6Address(self._ALL_ONES) return # Assume input argument to be string or any object representation # which converts into a formatted IP prefix string. addr = _split_optional_netmask(address) self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) if len(addr) == 2: # This may raise NetmaskValueError self._prefixlen = self._prefix_from_prefix_string(addr[1]) else: self._prefixlen = self._max_prefixlen self.netmask = IPv6Address(self._ip_int_from_prefix(self._prefixlen)) if strict: if (IPv6Address(int(self.network_address) & int(self.netmask)) != self.network_address): raise ValueError('%s has host bits set' % self) self.network_address = IPv6Address(int(self.network_address) & int(self.netmask)) if self._prefixlen == (self._max_prefixlen - 1): self.hosts = self.__iter__ @property def is_site_local(self): """Test if the address is reserved for site-local. Note that the site-local address space has been deprecated by RFC 3879. Use is_private to test if this address is in the space of unique local addresses as defined by RFC 4193. Returns: A boolean, True if the address is reserved per RFC 3513 2.5.6. """ return (self.network_address.is_site_local and self.broadcast_address.is_site_local) JordanMilne-Advocate-6d699ae/advocate/poolmanager.py000066400000000000000000000025501370334562200225440ustar00rootroot00000000000000import collections import functools from urllib3 import PoolManager from urllib3.poolmanager import _default_key_normalizer, PoolKey from .connectionpool import ( ValidatingHTTPSConnectionPool, ValidatingHTTPConnectionPool, ) pool_classes_by_scheme = { "http": ValidatingHTTPConnectionPool, "https": ValidatingHTTPSConnectionPool, } AdvocatePoolKey = collections.namedtuple('AdvocatePoolKey', PoolKey._fields + ('key_validator',)) def key_normalizer(key_class, request_context): request_context = request_context.copy() # TODO: add ability to serialize validator rules to dict, # allowing pool to be shared between sessions with the same # rules. request_context["validator"] = id(request_context["validator"]) return _default_key_normalizer(key_class, request_context) key_fn_by_scheme = { 'http': functools.partial(key_normalizer, AdvocatePoolKey), 'https': functools.partial(key_normalizer, AdvocatePoolKey), } class ValidatingPoolManager(PoolManager): def __init__(self, *args, **kwargs): super(ValidatingPoolManager, self).__init__(*args, **kwargs) # Make sure the API hasn't changed assert (hasattr(self, 'pool_classes_by_scheme')) self.pool_classes_by_scheme = pool_classes_by_scheme self.key_fn_by_scheme = key_fn_by_scheme.copy() JordanMilne-Advocate-6d699ae/examples/000077500000000000000000000000001370334562200177145ustar00rootroot00000000000000JordanMilne-Advocate-6d699ae/examples/hashurl.py000066400000000000000000000011701370334562200217330ustar00rootroot00000000000000import hashlib from flask import Flask, request import advocate import requests app = Flask(__name__) @app.route('/') def get_hash(): url = request.args.get("url") if not url: return "Please specify a url!" try: headers = {"User-Agent": "Hashifier 0.1"} resp = advocate.get(url, headers=headers) except advocate.UnacceptableAddressException: return "That URL points to a forbidden resource" except requests.RequestException: return "Failed to connect to the specified URL" return hashlib.sha256(resp.content).hexdigest() if __name__ == '__main__': app.run() JordanMilne-Advocate-6d699ae/pytest.ini000066400000000000000000000000411370334562200201220ustar00rootroot00000000000000[pytest] addopts = -p no:warningsJordanMilne-Advocate-6d699ae/requests_pytest_plugin.py000066400000000000000000000034441370334562200233160ustar00rootroot00000000000000import socket import doctest import pytest import requests import advocate import advocate.api from advocate.exceptions import MountDisabledException, ProxyDisabledException from advocate.packages import ipaddress from test.monkeypatching import CheckedSocket SKIP_EXCEPTIONS = (MountDisabledException, ProxyDisabledException) def pytest_runtestloop(): validator = advocate.AddrValidator( ip_whitelist={ # requests needs to be able to hit these for its tests! ipaddress.ip_network("127.0.0.1"), ipaddress.ip_network("127.0.1.1"), ipaddress.ip_network("10.255.255.1"), }, # the `httpbin` fixture uses a random port, we need to allow all ports port_whitelist=set(range(0, 65535)), ) # this will yell at us if we failed to patch something socket.socket = CheckedSocket # requests' tests rely on being able to pickle a `Session` advocate.api.RequestsAPIWrapper.SUPPORT_WRAPPER_PICKLING = True wrapper = advocate.api.RequestsAPIWrapper(validator) for attr in advocate.api.__all__: setattr(requests, attr, getattr(wrapper, attr)) def pytest_runtest_makereport(item, call): # This is necessary because we pull in requests' test suite, # which sometimes tests `session.mount()`. We disable that # method, so we need to ignore tests that use it. from _pytest.runner import pytest_runtest_makereport as mr report = mr(item, call) if call.excinfo is not None: exc = call.excinfo.value if isinstance(exc, doctest.UnexpectedException): exc = call.excinfo.value.exc_info[1] if isinstance(exc, SKIP_EXCEPTIONS): report.outcome = 'skipped' report.wasxfail = "reason: Advocate is not meant to support this" return report JordanMilne-Advocate-6d699ae/requirements-test.txt000066400000000000000000000011551370334562200223410ustar00rootroot00000000000000atomicwrites==1.4.0 attrs==19.3.0 blinker==1.4 brotlipy==0.7.0 certifi==2020.4.5.1 cffi==1.14.0 chardet==3.0.4 click==7.1.2 coverage==5.1 cryptography==2.9.2 decorator==4.4.2 Flask==1.1.2 httpbin==0.7.0 idna==2.6 importlib-metadata==1.6.0 itsdangerous==1.1.0 Jinja2==2.11.2 MarkupSafe==1.1.1 mock==3.0.5 more-itertools==5.0.0 netifaces==0.10.9 pluggy==0.13.1 py==1.8.1 pycparser==2.20 pygments==2.5.2 pyOpenSSL==19.1.0 pysocks==1.7.1 pytest==3.10.1 pytest-cov==2.8.1 pytest-httpbin==0.3.0 pytest-mock==2.0.0 raven==6.10.0 requests-futures==1.0.0 requests-mock==1.8.0 six==1.14.0 urllib3==1.22 werkzeug==1.0.1 zipp==1.2.0 JordanMilne-Advocate-6d699ae/setup.cfg000066400000000000000000000005001370334562200177120ustar00rootroot00000000000000[wheel] universal = 1 [flake8] # E123,E133,E226,E241,E242 are the default ignores ignore = E702,E712,E902,N802,F401 max-line-length = 95 exclude = env/,build/,docs/,.eggs/,.git/,packages/,dev_packages/,dist/,*.egg_info/,.cache/ [tool:pytest] norecursedirs = env build docs .eggs .git packages dev_packages dist .cache JordanMilne-Advocate-6d699ae/setup.py000066400000000000000000000032651370334562200176160ustar00rootroot00000000000000import re import setuptools from codecs import open requires = [ 'requests <3.0, >=2.18.0', 'urllib3 <2.0, >=1.22', 'six', "pyasn1", "pyopenssl", "ndg-httpsclient", 'netifaces>=0.10.5', ] packages = [ "advocate", "advocate.packages", "advocate.packages.ipaddress" ] version = '' with open('advocate/__init__.py', 'r') as fd: version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) with open('README.rst', 'r', 'utf-8') as f: readme = f.read() setuptools.setup( name='advocate', version=version, packages=packages, install_requires=requires, tests_require=[ "mock", "pytest", "pytest-cov", "requests-futures", "requests-mock", ], url='https://github.com/JordanMilne/Advocate', license='Apache 2', author='Jordan Milne', author_email='advocate@saynotolinux.com', keywords="http requests security ssrf proxy rebinding advocate", description=('A wrapper around the requests library for safely ' 'making HTTP requests on behalf of a third party'), long_description=readme, classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Natural Language :: English', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Topic :: Security', 'Topic :: Internet :: WWW/HTTP', ], ) JordanMilne-Advocate-6d699ae/test/000077500000000000000000000000001370334562200170555ustar00rootroot00000000000000JordanMilne-Advocate-6d699ae/test/__init__.py000066400000000000000000000000001370334562200211540ustar00rootroot00000000000000JordanMilne-Advocate-6d699ae/test/monkeypatching.py000066400000000000000000000022431370334562200224500ustar00rootroot00000000000000import contextlib import os.path import socket import traceback class DisallowedConnectException(Exception): pass class CheckedSocket(socket.socket): CONNECT_ALLOWED_FUNCS = {"validating_create_connection"} # `test_testserver.py` makes raw connections to the test server to ensure it works CONNECT_ALLOWED_FILES = {"test_testserver.py"} _checks_enabled = True @classmethod @contextlib.contextmanager def bypass_checks(cls): try: cls._checks_enabled = False yield finally: cls._checks_enabled = True @classmethod def _check_frame_allowed(cls, frame): if os.path.basename(frame[0]) in cls.CONNECT_ALLOWED_FILES: return True if frame[2] in cls.CONNECT_ALLOWED_FUNCS: return True return False def connect(self, *args, **kwargs): if self._checks_enabled: stack = traceback.extract_stack() if not any(self._check_frame_allowed(frame) for frame in stack): raise DisallowedConnectException("calling socket.connect() unsafely!") return super(CheckedSocket, self).connect(*args, **kwargs) JordanMilne-Advocate-6d699ae/test/test_advocate.py000066400000000000000000000567461370334562200222760ustar00rootroot00000000000000# coding=utf-8 from __future__ import division import pickle import socket import unittest # This needs to be done before third-party imports to make sure they all use # our wrapped socket class, especially in case of subclasses. from .monkeypatching import CheckedSocket, DisallowedConnectException socket.socket = CheckedSocket from mock import patch import requests import requests_mock import six.moves import advocate from advocate import AddrValidator from advocate.addrvalidator import canonicalize_hostname from advocate.api import RequestsAPIWrapper from advocate.connection import advocate_getaddrinfo from advocate.exceptions import ( MountDisabledException, NameserverException, UnacceptableAddressException, ) from advocate.packages import ipaddress from advocate.futures import FuturesSession # We use port 1 for testing because nothing is likely to legitimately listen # on it. AddrValidator.DEFAULT_PORT_WHITELIST.add(1) RequestsAPIWrapper.SUPPORT_WRAPPER_PICKLING = True global_wrapper = RequestsAPIWrapper(validator=AddrValidator(ip_whitelist={ ipaddress.ip_network("127.0.0.1"), })) RequestsAPIWrapper.SUPPORT_WRAPPER_PICKLING = False class _WrapperSubclass(global_wrapper.Session): def good_method(self): return "foo" def canonname_supported(): """Check if the nameserver supports the AI_CANONNAME flag travis-ci.org's Python 3 env doesn't seem to support it, so don't try any of the test that rely on it. """ addrinfo = advocate_getaddrinfo("example.com", 0, get_canonname=True) assert addrinfo return addrinfo[0][3] == b"example.com" def permissive_validator(**kwargs): default_options = dict( ip_blacklist=None, port_whitelist=None, port_blacklist=None, hostname_blacklist=None, allow_ipv6=True, allow_teredo=True, allow_6to4=True, allow_dns64=True, autodetect_local_addresses=False, ) default_options.update(**kwargs) return AddrValidator(**default_options) # Test our test wrappers to make sure they're testy class TestWrapperTests(unittest.TestCase): def test_unsafe_connect_raises(self): self.assertRaises( DisallowedConnectException, requests.get, "http://example.org/" ) class ValidateIPTests(unittest.TestCase): def _test_ip_kind_blocked(self, ip, **kwargs): validator = permissive_validator(**kwargs) self.assertFalse(validator.is_ip_allowed(ip)) def test_manual_ip_blacklist(self): """Test manually blacklisting based on IP""" validator = AddrValidator( allow_ipv6=True, ip_blacklist=( ipaddress.ip_network("132.0.5.0/24"), ipaddress.ip_network("152.0.0.0/8"), ipaddress.ip_network("::1"), ), ) self.assertFalse(validator.is_ip_allowed("132.0.5.1")) self.assertFalse(validator.is_ip_allowed("152.254.90.1")) self.assertTrue(validator.is_ip_allowed("178.254.90.1")) self.assertFalse(validator.is_ip_allowed("::1")) # Google, found via `dig google.com AAAA` self.assertTrue(validator.is_ip_allowed("2607:f8b0:400a:807::200e")) def test_ip_whitelist(self): """Test manually whitelisting based on IP""" validator = AddrValidator( ip_whitelist=( ipaddress.ip_network("127.0.0.1"), ), ) self.assertTrue(validator.is_ip_allowed("127.0.0.1")) def test_ip_whitelist_blacklist_conflict(self): """Manual whitelist should take precedence over manual blacklist""" validator = AddrValidator( ip_whitelist=( ipaddress.ip_network("127.0.0.1"), ), ip_blacklist=( ipaddress.ip_network("127.0.0.1"), ), ) self.assertTrue(validator.is_ip_allowed("127.0.0.1")) @unittest.skip("takes half an hour or so to run") def test_safecurl_blacklist(self): """Test that we at least disallow everything SafeCurl does""" # All IPs that SafeCurl would disallow bad_netblocks = (ipaddress.ip_network(x) for x in ( '0.0.0.0/8', '10.0.0.0/8', '100.64.0.0/10', '127.0.0.0/8', '169.254.0.0/16', '172.16.0.0/12', '192.0.0.0/29', '192.0.2.0/24', '192.88.99.0/24', '192.168.0.0/16', '198.18.0.0/15', '198.51.100.0/24', '203.0.113.0/24', '224.0.0.0/4', '240.0.0.0/4' )) i = 0 validator = AddrValidator() for bad_netblock in bad_netblocks: num_ips = bad_netblock.num_addresses # Don't test *every* IP in large netblocks step_size = int(min(max(num_ips / 255, 1), 128)) for ip_idx in six.moves.range(0, num_ips, step_size): i += 1 bad_ip = bad_netblock[ip_idx] bad_ip_allowed = validator.is_ip_allowed(bad_ip) if bad_ip_allowed: print(i, bad_ip) self.assertFalse(bad_ip_allowed) # TODO: something like the above for IPv6? def test_ipv4_mapped(self): self._test_ip_kind_blocked("::ffff:192.168.2.1") def test_teredo(self): # 192.168.2.1 as the client address self._test_ip_kind_blocked("2001:0000:4136:e378:8000:63bf:3f57:fdf2") # This should be disallowed even if teredo is allowed. self._test_ip_kind_blocked( "2001:0000:4136:e378:8000:63bf:3f57:fdf2", allow_teredo=False, ) def test_ipv6(self): self._test_ip_kind_blocked("2002:C0A8:FFFF::", allow_ipv6=False) def test_sixtofour(self): # 192.168.XXX.XXX self._test_ip_kind_blocked("2002:C0A8:FFFF::") self._test_ip_kind_blocked("2002:C0A8:FFFF::", allow_6to4=False) def test_dns64(self): # XXX: Don't even know if this is an issue, TBH. Seems to be related # to DNS64/NAT64, but not a lot of easy-to-understand info: # https://tools.ietf.org/html/rfc6052 self._test_ip_kind_blocked("64:ff9b::192.168.2.1") self._test_ip_kind_blocked("64:ff9b::192.168.2.1", allow_dns64=False) def test_link_local(self): # 169.254.XXX.XXX, AWS uses these for autoconfiguration self._test_ip_kind_blocked("169.254.1.1") def test_site_local(self): self._test_ip_kind_blocked("FEC0:CCCC::") def test_loopback(self): self._test_ip_kind_blocked("127.0.0.1") self._test_ip_kind_blocked("::1") def test_multicast(self): self._test_ip_kind_blocked("227.1.1.1") def test_private(self): self._test_ip_kind_blocked("192.168.2.1") self._test_ip_kind_blocked("10.5.5.5") self._test_ip_kind_blocked("0.0.0.0") self._test_ip_kind_blocked("0.1.1.1") self._test_ip_kind_blocked("100.64.0.0") def test_reserved(self): self._test_ip_kind_blocked("255.255.255.255") self._test_ip_kind_blocked("::ffff:192.168.2.1") # 6to4 relay self._test_ip_kind_blocked("192.88.99.0") def test_unspecified(self): self._test_ip_kind_blocked("0.0.0.0") def test_parsed(self): validator = permissive_validator() self.assertFalse(validator.is_ip_allowed( ipaddress.ip_address("0.0.0.0") )) self.assertTrue(validator.is_ip_allowed( ipaddress.ip_address("144.1.1.1") )) class AddrInfoTests(unittest.TestCase): def _is_addrinfo_allowed(self, host, port, **kwargs): validator = permissive_validator(**kwargs) allowed = False for res in advocate_getaddrinfo(host, port): if validator.is_addrinfo_allowed(res): allowed = True return allowed def test_simple(self): self.assertFalse( self._is_addrinfo_allowed("192.168.0.1", 80) ) def test_malformed_addrinfo(self): # Alright, the addrinfo format is probably never going to change, # but *what if it did?* vl = permissive_validator() addrinfo = advocate_getaddrinfo("example.com", 80)[0] + (1,) self.assertRaises(Exception, lambda: vl.is_addrinfo_allowed(addrinfo)) def test_unexpected_proto(self): # What if addrinfo returns info about a protocol we don't understand? vl = permissive_validator() addrinfo = list(advocate_getaddrinfo("example.com", 80)[0]) addrinfo[4] = addrinfo[4] + (1,) self.assertRaises(Exception, lambda: vl.is_addrinfo_allowed(addrinfo)) def test_default_port_whitelist(self): self.assertTrue( self._is_addrinfo_allowed("200.1.1.1", 8080) ) self.assertTrue( self._is_addrinfo_allowed("200.1.1.1", 80) ) self.assertFalse( self._is_addrinfo_allowed("200.1.1.1", 99) ) def test_port_whitelist(self): wl = (80, 10) self.assertTrue( self._is_addrinfo_allowed("200.1.1.1", 80, port_whitelist=wl) ) self.assertTrue( self._is_addrinfo_allowed("200.1.1.1", 10, port_whitelist=wl) ) self.assertFalse( self._is_addrinfo_allowed("200.1.1.1", 99, port_whitelist=wl) ) def test_port_blacklist(self): bl = (80, 10) self.assertFalse( self._is_addrinfo_allowed("200.1.1.1", 80, port_blacklist=bl) ) self.assertFalse( self._is_addrinfo_allowed("200.1.1.1", 10, port_blacklist=bl) ) self.assertTrue( self._is_addrinfo_allowed("200.1.1.1", 99, port_blacklist=bl) ) @patch("advocate.addrvalidator.determine_local_addresses") def test_local_address_handling(self, mock_determine_local_addresses): fake_addresses = [ipaddress.ip_network("200.1.1.1")] mock_determine_local_addresses.return_value = fake_addresses self.assertFalse(self._is_addrinfo_allowed( "200.1.1.1", 80, autodetect_local_addresses=True )) # Check that `is_ip_allowed` didn't make its own call to determine # local addresses mock_determine_local_addresses.assert_called_once_with() mock_determine_local_addresses.reset_mock() self.assertTrue(self._is_addrinfo_allowed( "200.1.1.1", 80, autodetect_local_addresses=False, )) mock_determine_local_addresses.assert_not_called() @unittest.skipIf( not canonname_supported(), "Nameserver doesn't support AI_CANONNAME, skipping hostname tests" ) class HostnameTests(unittest.TestCase): def setUp(self): self._canonname_supported = canonname_supported() def _is_hostname_allowed(self, host, fake_lookup=False, **kwargs): validator = permissive_validator(**kwargs) if fake_lookup: results = [(2, 1, 6, canonicalize_hostname(host).encode("utf8"), ('1.2.3.4', 80))] else: results = advocate_getaddrinfo(host, 80, get_canonname=True) for res in results: if validator.is_addrinfo_allowed(res): return True return False def test_no_blacklist(self): self.assertTrue(self._is_hostname_allowed("example.com")) def test_idn(self): # test some basic globs self.assertFalse(self._is_hostname_allowed( u"中国.example.org", fake_lookup=True, hostname_blacklist={"*.org"} )) # case insensitive, please self.assertFalse(self._is_hostname_allowed( u"中国.example.oRg", fake_lookup=True, hostname_blacklist={"*.Org"} )) self.assertFalse(self._is_hostname_allowed( u"中国.example.org", fake_lookup=True, hostname_blacklist={"xn--fiqs8s.*.org"} )) self.assertFalse(self._is_hostname_allowed( "xn--fiqs8s.example.org", fake_lookup=True, hostname_blacklist={u"中国.*.org"} )) self.assertTrue(self._is_hostname_allowed( u"example.org", fake_lookup=True, hostname_blacklist={u"中国.*.org"} )) self.assertTrue(self._is_hostname_allowed( u"example.com", fake_lookup=True, hostname_blacklist={u"中国.*.org"} )) self.assertTrue(self._is_hostname_allowed( u"foo.example.org", fake_lookup=True, hostname_blacklist={u"中国.*.org"} )) def test_missing_canonname(self): addrinfo = socket.getaddrinfo( "127.0.0.1", 1, 0, socket.SOCK_STREAM, ) self.assertTrue(addrinfo) # Should throw an error if we're using hostname blacklisting and the # addrinfo record we passed in doesn't have a canonname validator = permissive_validator(hostname_blacklist={"foo"}) self.assertRaises( NameserverException, validator.is_addrinfo_allowed, addrinfo[0] ) def test_embedded_null(self): vl = permissive_validator(hostname_blacklist={"*.baz.com"}) # Things get a little screwy with embedded nulls. Try to emulate any # possible null termination when checking if the hostname is allowed. self.assertFalse(vl.is_hostname_allowed("foo.baz.com\x00.example.com")) self.assertFalse(vl.is_hostname_allowed("foo.example.com\x00.baz.com")) self.assertFalse(vl.is_hostname_allowed(u"foo.baz.com\x00.example.com")) self.assertFalse(vl.is_hostname_allowed(u"foo.example.com\x00.baz.com")) class ConnectionPoolingTests(unittest.TestCase): @patch("advocate.connection.ValidatingHTTPConnection._new_conn") def test_connection_reuse(self, mock_new_conn): # Just because you can use an existing connection doesn't mean you # should. The disadvantage of us working at the socket level means that # we get bitten if a connection pool is shared between regular requests # and advocate. # This can never happen with requests, but let's set a good example :) with CheckedSocket.bypass_checks(): # HTTPBin supports `keep-alive`, so it's a good test subject requests.get("http://httpbin.org/") try: advocate.get("http://httpbin.org/") except: pass # Requests may retry several times, but our mock doesn't return a real # socket. Just check that it tried to create one. mock_new_conn.assert_any_call() class AdvocateWrapperTests(unittest.TestCase): def test_get(self): self.assertEqual(advocate.get("http://example.com").status_code, 200) self.assertEqual(advocate.get("https://example.com").status_code, 200) def test_validator(self): self.assertRaises( UnacceptableAddressException, advocate.get, "http://127.0.0.1/" ) self.assertRaises( UnacceptableAddressException, advocate.get, "http://localhost/" ) self.assertRaises( UnacceptableAddressException, advocate.get, "https://localhost/" ) @unittest.skipIf( not canonname_supported(), "Nameserver doesn't support AI_CANONNAME, skipping hostname tests" ) def test_blacklist_hostname(self): self.assertRaises( UnacceptableAddressException, advocate.get, "https://google.com/", validator=AddrValidator(hostname_blacklist={"google.com"}) ) # Disabled for now because the redirection endpoint appears to be broken. @unittest.skip def test_redirect(self): # Make sure httpbin even works test_url = "http://httpbin.org/status/204" self.assertEqual(advocate.get(test_url).status_code, 204) redir_url = "http://httpbin.org/redirect-to?url=http://127.0.0.1/" self.assertRaises( UnacceptableAddressException, advocate.get, redir_url ) def test_mount_disabled(self): sess = advocate.Session() self.assertRaises( MountDisabledException, sess.mount, "foo://", None, ) def test_advocate_requests_api_wrapper(self): wrapper = RequestsAPIWrapper(validator=AddrValidator()) local_validator = AddrValidator(ip_whitelist={ ipaddress.ip_network("127.0.0.1"), }) local_wrapper = RequestsAPIWrapper(validator=local_validator) self.assertRaises( UnacceptableAddressException, wrapper.get, "http://127.0.0.1:1/" ) with self.assertRaises(Exception) as cm: local_wrapper.get("http://127.0.0.1:1/") # Check that we got a connection exception instead of a validation one # This might be either exception depending on the requests version self.assertRegexpMatches( cm.exception.__class__.__name__, r"\A(Connection|Protocol)Error", ) self.assertRaises( UnacceptableAddressException, wrapper.get, "http://localhost:1/" ) self.assertRaises( UnacceptableAddressException, wrapper.get, "https://localhost:1/" ) def test_wrapper_session_pickle(self): # Make sure the validator still works after a pickle round-trip sess_instance = pickle.loads(pickle.dumps(global_wrapper.Session())) with self.assertRaises(Exception) as cm: sess_instance.get("http://127.0.0.1:1/") self.assertRegexpMatches( cm.exception.__class__.__name__, r"\A(Connection|Protocol)Error", ) self.assertRaises( UnacceptableAddressException, sess_instance.get, "http://127.0.1.1:1/" ) def test_wrapper_session_subclass(self): # Make sure pickle doesn't explode if we try to pickle a subclass # of `global_wrapper.Session` def _check_instance(instance): self.assertEqual(instance.good_method(), "foo") with self.assertRaises(Exception) as cm: instance.get("http://127.0.0.1:1/") self.assertRegexpMatches( cm.exception.__class__.__name__, r"\A(Connection|Protocol)Error", ) self.assertRaises( UnacceptableAddressException, instance.get, "http://127.0.1.1:1/" ) sess = _WrapperSubclass() _check_instance(sess) sess_unpickled = pickle.loads(pickle.dumps(sess)) _check_instance(sess_unpickled) @unittest.skipIf( not canonname_supported(), "Nameserver doesn't support AI_CANONNAME, skipping hostname tests" ) def test_advocate_requests_api_wrapper_hostnames(self): wrapper = RequestsAPIWrapper(validator=AddrValidator( hostname_blacklist={"google.com"}, )) self.assertRaises( UnacceptableAddressException, wrapper.get, "https://google.com/", ) def test_advocate_requests_api_wrapper_req_methods(self): # Make sure all the convenience methods make requests with the correct # methods wrapper = RequestsAPIWrapper(AddrValidator()) request_methods = ( "get", "options", "head", "post", "put", "patch", "delete" ) for method_name in request_methods: with requests_mock.mock() as request_mock: # This will fail if the request expected by `request_mock` # isn't sent when calling the wrapper method request_mock.request(method_name, "http://example.com/foo") getattr(wrapper, method_name)("http://example.com/foo") def test_wrapper_getattr_fallback(self): # Make sure wrappers include everything in Advocate's `__init__.py` wrapper = RequestsAPIWrapper(AddrValidator()) self.assertIsNotNone(wrapper.PreparedRequest) def test_proxy_attempt_throws(self): # Advocate can't do anything useful when you use a proxy, the proxy # is the one that ultimately makes the connection self.assertRaises( NotImplementedError, advocate.get, "http://example.org/", proxies={ "http": "http://example.org:3128", "https": "http://example.org:1080", }, ) @patch("advocate.addrvalidator.determine_local_addresses") def test_connect_without_local_addresses(self, mock_determine_local_addresses): fake_addresses = [ipaddress.ip_network("200.1.1.1")] mock_determine_local_addresses.return_value = fake_addresses validator = permissive_validator(autodetect_local_addresses=True) advocate.get("http://example.com/", validator=validator) # Check that `is_ip_allowed` didn't make its own call to determine # local addresses mock_determine_local_addresses.assert_called_once_with() mock_determine_local_addresses.reset_mock() validator = permissive_validator(autodetect_local_addresses=False) advocate.get("http://example.com", validator=validator) mock_determine_local_addresses.assert_not_called() class AdvocateFuturesTest(unittest.TestCase): def test_get(self): sess = FuturesSession() assert 200 == sess.get("http://example.org/").result().status_code def test_custom_validator(self): validator = AddrValidator(hostname_blacklist={"example.org"}) sess = FuturesSession(validator=validator) self.assertRaises( UnacceptableAddressException, lambda: sess.get("http://example.org").result() ) def test_many_workers(self): sess = FuturesSession(max_workers=50) self.assertRaises( UnacceptableAddressException, lambda: sess.get("http://127.0.0.1:1/").result() ) def test_passing_session(self): try: FuturesSession(session=requests.Session()) assert False except NotImplementedError: pass sess = FuturesSession() try: sess.session = requests.Session() assert False except NotImplementedError: pass sess.session = advocate.Session() def test_advocate_wrapper_futures(self): wrapper = RequestsAPIWrapper(validator=AddrValidator()) local_validator = AddrValidator(ip_whitelist={ ipaddress.ip_network("127.0.0.1"), }) local_wrapper = RequestsAPIWrapper(validator=local_validator) with self.assertRaises(UnacceptableAddressException): sess = wrapper.FuturesSession() sess.get("http://127.0.0.1/").result() with self.assertRaises(Exception) as cm: sess = local_wrapper.FuturesSession() sess.get("http://127.0.0.1:1/").result() # Check that we got a connection exception instead of a validation one # This might be either exception depending on the requests version self.assertRegexpMatches( cm.exception.__class__.__name__, r"\A(Connection|Protocol)Error", ) with self.assertRaises(UnacceptableAddressException): sess = wrapper.FuturesSession() sess.get("http://localhost:1/").result() with self.assertRaises(UnacceptableAddressException): sess = wrapper.FuturesSession() sess.get("https://localhost:1/").result() if __name__ == '__main__': unittest.main()