adal-0.4.6/0000755000175000017500000000000013133203245013260 5ustar travistravis00000000000000adal-0.4.6/adal/0000755000175000017500000000000013133203245014161 5ustar travistravis00000000000000adal-0.4.6/adal/__init__.py0000644000175000017500000000343613133203172016277 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ # pylint: disable=wrong-import-position __version__ = '0.4.6' import logging from .authentication_context import AuthenticationContext from .token_cache import TokenCache from .log import (set_logging_options, get_logging_options, ADAL_LOGGER_NAME) from .adal_error import AdalError # to avoid "No handler found" warnings. logging.getLogger(ADAL_LOGGER_NAME).addHandler(logging.NullHandler()) adal-0.4.6/adal/adal_error.py0000644000175000017500000000276213133203172016653 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ class AdalError(Exception): def __init__(self, error_msg, error_response=None): super(AdalError, self).__init__(error_msg) self.error_response = error_response adal-0.4.6/adal/argument.py0000644000175000017500000000374713133203172016367 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ from .constants import OAuth2DeviceCodeResponseParameters def validate_user_code_info(user_code_info): if not user_code_info: raise ValueError("the user_code_info parameter is required") if not user_code_info.get(OAuth2DeviceCodeResponseParameters.DEVICE_CODE): raise ValueError("the user_code_info is missing device_code") if not user_code_info.get(OAuth2DeviceCodeResponseParameters.INTERVAL): raise ValueError("the user_code_info is missing internal") if not user_code_info.get(OAuth2DeviceCodeResponseParameters.EXPIRES_IN): raise ValueError("the user_code_info is missing expires_in") adal-0.4.6/adal/authentication_context.py0000644000175000017500000003270613133203172021325 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ import os import threading import warnings from .authority import Authority from . import argument from .code_request import CodeRequest from .token_request import TokenRequest from .token_cache import TokenCache from . import log from .constants import OAuth2DeviceCodeResponseParameters GLOBAL_ADAL_OPTIONS = {} class AuthenticationContext(object): '''Retrieves authentication tokens from Azure Active Directory. For usages, check out the "sample" folder at: https://github.com/AzureAD/azure-activedirectory-library-for-python ''' def __init__( self, authority, validate_authority=None, cache=None, api_version='1.0'): '''Creates a new AuthenticationContext object. By default the authority will be checked against a list of known Azure Active Directory authorities. If the authority is not recognized as one of these well known authorities then token acquisition will fail. This behavior can be turned off via the validate_authority parameter below. :param str authority: A URL that identifies a token authority. :param bool validate_authority: (optional) Turns authority validation on or off. This parameter default to true. :param TokenCache cache: (optional) Sets the token cache used by this AuthenticationContext instance. If this parameter is not set, then a default is used. Cache instances is only used by that instance of the AuthenticationContext and are not shared unless it has been manually passed during the construction of other AuthenticationContexts. :param api_version: (optional) Specifies API version using on the wire. Historically it has a hardcoded default value as "1.0". Developers are now encouraged to set it as None explicitly, which means the underlying API version will be automatically chosen. In next major release, this default value will be changed to None. ''' self.authority = Authority(authority, validate_authority is None or validate_authority) self._oauth2client = None self.correlation_id = None env_value = os.environ.get('ADAL_PYTHON_SSL_NO_VERIFY') if api_version is not None: warnings.warn( """The default behavior of including api-version=1.0 on the wire is now deprecated. Future version of ADAL will change the default value to None. To ensure a smooth transition, you are recommended to explicitly set it to None in your code now, and test out the new behavior. context = AuthenticationContext(..., api_version=None) """, DeprecationWarning) self._call_context = { 'options': GLOBAL_ADAL_OPTIONS, 'api_version': api_version, 'verify_ssl': None if env_value is None else not env_value # mainly for tracing through proxy } self._token_requests_with_user_code = {} self.cache = cache or TokenCache() self._lock = threading.RLock() @property def options(self): return self._call_context['options'] @options.setter def options(self, val): self._call_context['options'] = val def _acquire_token(self, token_func): self._call_context['log_context'] = log.create_log_context(self.correlation_id) self.authority.validate(self._call_context) return token_func(self) def acquire_token(self, resource, user_id, client_id): '''Gets a token for a given resource via cached tokens. :param str resource: A URI that identifies the resource for which the token is valid. :param str user_id: The username of the user on behalf this application is authenticating. :param str client_id: The OAuth client id of the calling application. :returns: dic with several keys, include "accessToken" and "refreshToken". ''' def token_func(self): token_request = TokenRequest(self._call_context, self, client_id, resource) return token_request.get_token_from_cache_with_refresh(user_id) return self._acquire_token(token_func) def acquire_token_with_username_password(self, resource, username, password, client_id): '''Gets a token for a given resource via user credentails. :param str resource: A URI that identifies the resource for which the token is valid. :param str username: The username of the user on behalf this application is authenticating. :param str password: The password of the user named in the username parameter. :param str client_id: The OAuth client id of the calling application. :returns: dict with several keys, include "accessToken" and "refreshToken". ''' def token_func(self): token_request = TokenRequest(self._call_context, self, client_id, resource) return token_request.get_token_with_username_password(username, password) return self._acquire_token(token_func) def acquire_token_with_client_credentials(self, resource, client_id, client_secret): '''Gets a token for a given resource via client credentials. :param str resource: A URI that identifies the resource for which the token is valid. :param str client_id: The OAuth client id of the calling application. :param str client_secret: The OAuth client secret of the calling application. :returns: dict with several keys, include "accessToken". ''' def token_func(self): token_request = TokenRequest(self._call_context, self, client_id, resource) return token_request.get_token_with_client_credentials(client_secret) return self._acquire_token(token_func) def acquire_token_with_authorization_code(self, authorization_code, redirect_uri, resource, client_id, client_secret): '''Gets a token for a given resource via auhtorization code for a server app. :param str authorization_code: An authorization code returned from a client. :param str redirect_uri: the redirect uri that was used in the authorize call. :param str resource: A URI that identifies the resource for which the token is valid. :param str client_id: The OAuth client id of the calling application. :param str client_secret: The OAuth client secret of the calling application. :returns: dict with several keys, include "accessToken" and "refreshToken". ''' def token_func(self): token_request = TokenRequest( self._call_context, self, client_id, resource, redirect_uri) return token_request.get_token_with_authorization_code( authorization_code, client_secret) return self._acquire_token(token_func) def acquire_token_with_refresh_token(self, refresh_token, client_id, resource, client_secret=None): '''Gets a token for a given resource via refresh tokens :param str refresh_token: A refresh token returned in a tokne response from a previous invocation of acquireToken. :param str client_id: The OAuth client id of the calling application. :param str resource: A URI that identifies the resource for which the token is valid. :param str client_secret: (optional)The OAuth client secret of the calling application. :returns: dict with several keys, include "accessToken" and "refreshToken". ''' def token_func(self): token_request = TokenRequest(self._call_context, self, client_id, resource) return token_request.get_token_with_refresh_token(refresh_token, client_secret) return self._acquire_token(token_func) def acquire_token_with_client_certificate(self, resource, client_id, certificate, thumbprint): '''Gets a token for a given resource via certificate credentials :param str resource: A URI that identifies the resource for which the token is valid. :param str client_id: The OAuth client id of the calling application. :param str certificate: A PEM encoded certificate private key. :param str thumbprint: hex encoded thumbprint of the certificate. :returns: dict with several keys, include "accessToken". ''' def token_func(self): token_request = TokenRequest(self._call_context, self, client_id, resource) return token_request.get_token_with_certificate(certificate, thumbprint) return self._acquire_token(token_func) def acquire_user_code(self, resource, client_id, language=None): '''Gets the user code info which contains user_code, device_code for authenticating user on device. :param str resource: A URI that identifies the resource for which the device_code and user_code is valid for. :param str client_id: The OAuth client id of the calling application. :param str language: The language code specifying how the message should be localized to. :returns: dict contains code and uri for users to login through browser. ''' self._call_context['log_context'] = log.create_log_context(self.correlation_id) self.authority.validate(self._call_context) code_request = CodeRequest(self._call_context, self, client_id, resource) return code_request.get_user_code_info(language) def acquire_token_with_device_code(self, resource, user_code_info, client_id): '''Gets a new access token using via a device code. :param str resource: A URI that identifies the resource for which the token is valid. :param dict user_code_info: The code info from the invocation of "acquire_user_code" :param str client_id: The OAuth client id of the calling application. :returns: dict with several keys, include "accessToken" and "refreshToken". ''' self._call_context['log_context'] = log.create_log_context(self.correlation_id) def token_func(self): token_request = TokenRequest(self._call_context, self, client_id, resource) key = user_code_info[OAuth2DeviceCodeResponseParameters.DEVICE_CODE] with self._lock: self._token_requests_with_user_code[key] = token_request token = token_request.get_token_with_device_code(user_code_info) with self._lock: self._token_requests_with_user_code.pop(key, None) return token return self._acquire_token(token_func) def cancel_request_to_get_token_with_device_code(self, user_code_info): '''Cancels the polling request to get token with device code. :param dict user_code_info: The code info from the invocation of "acquire_user_code" :returns: None ''' argument.validate_user_code_info(user_code_info) key = user_code_info[OAuth2DeviceCodeResponseParameters.DEVICE_CODE] with self._lock: request = self._token_requests_with_user_code.get(key) if not request: raise ValueError('No acquire_token_with_device_code existed to be cancelled') request.cancel_token_request_with_device_code() self._token_requests_with_user_code.pop(key, None) adal-0.4.6/adal/authentication_parameters.py0000644000175000017500000002000313133203172021767 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ #Note, this module does not appear being used anywhere import re import requests from . import util from . import log from .constants import HttpError AUTHORIZATION_URI = 'authorization_uri' RESOURCE = 'resource' WWW_AUTHENTICATE_HEADER = 'www-authenticate' # pylint: disable=anomalous-backslash-in-string,too-few-public-methods class AuthenticationParameters(object): def __init__(self, authorization_uri, resource): self.authorization_uri = authorization_uri self.resource = resource # The 401 challenge is a standard defined in RFC6750, which is based in part on RFC2617. # The challenge has the following form. # WWW-Authenticate : Bearer # authorization_uri="https://login.windows.net/mytenant.com/oauth2/authorize", # Resource_id="00000002-0000-0000-c000-000000000000" # This regex is used to validate the structure of the challenge header. # Match whole structure: ^\s*Bearer\s+([^,\s="]+?)="([^"]*?)"\s*(,\s*([^,\s="]+?)="([^"]*?)"\s*)*$ # ^ Start at the beginning of the string. # \s*Bearer\s+ Match 'Bearer' surrounded by one or more amount of whitespace. # ([^,\s="]+?) This captures the key which is composed of any characters except # comma, whitespace or a quotes. # = Match the = sign. # "([^"]*?)" Captures the value can be any number of non quote characters. # At this point only the first key value pair as been captured. # \s* There can be any amount of white space after the first key value pair. # ( Start a capture group to retrieve the rest of the key value # pairs that are separated by commas. # \s* There can be any amount of whitespace before the comma. # , There must be a comma. # \s* There can be any amount of whitespace after the comma. # (([^,\s="]+?) This will capture the key that comes after the comma. It's made # of a series of any character except comma, whitespace or quotes. # = Match the equal sign between the key and value. # " Match the opening quote of the value. # ([^"]*?) This will capture the value which can be any number of non # quote characters. # " Match the values closing quote. # \s* There can be any amount of whitespace before the next comma. # )* Close the capture group for key value pairs. There can be any # number of these. # $ The rest of the string can be whitespace but nothing else up to # the end of the string. # # This regex checks the structure of the whole challenge header. The complete # header needs to be checked for validity before we can be certain that # we will succeed in pulling out the individual parts. bearer_challenge_structure_validation = re.compile( """^\s*Bearer\s+([^,\s="]+?)="([^"]*?)"\s*(,\s*([^,\s="]+?)="([^"]*?)"\s*)*$""") # This regex pulls out the key and value from the very first pair. first_key_value_pair_regex = re.compile("""^\s*Bearer\s+([^,\s="]+?)="([^"]*?)"\s*""") # This regex is used to pull out all of the key value pairs after the first one. # All of these begin with a comma. all_other_key_value_pair_regex = re.compile("""(?:,\s*([^,\s="]+?)="([^"]*?)"\s*)""") def parse_challenge(challenge): if not bearer_challenge_structure_validation.search(challenge): raise ValueError("The challenge is not parseable as an RFC6750 OAuth2 challenge") challenge_parameters = {} match = first_key_value_pair_regex.search(challenge) if match: challenge_parameters[match.group(1)] = match.group(2) for match in all_other_key_value_pair_regex.finditer(challenge): challenge_parameters[match.group(1)] = match.group(2) return challenge_parameters def create_authentication_parameters_from_header(challenge): challenge_parameters = parse_challenge(challenge) authorization_uri = challenge_parameters.get(AUTHORIZATION_URI) if not authorization_uri: raise ValueError("Could not find 'authorization_uri' in challenge header.") resource = challenge_parameters.get(RESOURCE) return AuthenticationParameters(authorization_uri, resource) def create_authentication_parameters_from_response(response): if response is None: raise AttributeError('Missing required parameter: response') if not hasattr(response, 'status_code') or not response.status_code: raise AttributeError('The response parameter does not have the expected HTTP status_code field') if not hasattr(response, 'headers') or not response.headers: raise AttributeError('There were no headers found in the response.') if response.status_code != HttpError.UNAUTHORIZED: raise ValueError('The response status code does not correspond to an OAuth challenge. ' 'The statusCode is expected to be 401 but is: {}'.format(response.status_code)) challenge = response.headers.get(WWW_AUTHENTICATE_HEADER) if not challenge: raise ValueError("The response does not contain a WWW-Authenticate header that can be " "used to determine the authority_uri and resource.") return create_authentication_parameters_from_header(challenge) def validate_url_object(url): if not url or not hasattr(url, 'geturl'): raise AttributeError('Parameter is of wrong type: url') def create_authentication_parameters_from_url(url, correlation_id=None): if isinstance(url, str): challenge_url = url else: validate_url_object(url) challenge_url = url.geturl() log_context = log.create_log_context(correlation_id) logger = log.Logger('AuthenticationParameters', log_context) logger.debug( "Attempting to retrieve authentication parameters from: {}".format(challenge_url) ) class _options(object): _call_context = {'log_context': log_context} options = util.create_request_options(_options()) try: response = requests.get(challenge_url, headers=options['headers']) except Exception: logger.info("Authentication parameters http get failed.") raise try: return create_authentication_parameters_from_response(response) except Exception: logger.info("Unable to parse response in to authentication parameters.") raise adal-0.4.6/adal/authority.py0000644000175000017500000001537413133203172016574 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ try: from urllib.parse import quote, urlparse except ImportError: from urllib import quote # pylint: disable=no-name-in-module from urlparse import urlparse # pylint: disable=import-error,ungrouped-imports import requests from .constants import AADConstants from .adal_error import AdalError from . import log from . import util class Authority(object): def __init__(self, authority_url, validate_authority=True): self._log = None self._call_context = None self._url = urlparse(authority_url) self._validate_authority_url() self._validated = not validate_authority self._host = None self._tenant = None self._parse_authority() self._authorization_endpoint = None self.token_endpoint = None self.device_code_endpoint = None self.is_adfs_authority = self._tenant.lower() == 'adfs' @property def url(self): return self._url.geturl() def _validate_authority_url(self): if self._url.scheme != 'https': raise ValueError("The authority url must be an https endpoint.") if self._url.query: raise ValueError("The authority url must not have a query string.") def _parse_authority(self): self._host = self._url.hostname path_parts = self._url.path.split('/') try: self._tenant = path_parts[1] except IndexError: raise ValueError("Could not determine tenant.") def _perform_static_instance_discovery(self): self._log.debug("Performing static instance discovery") try: AADConstants.WELL_KNOWN_AUTHORITY_HOSTS.index(self._url.hostname) except ValueError: return False self._log.debug("Authority validated via static instance discovery") return True def _create_authority_url(self): return "https://{}/{}{}".format(self._url.hostname, self._tenant, AADConstants.AUTHORIZE_ENDPOINT_PATH) def _create_instance_discovery_endpoint_from_template(self, authority_host): discovery_endpoint = AADConstants.INSTANCE_DISCOVERY_ENDPOINT_TEMPLATE discovery_endpoint = discovery_endpoint.replace('{authorize_host}', authority_host) discovery_endpoint = discovery_endpoint.replace('{authorize_endpoint}', quote(self._create_authority_url(), safe='~()*!.\'')) return urlparse(discovery_endpoint) def _perform_dynamic_instance_discovery(self): discovery_endpoint = self._create_instance_discovery_endpoint_from_template( AADConstants.WORLD_WIDE_AUTHORITY) get_options = util.create_request_options(self) operation = "Instance Discovery" self._log.debug("Attempting instance discover at: %s", discovery_endpoint.geturl()) try: resp = requests.get(discovery_endpoint.geturl(), headers=get_options['headers'], verify=self._call_context.get('verify_ssl', None)) util.log_return_correlation_id(self._log, operation, resp) except Exception: self._log.info("%s request failed", operation) raise if not util.is_http_success(resp.status_code): return_error_string = u"{} request returned http error: {}".format(operation, resp.status_code) error_response = "" if resp.text: return_error_string = u"{} and server response: {}".format(return_error_string, resp.text) try: error_response = resp.json() except ValueError: pass raise AdalError(return_error_string, error_response) else: discovery_resp = resp.json() if discovery_resp.get('tenant_discovery_endpoint'): return discovery_resp['tenant_discovery_endpoint'] else: raise AdalError('Failed to parse instance discovery response') def _validate_via_instance_discovery(self): valid = self._perform_static_instance_discovery() if not valid: self._perform_dynamic_instance_discovery() def _get_oauth_endpoints(self): if (not self.token_endpoint) or (not self.device_code_endpoint): self.token_endpoint = self._url.geturl() + AADConstants.TOKEN_ENDPOINT_PATH self.device_code_endpoint = self._url.geturl() + AADConstants.DEVICE_ENDPOINT_PATH def validate(self, call_context): self._log = log.Logger('Authority', call_context['log_context']) self._call_context = call_context if not self._validated: self._log.debug("Performing instance discovery: %s", self._url.geturl()) self._validate_via_instance_discovery() self._validated = True else: self._log.debug( "Instance discovery/validation has either already been completed or is turned off: %s", self._url.geturl()) self._get_oauth_endpoints() adal-0.4.6/adal/cache_driver.py0000644000175000017500000002316213133203172017154 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ import base64 import copy import hashlib import json from datetime import datetime, timedelta from dateutil import parser from .adal_error import AdalError from .constants import TokenResponseFields, Misc from . import log #surppress warnings: like accces to a protected member of "_AUTHORITY", etc # pylint: disable=W0212 def _create_token_hash(token): hash_object = hashlib.sha256() hash_object.update(token.encode('utf8')) return base64.b64encode(hash_object.digest()) def _create_token_id_message(entry): access_token_hash = _create_token_hash(entry[TokenResponseFields.ACCESS_TOKEN]) message = 'AccessTokenId: ' + str(access_token_hash) if entry.get(TokenResponseFields.REFRESH_TOKEN): refresh_token_hash = _create_token_hash(entry[TokenResponseFields.REFRESH_TOKEN]) message += ', RefreshTokenId: ' + str(refresh_token_hash) return message def _is_mrrt(entry): return bool(entry.get(TokenResponseFields.RESOURCE, None)) def _entry_has_metadata(entry): return (TokenResponseFields._CLIENT_ID in entry and TokenResponseFields._AUTHORITY in entry) class CacheDriver(object): def __init__(self, call_context, authority, resource, client_id, cache, refresh_function): self._call_context = call_context self._log = log.Logger("OAuth2Client", call_context['log_context']) self._authority = authority self._resource = resource self._client_id = client_id self._cache = cache self._refresh_function = refresh_function def _get_potential_entries(self, query): potential_entries_query = {} if query.get(TokenResponseFields._CLIENT_ID): potential_entries_query[TokenResponseFields._CLIENT_ID] = query[TokenResponseFields._CLIENT_ID] if query.get(TokenResponseFields.USER_ID): potential_entries_query[TokenResponseFields.USER_ID] = query[TokenResponseFields.USER_ID] self._log.debug('Looking for potential cache entries:') self._log.debug(json.dumps(potential_entries_query)) entries = self._cache.find(potential_entries_query) self._log.debug('Found %s potential entries.', len(entries)) return entries def _find_mrrt_tokens_for_user(self, user): return self._cache.find({ TokenResponseFields.IS_MRRT: True, TokenResponseFields.USER_ID: user, TokenResponseFields._CLIENT_ID : self._client_id }) def _load_single_entry_from_cache(self, query): return_val = [] is_resource_tenant_specific = False potential_entries = self._get_potential_entries(query) if potential_entries: resource_tenant_specific_entries = [ x for x in potential_entries if x[TokenResponseFields.RESOURCE] == self._resource and x[TokenResponseFields._AUTHORITY] == self._authority] if not resource_tenant_specific_entries: self._log.debug('No resource specific cache entries found.') #There are no resource specific entries. Find an MRRT token. mrrt_tokens = (x for x in potential_entries if x[TokenResponseFields.IS_MRRT]) token = next(mrrt_tokens, None) if token: self._log.debug('Found an MRRT token.') return_val = token else: self._log.debug('No MRRT tokens found.') elif len(resource_tenant_specific_entries) == 1: self._log.debug('Resource specific token found.') return_val = resource_tenant_specific_entries[0] is_resource_tenant_specific = True else: raise AdalError('More than one token matches the criteria. The result is ambiguous.') if return_val: self._log.debug('Returning token from cache lookup, %s', _create_token_id_message(return_val)) return return_val, is_resource_tenant_specific def _create_entry_from_refresh(self, entry, refresh_response): new_entry = copy.deepcopy(entry) new_entry.update(refresh_response) if entry[TokenResponseFields.IS_MRRT] and self._authority != entry[TokenResponseFields._AUTHORITY]: new_entry[TokenResponseFields._AUTHORITY] = self._authority self._log.debug('Created new cache entry from refresh response.') return new_entry def _replace_entry(self, entry_to_replace, new_entry): self.remove(entry_to_replace) self.add(new_entry) def _refresh_expired_entry(self, entry): token_response = self._refresh_function(entry, None) new_entry = self._create_entry_from_refresh(entry, token_response) self._replace_entry(entry, new_entry) self._log.info('Returning token refreshed after expiry.') return new_entry def _acquire_new_token_from_mrrt(self, entry): token_response = self._refresh_function(entry, self._resource) new_entry = self._create_entry_from_refresh(entry, token_response) self.add(new_entry) self._log.info('Returning token derived from mrrt refresh.') return new_entry def _refresh_entry_if_necessary(self, entry, is_resource_specific): expiry_date = parser.parse(entry[TokenResponseFields.EXPIRES_ON]) now = datetime.now(expiry_date.tzinfo) # Add some buffer in to the time comparison to account for clock skew or latency. now_plus_buffer = now + timedelta(minutes=Misc.CLOCK_BUFFER) if is_resource_specific and now_plus_buffer > expiry_date: if TokenResponseFields.REFRESH_TOKEN in entry: self._log.info('Cached token is expired. Refreshing: %s', expiry_date) return self._refresh_expired_entry(entry) else: self.remove(entry) return None elif not is_resource_specific and entry.get(TokenResponseFields.IS_MRRT): if TokenResponseFields.REFRESH_TOKEN in entry: self._log.info('Acquiring new access token from MRRT token.') return self._acquire_new_token_from_mrrt(entry) else: self.remove(entry) return None else: return entry def find(self, query): if query is None: query = {} self._log.debug('finding with query: %s', json.dumps(query)) entry, is_resource_tenant_specific = self._load_single_entry_from_cache(query) if entry: return self._refresh_entry_if_necessary(entry, is_resource_tenant_specific) else: return None def remove(self, entry): self._log.debug('Removing entry.') self._cache.remove([entry]) def _remove_many(self, entries): self._log.debug('Remove many:%s', len(entries)) self._cache.remove(entries) def _add_many(self, entries): self._log.debug('Add many: %s', len(entries)) self._cache.add(entries) def _update_refresh_tokens(self, entry): if _is_mrrt(entry) and entry.get(TokenResponseFields.REFRESH_TOKEN): mrrt_tokens = self._find_mrrt_tokens_for_user(entry.get(TokenResponseFields.USER_ID)) if mrrt_tokens: self._log.debug('Updating %s cached refresh tokens', len(mrrt_tokens)) self._remove_many(mrrt_tokens) for t in mrrt_tokens: t[TokenResponseFields.REFRESH_TOKEN] = entry[TokenResponseFields.REFRESH_TOKEN] self._add_many(mrrt_tokens) def _argument_entry_with_cached_metadata(self, entry): if _entry_has_metadata(entry): return if _is_mrrt(entry): self._log.debug('Added entry is MRRT') entry[TokenResponseFields.IS_MRRT] = True else: entry[TokenResponseFields.RESOURCE] = self._resource entry[TokenResponseFields._CLIENT_ID] = self._client_id entry[TokenResponseFields._AUTHORITY] = self._authority def add(self, entry): self._log.debug('Adding entry %s', _create_token_id_message(entry)) self._argument_entry_with_cached_metadata(entry) self._update_refresh_tokens(entry) self._cache.add([entry]) adal-0.4.6/adal/code_request.py0000644000175000017500000000516613133203172017224 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ from . import constants from . import log from . import oauth2_client OAUTH2_PARAMETERS = constants.OAuth2.Parameters class CodeRequest(object): """description of class""" def __init__(self, call_context, authentication_context, client_id, resource): self._log = log.Logger("CodeRequest", call_context['log_context']) self._call_context = call_context self._authentication_context = authentication_context self._client_id = client_id self._resource = resource def _get_user_code_info(self, oauth_parameters): client = self._create_oauth2_client() return client.get_user_code_info(oauth_parameters) def _create_oauth2_client(self): return oauth2_client.OAuth2Client( self._call_context, self._authentication_context.authority) def _create_oauth_parameters(self): return { OAUTH2_PARAMETERS.CLIENT_ID: self._client_id, OAUTH2_PARAMETERS.RESOURCE: self._resource } def get_user_code_info(self, language): self._log.info('Getting user code info.') oauth_parameters = self._create_oauth_parameters() if language: oauth_parameters[OAUTH2_PARAMETERS.LANGUAGE] = language return self._get_user_code_info(oauth_parameters) adal-0.4.6/adal/constants.py0000644000175000017500000001643513133203172016557 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ # pylint: disable=too-few-public-methods,old-style-class,no-init class Errors: # Constants ERROR_VALUE_NONE = '{} should not be None.' ERROR_VALUE_EMPTY_STRING = '{} should not be "".' ERROR_RESPONSE_MALFORMED_XML = 'The provided response string is not well formed XML.' class OAuth2Parameters(object): GRANT_TYPE = 'grant_type' CLIENT_ASSERTION = 'client_assertion' CLIENT_ASSERTION_TYPE = 'client_assertion_type' CLIENT_ID = 'client_id' CLIENT_SECRET = 'client_secret' REDIRECT_URI = 'redirect_uri' RESOURCE = 'resource' CODE = 'code' SCOPE = 'scope' ASSERTION = 'assertion' AAD_API_VERSION = 'api-version' USERNAME = 'username' PASSWORD = 'password' REFRESH_TOKEN = 'refresh_token' LANGUAGE = 'mkt' DEVICE_CODE = 'device_code' class OAuth2GrantType(object): AUTHORIZATION_CODE = 'authorization_code' REFRESH_TOKEN = 'refresh_token' CLIENT_CREDENTIALS = 'client_credentials' JWT_BEARER = 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer' PASSWORD = 'password' SAML1 = 'urn:ietf:params:oauth:grant-type:saml1_1-bearer' SAML2 = 'urn:ietf:params:oauth:grant-type:saml2-bearer' DEVICE_CODE = 'device_code' class OAuth2ResponseParameters(object): CODE = 'code' TOKEN_TYPE = 'token_type' ACCESS_TOKEN = 'access_token' ID_TOKEN = 'id_token' REFRESH_TOKEN = 'refresh_token' CREATED_ON = 'created_on' EXPIRES_ON = 'expires_on' EXPIRES_IN = 'expires_in' RESOURCE = 'resource' ERROR = 'error' ERROR_DESCRIPTION = 'error_description' class OAuth2DeviceCodeResponseParameters: USER_CODE = 'user_code' DEVICE_CODE = 'device_code' VERIFICATION_URL = 'verification_url' EXPIRES_IN = 'expires_in' INTERVAL = 'interval' MESSAGE = 'message' ERROR = 'error' ERROR_DESCRIPTION = 'error_description' class OAuth2Scope(object): OPENID = 'openid' class OAuth2(object): Parameters = OAuth2Parameters() GrantType = OAuth2GrantType() ResponseParameters = OAuth2ResponseParameters() DeviceCodeResponseParameters = OAuth2DeviceCodeResponseParameters() Scope = OAuth2Scope() IdTokenMap = { 'tid' : 'tenantId', 'given_name' : 'givenName', 'family_name' : 'familyName', 'idp' : 'identityProvider', 'oid' : 'oid' } class TokenResponseFields(object): TOKEN_TYPE = 'tokenType' ACCESS_TOKEN = 'accessToken' REFRESH_TOKEN = 'refreshToken' CREATED_ON = 'createdOn' EXPIRES_ON = 'expiresOn' EXPIRES_IN = 'expiresIn' RESOURCE = 'resource' USER_ID = 'userId' ERROR = 'error' ERROR_DESCRIPTION = 'errorDescription' # not from the wire, but amends for token cache _AUTHORITY = '_authority' _CLIENT_ID = '_clientId' IS_MRRT = 'isMRRT' class IdTokenFields(object): USER_ID = 'userId' IS_USER_ID_DISPLAYABLE = 'isUserIdDisplayable' TENANT_ID = 'tenantId' GIVE_NAME = 'givenName' FAMILY_NAME = 'familyName' IDENTITY_PROVIDER = 'identityProvider' class Misc(object): MAX_DATE = 0xffffffff CLOCK_BUFFER = 5 # In minutes. class Jwt(object): SELF_SIGNED_JWT_LIFETIME = 10 # 10 mins in mins AUDIENCE = 'aud' ISSUER = 'iss' SUBJECT = 'sub' NOT_BEFORE = 'nbf' EXPIRES_ON = 'exp' JWT_ID = 'jti' class UserRealm(object): federation_protocol_type = { 'WSFederation' : 'wstrust', 'SAML2' : 'saml20', 'Unknown' : 'unknown' } account_type = { 'Federated' : 'federated', 'Managed' : 'managed', 'Unknown' : 'unknown' } class Saml(object): TokenTypeV1 = 'urn:oasis:names:tc:SAML:1.0:assertion' TokenTypeV2 = 'urn:oasis:names:tc:SAML:2.0:assertion' class XmlNamespaces(object): namespaces = { 'wsdl' :'http://schemas.xmlsoap.org/wsdl/', 'sp' :'http://docs.oasis-open.org/ws-sx/ws-securitypolicy/200702', 'sp2005' :'http://schemas.xmlsoap.org/ws/2005/07/securitypolicy', 'wsu' :'http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd', 'wsa10' :'http://www.w3.org/2005/08/addressing', 'http' :'http://schemas.microsoft.com/ws/06/2004/policy/http', 'soap12' :'http://schemas.xmlsoap.org/wsdl/soap12/', 'wsp' :'http://schemas.xmlsoap.org/ws/2004/09/policy', 's' :'http://www.w3.org/2003/05/soap-envelope', 'wsa' :'http://www.w3.org/2005/08/addressing', 'wst' :'http://docs.oasis-open.org/ws-sx/ws-trust/200512', 'trust' : "http://docs.oasis-open.org/ws-sx/ws-trust/200512", 'saml' : "urn:oasis:names:tc:SAML:1.0:assertion", 't' : 'http://schemas.xmlsoap.org/ws/2005/02/trust' } class Cache(object): HASH_ALGORITHM = 'sha256' class HttpError(object): UNAUTHORIZED = 401 class AADConstants(object): WORLD_WIDE_AUTHORITY = 'login.windows.net' WELL_KNOWN_AUTHORITY_HOSTS = [ 'login.windows.net', 'login.microsoftonline.com', 'login.chinacloudapi.cn', 'login-us.microsoftonline.com', 'login.microsoftonline.us', 'login.microsoftonline.de', ] INSTANCE_DISCOVERY_ENDPOINT_TEMPLATE = 'https://{authorize_host}/common/discovery/instance?authorization_endpoint={authorize_endpoint}&api-version=1.0' # pylint: disable=invalid-name AUTHORIZE_ENDPOINT_PATH = '/oauth2/authorize' TOKEN_ENDPOINT_PATH = '/oauth2/token' DEVICE_ENDPOINT_PATH = '/oauth2/devicecode' class AdalIdParameters(object): SKU = 'x-client-SKU' VERSION = 'x-client-Ver' OS = 'x-client-OS' # pylint: disable=invalid-name CPU = 'x-client-CPU' PYTHON_SKU = 'Python' class WSTrustVersion(object): UNDEFINED = 'undefined' WSTRUST13 = 'wstrust13' WSTRUST2005 = 'wstrust2005' adal-0.4.6/adal/log.py0000644000175000017500000000773213133203172015324 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ import logging import uuid import traceback ADAL_LOGGER_NAME = 'adal-python' def create_log_context(correlation_id=None): return {'correlation_id' : correlation_id or str(uuid.uuid4())} def set_logging_options(options=None): '''Configure adal logger, including level and handler spec'd by python logging module. Basic Usages:: >>>adal.set_logging_options({ >>> 'level': 'DEBUG' >>> 'handler': logging.FileHandler('adal.log') >>>}) ''' if options is None: options = {} logger = logging.getLogger(ADAL_LOGGER_NAME) logger.setLevel(options.get('level', logging.ERROR)) handler = options.get('handler') if handler: handler.setLevel(logger.level) logger.addHandler(handler) def get_logging_options(): '''Get logging options :returns: a dict, with a key of 'level' for logging level. ''' logger = logging.getLogger(ADAL_LOGGER_NAME) level = logger.getEffectiveLevel() return { 'level': logging.getLevelName(level) } class Logger(object): '''wrapper around python built-in logging to log correlation_id, and stack trace through keyword argument of 'log_stack_trace' ''' def __init__(self, component_name, log_context): if not log_context: raise AttributeError('Logger: log_context is a required parameter') self._component_name = component_name self.log_context = log_context self._logging = logging.getLogger(ADAL_LOGGER_NAME) def _log_message(self, msg, log_stack_trace=None): correlation_id = self.log_context.get("correlation_id", "") formatted = "{} - {}:{}".format( correlation_id, self._component_name, msg) if log_stack_trace: formatted += "\nStack:\n{}".format(traceback.format_stack()) return formatted def warn(self, msg, *args, **kwargs): log_stack_trace = kwargs.pop('log_stack_trace', None) msg = self._log_message(msg, log_stack_trace) self._logging.warning(msg, *args, **kwargs) def info(self, msg, *args, **kwargs): log_stack_trace = kwargs.pop('log_stack_trace', None) msg = self._log_message(msg, log_stack_trace) self._logging.info(msg, *args, **kwargs) def debug(self, msg, *args, **kwargs): log_stack_trace = kwargs.pop('log_stack_trace', None) msg = self._log_message(msg, log_stack_trace) self._logging.debug(msg, *args, **kwargs) adal-0.4.6/adal/mex.py0000644000175000017500000002566213133203172015336 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ try: from urllib.parse import urlparse except ImportError: from urlparse import urlparse # pylint: disable=import-error try: from xml.etree import cElementTree as ET except ImportError: from xml.etree import ElementTree as ET import requests from . import log from . import util from . import xmlutil from .constants import XmlNamespaces, WSTrustVersion from .adal_error import AdalError TRANSPORT_BINDING_XPATH = 'wsp:ExactlyOne/wsp:All/sp:TransportBinding' TRANSPORT_BINDING_2005_XPATH = 'wsp:ExactlyOne/wsp:All/sp2005:TransportBinding' #pylint: disable=invalid-name SOAP_ACTION_XPATH = 'wsdl:operation/soap12:operation' RST_SOAP_ACTION_13 = 'http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue' RST_SOAP_ACTION_2005 = 'http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue' #pylint: disable=invalid-name SOAP_TRANSPORT_XPATH = 'soap12:binding' SOAP_HTTP_TRANSPORT_VALUE = 'http://schemas.xmlsoap.org/soap/http' PORT_XPATH = 'wsdl:service/wsdl:port' ADDRESS_XPATH = 'wsa10:EndpointReference/wsa10:Address' def _url_is_secure(endpoint_url): parsed = urlparse(endpoint_url) return parsed.scheme == 'https' class Mex(object): def __init__(self, call_context, url): self._log = log.Logger("MEX", call_context.get('log_context')) self._call_context = call_context self._url = url self._dom = None self._parents = None self._mex_doc = None self.username_password_policy = {} self._log.debug("Mex created with url: %s", self._url) def discover(self): self._log.debug("Retrieving mex at: %s", self._url) options = util.create_request_options(self, {'headers': {'Content-Type': 'application/soap+xml'}}) try: operation = "Mex Get" resp = requests.get(self._url, headers=options['headers'], verify=self._call_context.get('verify_ssl', None)) util.log_return_correlation_id(self._log, operation, resp) except Exception: self._log.info("%s request failed", operation) raise if not util.is_http_success(resp.status_code): return_error_string = u"{} request returned http error: {}".format(operation, resp.status_code) error_response = "" if resp.text: return_error_string = u"{} and server response: {}".format(return_error_string, resp.text) try: error_response = resp.json() except ValueError: pass raise AdalError(return_error_string, error_response) else: try: self._mex_doc = resp.text #options = {'errorHandler':self._log.error} self._dom = ET.fromstring(self._mex_doc) self._parents = {c:p for p in self._dom.iter() for c in p} self._parse() except Exception: self._log.info('Failed to parse mex response in to DOM') raise def _check_policy(self, policy_node): policy_id = policy_node.attrib["{{{}}}Id".format(XmlNamespaces.namespaces['wsu'])] # Try with Transport Binding XPath transport_binding_nodes = xmlutil.xpath_find(policy_node, TRANSPORT_BINDING_XPATH) # If unsuccessful, try again with 2005 XPath if not transport_binding_nodes: transport_binding_nodes = xmlutil.xpath_find(policy_node, TRANSPORT_BINDING_2005_XPATH) # If we did not find any binding, this is potentially bad. if not transport_binding_nodes: self._log.debug( "Potential policy did not match required transport binding: %s", policy_id) else: self._log.debug("Found matching policy id: %s", policy_id) return policy_id def _select_username_password_polices(self, xpath): policies = {} username_token_nodes = xmlutil.xpath_find(self._dom, xpath) if not username_token_nodes: self._log.warn("No username token policy nodes found.") return for node in username_token_nodes: policy_node = self._parents[self._parents[self._parents[self._parents[self._parents[self._parents[self._parents[node]]]]]]] policy_id = self._check_policy(policy_node) if policy_id: id_ref = '#' + policy_id policies[id_ref] = {id:id_ref} return policies if policies else None def _check_soap_action_and_transport(self, binding_node): soap_action = "" soap_transport = "" name = binding_node.get('name') soap_transport_attributes = "" soap_action_attributes = xmlutil.xpath_find(binding_node, SOAP_ACTION_XPATH)[0].attrib['soapAction'] if soap_action_attributes: soap_action = soap_action_attributes soap_transport_attributes = xmlutil.xpath_find(binding_node, SOAP_TRANSPORT_XPATH)[0].attrib['transport'] if soap_transport_attributes: soap_transport = soap_transport_attributes if soap_transport == SOAP_HTTP_TRANSPORT_VALUE: if soap_action == RST_SOAP_ACTION_13: self._log.debug('found binding matching Action and Transport: %s', name) return WSTrustVersion.WSTRUST13 elif soap_action == RST_SOAP_ACTION_2005: self._log.debug('found binding matching Action and Transport: %s', name) return WSTrustVersion.WSTRUST2005 self._log.debug('binding node did not match soap Action or Transport: %s', name) return WSTrustVersion.UNDEFINED def _get_matching_bindings(self, policies): bindings = {} binding_policy_ref_nodes = xmlutil.xpath_find(self._dom, 'wsdl:binding/wsp:PolicyReference') for node in binding_policy_ref_nodes: uri = node.get('URI') policy = policies.get(uri) if policy: binding_node = self._parents[node] binding_name = binding_node.get('name') version = self._check_soap_action_and_transport(binding_node) if version != WSTrustVersion.UNDEFINED: bindings[binding_name] = { 'url': uri, 'version': version } return bindings if bindings else None def _get_ports_for_policy_bindings(self, bindings, policies): port_nodes = xmlutil.xpath_find(self._dom, PORT_XPATH) if not port_nodes: self._log.warn("No ports found") for node in port_nodes: binding_id = node.get('binding') binding_id = binding_id.split(':')[-1] trust_policy = bindings.get(binding_id) if trust_policy: binding_policy = policies.get(trust_policy.get('url')) if binding_policy and not binding_policy.get('url', None): binding_policy['version'] = trust_policy['version'] address_node = node.find(ADDRESS_XPATH, XmlNamespaces.namespaces) if address_node is None: raise AdalError("No address nodes on port") address = xmlutil.find_element_text(address_node) if _url_is_secure(address): binding_policy['url'] = address else: self._log.warn("Skipping insecure endpoint: %s", address) def _select_single_matching_policy(self, policies): matching_policies = [p for p in policies.values() if p.get('url')] if not matching_policies: self._log.warn("No policies found with a url.") return wstrust13_policy = None wstrust2005_policy = None for policy in matching_policies: version = policy.get('version', None) if version == WSTrustVersion.WSTRUST13: wstrust13_policy = policy elif version == WSTrustVersion.WSTRUST2005: wstrust2005_policy = policy if wstrust13_policy is None and wstrust2005_policy is None: self._log.warn('No policies found for either wstrust13 or wstrust2005') self.username_password_policy = wstrust13_policy or wstrust2005_policy def _parse(self): policies = self._select_username_password_polices( 'wsp:Policy/wsp:ExactlyOne/wsp:All/sp:SignedEncryptedSupportingTokens/wsp:Policy/sp:UsernameToken/wsp:Policy/sp:WssUsernameToken10') xpath2005 = 'wsp:Policy/wsp:ExactlyOne/wsp:All/sp2005:SignedSupportingTokens/wsp:Policy/sp2005:UsernameToken/wsp:Policy/sp2005:WssUsernameToken10' if policies: policies2005 = self._select_username_password_polices(xpath2005) if policies2005: policies.update(policies2005) else: policies = self._select_username_password_polices(xpath2005) if not policies: raise AdalError("No matching policies.") bindings = self._get_matching_bindings(policies) if not bindings: raise AdalError("No matching bindings.") self._get_ports_for_policy_bindings(bindings, policies) self._select_single_matching_policy(policies) if not self._url: raise AdalError("No ws-trust endpoints match requirements.") adal-0.4.6/adal/oauth2_client.py0000644000175000017500000003362213133203172017300 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ from datetime import datetime, timedelta import math import re import json import time import uuid try: from urllib.parse import urlencode, urlparse except ImportError: from urllib import urlencode # pylint: disable=no-name-in-module from urlparse import urlparse # pylint: disable=import-error,ungrouped-imports import requests from . import log from . import util from .constants import OAuth2, TokenResponseFields, IdTokenFields from .adal_error import AdalError TOKEN_RESPONSE_MAP = { OAuth2.ResponseParameters.TOKEN_TYPE : TokenResponseFields.TOKEN_TYPE, OAuth2.ResponseParameters.ACCESS_TOKEN : TokenResponseFields.ACCESS_TOKEN, OAuth2.ResponseParameters.REFRESH_TOKEN : TokenResponseFields.REFRESH_TOKEN, OAuth2.ResponseParameters.CREATED_ON : TokenResponseFields.CREATED_ON, OAuth2.ResponseParameters.EXPIRES_ON : TokenResponseFields.EXPIRES_ON, OAuth2.ResponseParameters.EXPIRES_IN : TokenResponseFields.EXPIRES_IN, OAuth2.ResponseParameters.RESOURCE : TokenResponseFields.RESOURCE, OAuth2.ResponseParameters.ERROR : TokenResponseFields.ERROR, OAuth2.ResponseParameters.ERROR_DESCRIPTION : TokenResponseFields.ERROR_DESCRIPTION, } _REQ_OPTION = {'headers' : {'content-type': 'application/x-www-form-urlencoded'}} _ERROR_TEMPLATE = u"{} request returned http error: {}" def map_fields(in_obj, map_to): return dict((map_to[k], v) for k, v in in_obj.items() if k in map_to) def _get_user_id(id_token): user_id = None is_displayable = False if id_token.get('upn'): user_id = id_token['upn'] is_displayable = True elif id_token.get('email'): user_id = id_token['email'] is_displayable = True elif id_token.get('sub'): user_id = id_token['sub'] if not user_id: user_id = str(uuid.uuid4()) user_id_vals = {} user_id_vals[IdTokenFields.USER_ID] = user_id if is_displayable: user_id_vals[IdTokenFields.IS_USER_ID_DISPLAYABLE] = True return user_id_vals def _extract_token_values(id_token): extracted_values = {} extracted_values = map_fields(id_token, OAuth2.IdTokenMap) extracted_values.update(_get_user_id(id_token)) return extracted_values class OAuth2Client(object): def __init__(self, call_context, authority): self._token_endpoint = authority.token_endpoint self._device_code_endpoint = authority.device_code_endpoint self._log = log.Logger("OAuth2Client", call_context['log_context']) self._call_context = call_context self._cancel_polling_request = False def _create_token_url(self): parameters = {} if self._call_context.get('api_version'): parameters[OAuth2.Parameters.AAD_API_VERSION] = self._call_context[ 'api_version'] return urlparse('{}?{}'.format(self._token_endpoint, urlencode(parameters))) def _create_device_code_url(self): parameters = {} parameters[OAuth2.Parameters.AAD_API_VERSION] = '1.0' return urlparse('{}?{}'.format(self._device_code_endpoint, urlencode(parameters))) def _parse_optional_ints(self, obj, keys): for key in keys: try: obj[key] = int(obj[key]) except ValueError: self._log.info("%s could not be parsed as an int", key) raise except KeyError: # if the key isn't present we can just continue pass def _parse_id_token(self, encoded_token): cracked_token = self._open_jwt(encoded_token) if not cracked_token: return try: b64_id_token = cracked_token['JWSPayload'] b64_decoded = util.base64_urlsafe_decode(b64_id_token) if not b64_decoded: self._log.warn('The returned id_token could not be base64 url safe decoded.') return id_token = json.loads(b64_decoded.decode('utf-8')) except ValueError: self._log.info("The returned id_token could not be decoded: %s", encoded_token) raise return _extract_token_values(id_token) def _open_jwt(self, jwt_token): id_token_parts_reg = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$" matches = re.search(id_token_parts_reg, jwt_token) if not matches or len(matches.groups()) < 3: self._log.warn('The token was not parsable.') return {} return { 'header': matches.group(1), 'JWSPayload': matches.group(2), 'JWSSig': matches.group(3) } def _validate_token_response(self, body): try: wire_response = json.loads(body) except ValueError: self._log.info( 'The token response from the server is unparseable as JSON: %s', body) raise int_keys = [ OAuth2.ResponseParameters.EXPIRES_ON, OAuth2.ResponseParameters.EXPIRES_IN, OAuth2.ResponseParameters.CREATED_ON ] self._parse_optional_ints(wire_response, int_keys) expires_in = wire_response.get(OAuth2.ResponseParameters.EXPIRES_IN) if expires_in: now = datetime.now() soon = timedelta(seconds=expires_in) wire_response[OAuth2.ResponseParameters.EXPIRES_ON] = str(now + soon) created_on = wire_response.get(OAuth2.ResponseParameters.CREATED_ON) if created_on: temp_date = datetime.fromtimestamp(created_on) wire_response[OAuth2.ResponseParameters.CREATED_ON] = str(temp_date) if not wire_response.get(OAuth2.ResponseParameters.TOKEN_TYPE): raise AdalError('wire_response is missing token_type', wire_response) if not wire_response.get(OAuth2.ResponseParameters.ACCESS_TOKEN): raise AdalError('wire_response is missing access_token', wire_response) token_response = map_fields(wire_response, TOKEN_RESPONSE_MAP) if wire_response.get(OAuth2.ResponseParameters.ID_TOKEN): id_token = self._parse_id_token(wire_response[OAuth2.ResponseParameters.ID_TOKEN]) if id_token: token_response.update(id_token) return token_response def _validate_device_code_response(self, body): try: wire_response = json.loads(body) except ValueError: self._log.info('The device code response returned from the server is unparseable as JSON:') raise int_keys = [ OAuth2.DeviceCodeResponseParameters.EXPIRES_IN, OAuth2.DeviceCodeResponseParameters.INTERVAL ] self._parse_optional_ints(wire_response, int_keys) if not wire_response.get(OAuth2.DeviceCodeResponseParameters.EXPIRES_IN): raise AdalError('wire_response is missing expires_in', wire_response) if not wire_response.get(OAuth2.DeviceCodeResponseParameters.DEVICE_CODE): raise AdalError('wire_response is missing device_code', wire_response) if not wire_response.get(OAuth2.DeviceCodeResponseParameters.USER_CODE): raise AdalError('wire_response is missing user_code', wire_response) #skip field naming tweak, becasue names from wire are python style already return wire_response def _handle_get_token_response(self, body): try: return self._validate_token_response(body) except Exception: self._log.info("Error validating get token response '%s'", body) raise def _handle_get_device_code_response(self, body): try: return self._validate_device_code_response(body) except Exception: self._log.info("Error validating get user code response '%s'", body) raise def get_token(self, oauth_parameters): token_url = self._create_token_url() url_encoded_token_request = urlencode(oauth_parameters) post_options = util.create_request_options(self, _REQ_OPTION) operation = "Get Token" try: resp = requests.post(token_url.geturl(), data=url_encoded_token_request, headers=post_options['headers'], verify=self._call_context.get('verify_ssl', None)) util.log_return_correlation_id(self._log, operation, resp) except Exception: self._log.info("%s request failed", operation) raise if util.is_http_success(resp.status_code): return self._handle_get_token_response(resp.text) else: return_error_string = _ERROR_TEMPLATE.format(operation, resp.status_code) error_response = "" if resp.text: return_error_string = u"{} and server response: {}".format(return_error_string, resp.text) try: error_response = resp.json() except ValueError: pass raise AdalError(return_error_string, error_response) def get_user_code_info(self, oauth_parameters): device_code_url = self._create_device_code_url() url_encoded_code_request = urlencode(oauth_parameters) post_options = util.create_request_options(self, _REQ_OPTION) operation = "Get Device Code" try: resp = requests.post(device_code_url.geturl(), data=url_encoded_code_request, headers=post_options['headers'], verify=self._call_context.get('verify_ssl', None)) util.log_return_correlation_id(self._log, operation, resp) except Exception: self._log.info("%s request failed", operation) raise if util.is_http_success(resp.status_code): return self._handle_get_device_code_response(resp.text) else: return_error_string = _ERROR_TEMPLATE.format(operation, resp.status_code) error_response = "" if resp.text: return_error_string = u"{} and server response: {}".format(return_error_string, resp.text) try: error_response = resp.json() except ValueError: pass raise AdalError(return_error_string, error_response) def get_token_with_polling(self, oauth_parameters, refresh_internal, expires_in): token_url = self._create_token_url() url_encoded_code_request = urlencode(oauth_parameters) post_options = util.create_request_options(self, _REQ_OPTION) operation = "Get token with device code" max_times_for_retry = math.floor(expires_in/refresh_internal) for _ in range(int(max_times_for_retry)): if self._cancel_polling_request: raise AdalError('Polling_Request_Cancelled') resp = requests.post( token_url.geturl(), data=url_encoded_code_request, headers=post_options['headers'], verify=self._call_context.get('verify_ssl', None)) util.log_return_correlation_id(self._log, operation, resp) wire_response = {} if not util.is_http_success(resp.status_code): # on error, the body should be json already wire_response = json.loads(resp.text) error = wire_response.get(OAuth2.DeviceCodeResponseParameters.ERROR) if error == 'authorization_pending': time.sleep(refresh_internal) continue elif error: raise AdalError('Unexpected polling state {}'.format(error), wire_response) else: try: return self._validate_token_response(resp.text) except Exception: self._log.info(u"Error validating get token response '%s'", resp.text) raise raise AdalError('Timeout from "get_token_with_polling"') def cancel_polling_request(self): self._cancel_polling_request = True adal-0.4.6/adal/self_signed_jwt.py0000644000175000017500000001132713133203172017704 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ import time import datetime import uuid import base64 import binascii import re import jwt from .constants import Jwt from .log import Logger from .adal_error import AdalError def _get_date_now(): return datetime.datetime.now() def _get_new_jwt_id(): return str(uuid.uuid4()) def _create_x5t_value(thumbprint): hex_val = binascii.a2b_hex(thumbprint) return base64.urlsafe_b64encode(hex_val).decode() def _sign_jwt(header, payload, certificate): try: encoded_jwt = _encode_jwt(payload, certificate, header) except Exception as exp: raise AdalError("Error:Invalid Certificate: Expected Start of Certificate to be '-----BEGIN RSA PRIVATE KEY-----'", exp) _raise_on_invalid_jwt_signature(encoded_jwt) return encoded_jwt def _encode_jwt(payload, certificate, header): return jwt.encode(payload, certificate, algorithm='RS256', headers=header).decode() def _raise_on_invalid_jwt_signature(encoded_jwt): segments = encoded_jwt.split('.') if len(segments) < 3 or not segments[2]: raise AdalError('Failed to sign JWT. This is most likely due to an invalid certificate.') class SelfSignedJwt(object): NumCharIn128BitHexString = 128/8*2 numCharIn160BitHexString = 160/8*2 ThumbprintRegEx = r"^[a-f\d]*$" def __init__(self, call_context, authority, client_id): self._log = Logger('SelfSignedJwt', call_context['log_context']) self._call_context = call_context self._authortiy = authority self._token_endpoint = authority.token_endpoint self._client_id = client_id def _create_header(self, thumbprint): x5t = _create_x5t_value(thumbprint) header = {'typ':'JWT', 'alg':'RS256', 'x5t':x5t} self._log.debug("Creating self signed JWT header. x5t: %s", x5t) return header def _create_payload(self): now = _get_date_now() minutes = datetime.timedelta(0, 0, 0, 0, Jwt.SELF_SIGNED_JWT_LIFETIME) expires = now + minutes self._log.debug( 'Creating self signed JWT payload. Expires: %s NotBefore: %s', expires, now) jwt_payload = {} jwt_payload[Jwt.AUDIENCE] = self._token_endpoint jwt_payload[Jwt.ISSUER] = self._client_id jwt_payload[Jwt.SUBJECT] = self._client_id jwt_payload[Jwt.NOT_BEFORE] = int(time.mktime(now.timetuple())) jwt_payload[Jwt.EXPIRES_ON] = int(time.mktime(expires.timetuple())) jwt_payload[Jwt.JWT_ID] = _get_new_jwt_id() return jwt_payload def _raise_on_invalid_thumbprint(self, thumbprint): thumbprint_sizes = [self.NumCharIn128BitHexString, self.numCharIn160BitHexString] size_ok = len(thumbprint) in thumbprint_sizes if not size_ok or not re.search(self.ThumbprintRegEx, thumbprint): raise AdalError("The thumbprint does not match a known format") def _reduce_thumbprint(self, thumbprint): canonical = thumbprint.lower().replace(' ', '').replace(':', '') self._raise_on_invalid_thumbprint(canonical) return canonical def create(self, certificate, thumbprint): thumbprint = self._reduce_thumbprint(thumbprint) header = self._create_header(thumbprint) payload = self._create_payload() return _sign_jwt(header, payload, certificate) adal-0.4.6/adal/token_cache.py0000644000175000017500000001103713133203172016777 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ import json import threading from .constants import TokenResponseFields def _string_cmp(str1, str2): '''Case insensitive comparison. Return true if both are None''' str1 = str1 if str1 is not None else '' str2 = str2 if str2 is not None else '' return str1.lower() == str2.lower() class TokenCacheKey(object): # pylint: disable=too-few-public-methods def __init__(self, authority, resource, client_id, user_id): self.authority = authority self.resource = resource self.client_id = client_id self.user_id = user_id def __hash__(self): return hash((self.authority, self.resource, self.client_id, self.user_id)) def __eq__(self, other): return _string_cmp(self.authority, other.authority) and \ _string_cmp(self.resource, other.resource) and \ _string_cmp(self.client_id, other.client_id) and \ _string_cmp(self.user_id, other.user_id) # pylint: disable=protected-access def _get_cache_key(entry): return TokenCacheKey( entry.get(TokenResponseFields._AUTHORITY), entry.get(TokenResponseFields.RESOURCE), entry.get(TokenResponseFields._CLIENT_ID), entry.get(TokenResponseFields.USER_ID)) class TokenCache(object): def __init__(self, state=None): self._cache = {} self._lock = threading.RLock() if state: self.deserialize(state) self.has_state_changed = False def find(self, query): with self._lock: return self._query_cache( query.get(TokenResponseFields.IS_MRRT), query.get(TokenResponseFields.USER_ID), query.get(TokenResponseFields._CLIENT_ID)) def remove(self, entries): with self._lock: for e in entries: key = _get_cache_key(e) self._cache.pop(key) self.has_state_changed = True def add(self, entries): with self._lock: for e in entries: key = _get_cache_key(e) self._cache[key] = e self.has_state_changed = True def serialize(self): with self._lock: return json.dumps(list(self._cache.values())) def deserialize(self, state): with self._lock: self._cache.clear() if state: tokens = json.loads(state) for t in tokens: key = _get_cache_key(t) self._cache[key] = t def read_items(self): '''output list of tuples in (key, authentication-result)''' with self._lock: return self._cache.items() def _query_cache(self, is_mrrt, user_id, client_id): matches = [] for k in self._cache: v = self._cache[k] #None value will be taken as wildcard match #pylint: disable=too-many-boolean-expressions if ((is_mrrt is None or is_mrrt == v.get(TokenResponseFields.IS_MRRT)) and (user_id is None or _string_cmp(user_id, v.get(TokenResponseFields.USER_ID))) and (client_id is None or _string_cmp(client_id, v.get(TokenResponseFields._CLIENT_ID)))): matches.append(v) return matches adal-0.4.6/adal/token_request.py0000644000175000017500000004244713133203172017435 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ from base64 import b64encode import re from . import constants from . import log from . import mex from . import oauth2_client from . import self_signed_jwt from . import user_realm from . import wstrust_request from .adal_error import AdalError from .cache_driver import CacheDriver from .constants import WSTrustVersion OAUTH2_PARAMETERS = constants.OAuth2.Parameters TOKEN_RESPONSE_FIELDS = constants.TokenResponseFields OAUTH2_GRANT_TYPE = constants.OAuth2.GrantType OAUTH2_SCOPE = constants.OAuth2.Scope OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS = constants.OAuth2.DeviceCodeResponseParameters SAML = constants.Saml ACCOUNT_TYPE = constants.UserRealm.account_type USER_ID = constants.TokenResponseFields.USER_ID _CLIENT_ID = constants.TokenResponseFields._CLIENT_ID #pylint: disable=protected-access def add_parameter_if_available(parameters, key, value): if value: parameters[key] = value def _get_saml_grant_type(wstrust_response): token_type = wstrust_response.token_type if token_type == SAML.TokenTypeV1: return OAUTH2_GRANT_TYPE.SAML1 elif token_type == SAML.TokenTypeV2: return OAUTH2_GRANT_TYPE.SAML2 else: raise AdalError("RSTR returned unknown token type: {}".format(token_type)) class TokenRequest(object): def __init__(self, call_context, authentication_context, client_id, resource, redirect_uri=None): self._log = log.Logger("TokenRequest", call_context['log_context']) self._call_context = call_context self._authentication_context = authentication_context self._resource = resource self._client_id = client_id self._redirect_uri = redirect_uri self._cache_driver = None # should be set at the beginning of get_token # functions that have a user_id self._user_id = None self._user_realm = None # should be set when acquire token using device flow self._polling_client = None def _create_user_realm_request(self, username): return user_realm.UserRealm(self._call_context, username, self._authentication_context.authority.url) def _create_mex(self, mex_endpoint): return mex.Mex(self._call_context, mex_endpoint) def _create_wstrust_request(self, wstrust_endpoint, applies_to, wstrust_endpoint_version): return wstrust_request.WSTrustRequest(self._call_context, wstrust_endpoint, applies_to, wstrust_endpoint_version) def _create_oauth2_client(self): return oauth2_client.OAuth2Client(self._call_context, self._authentication_context.authority) def _create_self_signed_jwt(self): return self_signed_jwt.SelfSignedJwt(self._call_context, self._authentication_context.authority, self._client_id) def _oauth_get_token(self, oauth_parameters): client = self._create_oauth2_client() return client.get_token(oauth_parameters) def _create_cache_driver(self): return CacheDriver( self._call_context, self._authentication_context.authority.url, self._resource, self._client_id, self._authentication_context.cache, self._get_token_with_token_response ) def _find_token_from_cache(self): self._cache_driver = self._create_cache_driver() cache_query = self._create_cache_query() return self._cache_driver.find(cache_query) def _add_token_into_cache(self, token): cache_driver = self._create_cache_driver() self._log.debug('Storing retrieved token into cache') cache_driver.add(token) def _get_token_with_token_response(self, entry, resource): self._log.debug("called to refresh a token from the cache") refresh_token = entry[TOKEN_RESPONSE_FIELDS.REFRESH_TOKEN] return self._get_token_with_refresh_token(refresh_token, resource, None) def _create_cache_query(self): query = {_CLIENT_ID : self._client_id} if self._user_id: query[USER_ID] = self._user_id else: self._log.debug("No user_id passed for cache query") return query def _create_oauth_parameters(self, grant_type): oauth_parameters = {} oauth_parameters[OAUTH2_PARAMETERS.GRANT_TYPE] = grant_type if (OAUTH2_GRANT_TYPE.AUTHORIZATION_CODE != grant_type and OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS != grant_type and OAUTH2_GRANT_TYPE.REFRESH_TOKEN != grant_type and OAUTH2_GRANT_TYPE.DEVICE_CODE != grant_type): oauth_parameters[OAUTH2_PARAMETERS.SCOPE] = OAUTH2_SCOPE.OPENID add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.CLIENT_ID, self._client_id) add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.RESOURCE, self._resource) add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.REDIRECT_URI, self._redirect_uri) return oauth_parameters def _get_token_username_password_managed(self, username, password): self._log.debug('Acquiring token with username password for managed user') oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.PASSWORD) oauth_parameters[OAUTH2_PARAMETERS.PASSWORD] = password oauth_parameters[OAUTH2_PARAMETERS.USERNAME] = username return self._oauth_get_token(oauth_parameters) def _perform_wstrust_assertion_oauth_exchange(self, wstrust_response): self._log.debug("Performing OAuth assertion grant type exchange.") oauth_parameters = {} grant_type = _get_saml_grant_type(wstrust_response) token_bytes = wstrust_response.token assertion = b64encode(token_bytes) oauth_parameters = self._create_oauth_parameters(grant_type) oauth_parameters[OAUTH2_PARAMETERS.ASSERTION] = assertion return self._oauth_get_token(oauth_parameters) def _perform_wstrust_exchange(self, wstrust_endpoint, wstrust_endpoint_version, username, password): wstrust = self._create_wstrust_request(wstrust_endpoint, "urn:federation:MicrosoftOnline", wstrust_endpoint_version) result = wstrust.acquire_token(username, password) if not result.token: err_template = "Unsuccessful RSTR.\n\terror code: {}\n\tfaultMessage: {}" error_msg = err_template.format(result.error_code, result.fault_message) self._log.info(error_msg) raise AdalError(error_msg) return result def _perform_username_password_for_access_token_exchange(self, wstrust_endpoint, wstrust_endpoint_version, username, password): wstrust_response = self._perform_wstrust_exchange(wstrust_endpoint, wstrust_endpoint_version, username, password) return self._perform_wstrust_assertion_oauth_exchange(wstrust_response) def _get_token_username_password_federated(self, username, password): self._log.debug("Acquiring token with username password for federated user") if not self._user_realm.federation_metadata_url: self._log.warn("Unable to retrieve federationMetadataUrl from AAD. " "Attempting fallback to AAD supplied endpoint.") if not self._user_realm.federation_active_auth_url: raise AdalError('AAD did not return a WSTrust endpoint. Unable to proceed.') wstrust_version = TokenRequest._parse_wstrust_version_from_federation_active_authurl( self._user_realm.federation_active_auth_url) self._log.debug('wstrust endpoint version is: %s', wstrust_version) return self._perform_username_password_for_access_token_exchange( self._user_realm.federation_active_auth_url, wstrust_version, username, password) else: mex_endpoint = self._user_realm.federation_metadata_url self._log.debug("Attempting mex at: %s", mex_endpoint) mex_instance = self._create_mex(mex_endpoint) wstrust_version = WSTrustVersion.UNDEFINED try: mex_instance.discover() wstrust_endpoint = mex_instance.username_password_policy['url'] wstrust_version = mex_instance.username_password_policy['version'] except Exception: #pylint: disable=broad-except warn_template = ("MEX exchange failed for %s. " "Attempting fallback to AAD supplied endpoint.") self._log.warn(warn_template, mex_endpoint) wstrust_endpoint = self._user_realm.federation_active_auth_url wstrust_version = TokenRequest._parse_wstrust_version_from_federation_active_authurl( self._user_realm.federation_active_auth_url) if not wstrust_endpoint: raise AdalError('AAD did not return a WSTrust endpoint. Unable to proceed.') return self._perform_username_password_for_access_token_exchange(wstrust_endpoint, wstrust_version, username, password) @staticmethod def _parse_wstrust_version_from_federation_active_authurl(federation_active_authurl): wstrust2005_regex = r'[/trust]?[2005][/usernamemixed]?' wstrust13_regex = r'[/trust]?[13][/usernamemixed]?' if re.search(wstrust2005_regex, federation_active_authurl): return WSTrustVersion.WSTRUST2005 elif re.search(wstrust13_regex, federation_active_authurl): return WSTrustVersion.WSTRUST13 return WSTrustVersion.UNDEFINED def get_token_with_username_password(self, username, password): self._log.info("Acquiring token with username password.") self._user_id = username try: token = self._find_token_from_cache() if token: return token except AdalError as exp: self._log.warn( 'Attempt to look for token in cache resulted in Error: %s', exp, log_stack_trace=True) if not self._authentication_context.authority.is_adfs_authority: self._user_realm = self._create_user_realm_request(username) self._user_realm.discover() try: if self._user_realm.account_type == ACCOUNT_TYPE['Managed']: token = self._get_token_username_password_managed(username, password) elif self._user_realm.account_type == ACCOUNT_TYPE['Federated']: token = self._get_token_username_password_federated(username, password) else: raise AdalError( "Server returned an unknown AccountType: {}".format(self._user_realm.account_type)) self._log.debug("Successfully retrieved token from authority.") except Exception: self._log.info("get_token_func returned with error") raise else: self._log.info('Skipping user realm discovery for ADFS authority') token = self._get_token_username_password_managed(username, password) self._cache_driver.add(token) return token def get_token_with_client_credentials(self, client_secret): self._log.info("Getting token with client credentials.") try: token = self._find_token_from_cache() if token: return token except AdalError as exp: self._log.warn( 'Attempt to look for token in cache resulted in Error: %s', exp, log_stack_trace=True) oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS) oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret token = self._oauth_get_token(oauth_parameters) self._cache_driver.add(token) return token def get_token_with_authorization_code(self, authorization_code, client_secret): self._log.info("Getting token with auth code.") oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.AUTHORIZATION_CODE) oauth_parameters[OAUTH2_PARAMETERS.CODE] = authorization_code oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret return self._oauth_get_token(oauth_parameters) def _get_token_with_refresh_token(self, refresh_token, resource, client_secret): self._log.info("Getting a new token from a refresh token") oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.REFRESH_TOKEN) if resource: oauth_parameters[OAUTH2_PARAMETERS.RESOURCE] = resource if client_secret: oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret oauth_parameters[OAUTH2_PARAMETERS.REFRESH_TOKEN] = refresh_token return self._oauth_get_token(oauth_parameters) def get_token_with_refresh_token(self, refresh_token, client_secret): return self._get_token_with_refresh_token(refresh_token, None, client_secret) def get_token_from_cache_with_refresh(self, user_id): self._log.info("Getting token from cache with refresh if necessary.") self._user_id = user_id return self._find_token_from_cache() def _create_jwt(self, certificate, thumbprint): ssj = self._create_self_signed_jwt() jwt = ssj.create(certificate, thumbprint) if not jwt: raise AdalError("Failed to create JWT.") return jwt def get_token_with_certificate(self, certificate, thumbprint): self._log.info("Getting a token via certificate.") jwt = self._create_jwt(certificate, thumbprint) oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS) oauth_parameters[OAUTH2_PARAMETERS.CLIENT_ASSERTION_TYPE] = OAUTH2_GRANT_TYPE.JWT_BEARER oauth_parameters[OAUTH2_PARAMETERS.CLIENT_ASSERTION] = jwt try: token = self._find_token_from_cache() if token: return token except AdalError as exp: self._log.warn( 'Attempt to look for token in cache resulted in Error: %s', exp, log_stack_trace=True) return self._oauth_get_token(oauth_parameters) def get_token_with_device_code(self, user_code_info): self._log.info("Getting a token via device code") oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.DEVICE_CODE) oauth_parameters[OAUTH2_PARAMETERS.CODE] = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.DEVICE_CODE] interval = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.INTERVAL] expires_in = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.EXPIRES_IN] if interval <= 0: raise AdalError('invalid refresh interval') client = self._create_oauth2_client() self._polling_client = client token = client.get_token_with_polling(oauth_parameters, interval, expires_in) self._add_token_into_cache(token) return token def cancel_token_request_with_device_code(self): self._polling_client.cancel_polling_request() adal-0.4.6/adal/user_realm.py0000644000175000017500000001433213133203172016673 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ import json try: from urllib.parse import quote, urlencode from urllib.parse import urlunparse except ImportError: from urllib import quote, urlencode #pylint: disable=no-name-in-module from urlparse import urlunparse #pylint: disable=import-error import requests from . import constants from . import log from . import util from .adal_error import AdalError USER_REALM_PATH_TEMPLATE = 'common/UserRealm/' ACCOUNT_TYPE = constants.UserRealm.account_type FEDERATION_PROTOCOL_TYPE = constants.UserRealm.federation_protocol_type class UserRealm(object): def __init__(self, call_context, user_principle, authority_url): self._log = log.Logger("UserRealm", call_context['log_context']) self._call_context = call_context self.api_version = '1.0' self.federation_protocol = None self.account_type = None self.federation_metadata_url = None self.federation_active_auth_url = None self._user_principle = user_principle self._authority_url = authority_url def _get_user_realm_url(self): url_components = list(util.copy_url(self._authority_url)) url_encoded_user = quote(self._user_principle, safe='~()*!.\'') url_components[2] = '/' + USER_REALM_PATH_TEMPLATE.replace('', url_encoded_user) user_realm_query = {'api-version':self.api_version} url_components[4] = urlencode(user_realm_query) return util.copy_url(urlunparse(url_components)) @staticmethod def _validate_constant_value(value_dic, value, case_sensitive=False): if not value: return False if not case_sensitive: value = value.lower() return value if value in value_dic.values() else False @staticmethod def _validate_account_type(account_type): return UserRealm._validate_constant_value(ACCOUNT_TYPE, account_type) @staticmethod def _validate_federation_protocol(protocol): return UserRealm._validate_constant_value(FEDERATION_PROTOCOL_TYPE, protocol) def _log_parsed_response(self): self._log.debug('UserRealm response:') self._log.debug(' AccountType: %s', self.account_type) self._log.debug(' FederationProtocol: %s', self.federation_protocol) self._log.debug(' FederationMetatdataUrl: %s', self.federation_metadata_url) self._log.debug(' FederationActiveAuthUrl: %s', self.federation_active_auth_url) def _parse_discovery_response(self, body): self._log.debug("Discovery response:\n %s", body) try: response = json.loads(body) except ValueError: error_template = ("Parsing realm discovery response JSON failed " "for body: '{}'") self._log.info(error_template.format(body)) raise account_type = UserRealm._validate_account_type(response['account_type']) if not account_type: raise AdalError('Cannot parse account_type: {}'.format(account_type)) self.account_type = account_type if self.account_type == ACCOUNT_TYPE['Federated']: protocol = UserRealm._validate_federation_protocol(response['federation_protocol']) if not protocol: raise AdalError('Cannot parse federation protocol: {}'.format(protocol)) self.federation_protocol = protocol self.federation_metadata_url = response['federation_metadata_url'] self.federation_active_auth_url = response['federation_active_auth_url'] self._log_parsed_response() def discover(self): options = util.create_request_options(self, {'headers': {'Accept':'application/json'}}) user_realm_url = self._get_user_realm_url() self._log.debug("Performing user realm discovery at: %s", user_realm_url.geturl()) operation = 'User Realm Discovery' resp = requests.get(user_realm_url.geturl(), headers=options['headers'], verify=self._call_context.get('verify_ssl', None)) util.log_return_correlation_id(self._log, operation, resp) if not util.is_http_success(resp.status_code): return_error_string = u"{} request returned http error: {}".format(operation, resp.status_code) error_response = "" if resp.text: return_error_string = u"{} and server response: {}".format(return_error_string, resp.text) try: error_response = resp.json() except ValueError: pass raise AdalError(return_error_string, error_response) else: self._parse_discovery_response(resp.text) adal-0.4.6/adal/util.py0000644000175000017500000000675713133203172015526 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ import sys import base64 try: from urllib.parse import urlparse except ImportError: from urlparse import urlparse #pylint: disable=import-error import adal from .constants import AdalIdParameters def is_http_success(status_code): return status_code >= 200 and status_code < 300 def add_default_request_headers(self, options): if not options.get('headers'): options['headers'] = {} headers = options['headers'] if not headers.get('Accept-Charset'): headers['Accept-Charset'] = 'utf-8' #pylint: disable=protected-access headers['client-request-id'] = self._call_context['log_context']['correlation_id'] headers['return-client-request-id'] = 'true' headers[AdalIdParameters.SKU] = AdalIdParameters.PYTHON_SKU headers[AdalIdParameters.VERSION] = adal.__version__ headers[AdalIdParameters.OS] = sys.platform headers[AdalIdParameters.CPU] = 'x64' if sys.maxsize > 2 ** 32 else 'x86' def create_request_options(self, *options): merged_options = {} if options: for i in options: merged_options.update(i) #pylint: disable=protected-access if self._call_context.get('options') and self._call_context['options'].get('http'): merged_options.update(self._call_context['options']['http']) add_default_request_headers(self, merged_options) return merged_options def log_return_correlation_id(log, operation_message, response): if response and response.headers and response.headers.get('client-request-id'): log.info("{} Server returned this correlation_id: {}".format( operation_message, response.headers['client-request-id'])) def copy_url(url_source): if hasattr(url_source, 'geturl'): return urlparse(url_source.geturl()) else: return urlparse(url_source) # urlsafe_b64decode requires correct padding. AAD does not include padding so # the string needs to be correctly padded before decoding. def base64_urlsafe_decode(b64string): b64string += '=' * (4 - ((len(b64string) % 4))) return base64.urlsafe_b64decode(b64string.encode('ascii')) adal-0.4.6/adal/wstrust_request.py0000644000175000017500000002032213133203172020034 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ import uuid from datetime import datetime, timedelta import requests from . import log from . import util from . import wstrust_response from .adal_error import AdalError from .constants import WSTrustVersion _USERNAME_PLACEHOLDER = '{UsernamePlaceHolder}' _PASSWORD_PLACEHOLDER = '{PasswordPlaceHolder}' class WSTrustRequest(object): def __init__(self, call_context, watrust_endpoint_url, applies_to, wstrust_endpoint_version): self._log = log.Logger('WSTrustRequest', call_context['log_context']) self._call_context = call_context self._wstrust_endpoint_url = watrust_endpoint_url self._applies_to = applies_to self._wstrust_endpoint_version = wstrust_endpoint_version @staticmethod def _build_security_header(): time_now = datetime.utcnow() expire_time = time_now + timedelta(minutes=10) time_now_str = time_now.isoformat()[:-3] + 'Z' expire_time_str = expire_time.isoformat()[:-3] + 'Z' security_header_xml = ("" "" "" + time_now_str + "" "" + expire_time_str + "" "" "" "" + _USERNAME_PLACEHOLDER + "" "" + _PASSWORD_PLACEHOLDER + "" "" "") return security_header_xml @staticmethod def _populate_rst_username_password(template, username, password): password = WSTrustRequest._escape_password(password) return template.replace(_USERNAME_PLACEHOLDER, username).replace(_PASSWORD_PLACEHOLDER, password) @staticmethod def _escape_password(password): return password.replace('&', '&').replace('"', '"').replace("'", ''').replace('<', '<').replace('>', '>') def _build_rst(self, username, password): message_id = str(uuid.uuid4()) schema_location = 'http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd' soap_action = 'http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue' rst_trust_namespace = 'http://docs.oasis-open.org/ws-sx/ws-trust/200512' key_type = 'http://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearer' request_type = 'http://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue' if self._wstrust_endpoint_version == WSTrustVersion.WSTRUST2005: soap_action = 'http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue' rst_trust_namespace = 'http://schemas.xmlsoap.org/ws/2005/02/trust' key_type = 'http://schemas.xmlsoap.org/ws/2005/05/identity/NoProofKey' request_type = 'http://schemas.xmlsoap.org/ws/2005/02/trust/Issue' rst_template = ("".format(schema_location) + "" + "{}".format(soap_action) + "urn:uuid:{}".format(message_id) + "" + "http://www.w3.org/2005/08/addressing/anonymous" + "" + "{}".format(self._wstrust_endpoint_url) + WSTrustRequest._build_security_header() + "" + "" + "".format(rst_trust_namespace) + "" + "" + "{}".format(self._applies_to) + "" + "" + "{}".format(key_type) + "{}".format(request_type) + "" + "" + "") self._log.debug('Created RST: \n' + rst_template) return WSTrustRequest._populate_rst_username_password(rst_template, username, password) def _handle_rstr(self, body): wstrust_resp = wstrust_response.WSTrustResponse(self._call_context, body, self._wstrust_endpoint_version) wstrust_resp.parse() return wstrust_resp def acquire_token(self, username, password): if self._wstrust_endpoint_version == WSTrustVersion.UNDEFINED: raise AdalError('Unsupported wstrust endpoint version. Current support version is wstrust2005 or wstrust13.') rst = self._build_rst(username, password) if self._wstrust_endpoint_version == WSTrustVersion.WSTRUST2005: soap_action = 'http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue' else: soap_action = 'http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue' headers = {'headers': {'Content-type':'application/soap+xml; charset=utf-8', 'SOAPAction': soap_action}, 'body': rst} options = util.create_request_options(self, headers) self._log.debug("Sending RST to: %s", self._wstrust_endpoint_url) operation = "WS-Trust RST" resp = requests.post(self._wstrust_endpoint_url, headers=options['headers'], data=rst, allow_redirects=True, verify=self._call_context.get('verify_ssl', None)) util.log_return_correlation_id(self._log, operation, resp) if not util.is_http_success(resp.status_code): return_error_string = u"{} request returned http error: {}".format(operation, resp.status_code) error_response = "" if resp.text: return_error_string = u"{} and server response: {}".format(return_error_string, resp.text) try: error_response = resp.json() except ValueError: pass raise AdalError(return_error_string, error_response) else: return self._handle_rstr(resp.text) adal-0.4.6/adal/wstrust_response.py0000644000175000017500000002105713133203172020210 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ try: from xml.etree import cElementTree as ET except ImportError: from xml.etree import ElementTree as ET import re from . import xmlutil from . import log from .adal_error import AdalError from .constants import WSTrustVersion # Creates a log message that contains the RSTR scrubbed of the actual SAML assertion. def scrub_rstr_log_message(response_str): # A regular expression for finding the SAML Assertion in an response_str. Used to remove the SAML # assertion when logging the response_str. assertion_regex = r'RequestedSecurityToken.*?((<.*?:Assertion.*?>).*<\/.*?Assertion>).*?' single_line_rstr, _ = re.subn(r'(\r\n|\n|\r)', '', response_str) match = re.search(assertion_regex, single_line_rstr) if not match: #No Assertion was matched so just return the response_str as is. scrubbed_rstr = single_line_rstr else: saml_assertion = match.group(1) saml_assertion_start_tag = match.group(2) scrubbed_rstr = single_line_rstr.replace( saml_assertion, saml_assertion_start_tag + 'ASSERTION CONTENTS REDACTED') return 'RSTR Response: ' + scrubbed_rstr class WSTrustResponse(object): def __init__(self, call_context, response, wstrust_version): self._log = log.Logger("WSTrustResponse", call_context['log_context']) self._call_context = call_context self._response = response self._dom = None self._parents = None self.error_code = None self.fault_message = None self.token_type = None self.token = None self._wstrust_version = wstrust_version if response: self._log.debug(scrub_rstr_log_message(response)) # Sample error message # # # http://www.w3.org/2005/08/addressing/soap/fault # - # # 2013-07-30T00:32:21.989Z # 2013-07-30T00:37:21.989Z # # # # # # # s:Sender # # a:RequestFailed # # # # MSIS3127: The specified request failed. # # # # def _parse_error(self): error_found = False fault_node = xmlutil.xpath_find(self._dom, 's:Body/s:Fault/s:Reason/s:Text') if fault_node: self.fault_message = fault_node[0].text if self.fault_message: error_found = True # Subcode has minoccurs=0 and maxoccurs=1(default) according to the http://www.w3.org/2003/05/soap-envelope # Subcode may have another subcode as well. This is only targetting at top level subcode. # Subcode value may have different messages not always uses http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd. # text inside the value is not possible to select without prefix, so substring is necessary subnode = xmlutil.xpath_find(self._dom, 's:Body/s:Fault/s:Code/s:Subcode/s:Value') if len(subnode) > 1: raise AdalError("Found too many fault code values: {}".format(len(subnode))) if subnode: error_code = subnode[0].text self.error_code = error_code.split(':')[1] return error_found def _parse_token(self): if self._wstrust_version == WSTrustVersion.WSTRUST2005: token_type_nodes_xpath = 's:Body/t:RequestSecurityTokenResponse/t:TokenType' security_token_xpath = 't:RequestedSecurityToken' else: token_type_nodes_xpath = 's:Body/wst:RequestSecurityTokenResponseCollection/wst:RequestSecurityTokenResponse/wst:TokenType' security_token_xpath = 'wst:RequestedSecurityToken' token_type_nodes = xmlutil.xpath_find(self._dom, token_type_nodes_xpath) if not token_type_nodes: raise AdalError("No TokenType nodes found in RSTR") for node in token_type_nodes: if self.token: self._log.warn("Found more than one returned token. Using the first.") break token_type = xmlutil.find_element_text(node) if not token_type: self._log.warn("Could not find token type in RSTR token.") requested_token_node = xmlutil.xpath_find(self._parents[node], security_token_xpath) if len(requested_token_node) > 1: raise AdalError("Found too many RequestedSecurityToken nodes for token type: {}".format(token_type)) if not requested_token_node: self._log.warn( "Unable to find RequestsSecurityToken element associated with TokenType element: %s", token_type) continue # Adjust namespaces (without this they are autogenerated) so this is understood # by the receiver. Then make a string repr of the element tree node. ET.register_namespace('saml', 'urn:oasis:names:tc:SAML:1.0:assertion') ET.register_namespace('ds', 'http://www.w3.org/2000/09/xmldsig#') token = ET.tostring(requested_token_node[0][0]) if token is None: self._log.warn( "Unable to find token associated with TokenType element: %s", token_type) continue self.token = token self.token_type = token_type self._log.info("Found token of type: %s", self.token_type) if self.token is None: raise AdalError("Unable to find any tokens in RSTR.") def parse(self): if not self._response: raise AdalError("Received empty RSTR response body.") try: self._dom = ET.fromstring(self._response) except Exception as exp: raise AdalError('Failed to parse RSTR in to DOM', exp) try: self._parents = {c:p for p in self._dom.iter() for c in p} error_found = self._parse_error() if error_found: str_error_code = self.error_code or 'NONE' str_fault_message = self.fault_message or 'NONE' error_template = 'Server returned error in RSTR - ErrorCode: {} : FaultMessage: {}' raise AdalError(error_template.format(str_error_code, str_fault_message)) self._parse_token() finally: self._dom = None self._parents = None adal-0.4.6/adal/xmlutil.py0000644000175000017500000000536613133203172016242 0ustar travistravis00000000000000#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ try: from xml.etree import cElementTree as ET except ImportError: from xml.etree import ElementTree as ET from . import constants XPATH_PATH_TEMPLATE = '*[local-name() = \'LOCAL_NAME\' and namespace-uri() = \'NAMESPACE\']' def expand_q_names(xpath): namespaces = constants.XmlNamespaces.namespaces path_parts = xpath.split('/') for index, part in enumerate(path_parts): if part.find(":") != -1: q_parts = part.split(':') if len(q_parts) != 2: raise IndexError("Unable to parse XPath string: {} with QName: {}".format(xpath, part)) expanded_path = XPATH_PATH_TEMPLATE.replace('LOCAL_NAME', q_parts[1]) expanded_path = expanded_path.replace('NAMESPACE', namespaces[q_parts[0]]) path_parts[index] = expanded_path return '/'.join(path_parts) def xpath_find(dom, xpath): return dom.findall(xpath, constants.XmlNamespaces.namespaces) def serialize_node_children(node): doc = "" for child in node.iter(): if is_element_node(child): estring = ET.tostring(child) doc += estring if isinstance(estring, str) else estring.decode() return doc if doc else None def is_element_node(node): return hasattr(node, 'tag') def find_element_text(node): for child in node.iter(): if child.text: return child.text adal-0.4.6/adal.egg-info/0000755000175000017500000000000013133203245015653 5ustar travistravis00000000000000adal-0.4.6/adal.egg-info/PKG-INFO0000644000175000017500000000160413133203245016751 0ustar travistravis00000000000000Metadata-Version: 1.1 Name: adal Version: 0.4.6 Summary: The ADAL for Python library makes it easy for python application to authenticate to Azure Active Directory (AAD) in order to access AAD protected web resources. Home-page: https://github.com/AzureAD/azure-activedirectory-library-for-python Author: Microsoft Corporation Author-email: nugetaad@microsoft.com License: MIT Description: UNKNOWN Platform: UNKNOWN Classifier: Development Status :: 3 - Alpha Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2 Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: License :: OSI Approved :: MIT License adal-0.4.6/adal.egg-info/SOURCES.txt0000644000175000017500000000106213133203245017536 0ustar travistravis00000000000000setup.cfg setup.py adal/__init__.py adal/adal_error.py adal/argument.py adal/authentication_context.py adal/authentication_parameters.py adal/authority.py adal/cache_driver.py adal/code_request.py adal/constants.py adal/log.py adal/mex.py adal/oauth2_client.py adal/self_signed_jwt.py adal/token_cache.py adal/token_request.py adal/user_realm.py adal/util.py adal/wstrust_request.py adal/wstrust_response.py adal/xmlutil.py adal.egg-info/PKG-INFO adal.egg-info/SOURCES.txt adal.egg-info/dependency_links.txt adal.egg-info/requires.txt adal.egg-info/top_level.txtadal-0.4.6/adal.egg-info/dependency_links.txt0000644000175000017500000000000113133203245021721 0ustar travistravis00000000000000 adal-0.4.6/adal.egg-info/requires.txt0000644000175000017500000000011013133203245020243 0ustar travistravis00000000000000PyJWT>=1.0.0 requests>=2.0.0 python-dateutil>=2.1.0 cryptography>=1.1.0 adal-0.4.6/adal.egg-info/top_level.txt0000644000175000017500000000000513133203245020400 0ustar travistravis00000000000000adal adal-0.4.6/setup.py0000644000175000017500000000566113133203172015001 0ustar travistravis00000000000000#!/usr/bin/env python #------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ from setuptools import setup import re, io # setup.py shall not import adal __version__ = re.search( r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', # It excludes inline comment too io.open('adal/__init__.py', encoding='utf_8_sig').read() ).group(1) # To build: # python setup.py sdist # python setup.py bdist_wheel # # To install: # python setup.py install # # To register (only needed once): # python setup.py register # # To upload: # python setup.py sdist upload # python setup.py bdist_wheel upload setup( name='adal', version=__version__, description=('The ADAL for Python library makes it easy for python ' + 'application to authenticate to Azure Active Directory ' + '(AAD) in order to access AAD protected web resources.'), license='MIT', author='Microsoft Corporation', author_email='nugetaad@microsoft.com', url='https://github.com/AzureAD/azure-activedirectory-library-for-python', classifiers=[ 'Development Status :: 3 - Alpha', 'Programming Language :: Python', 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'License :: OSI Approved :: MIT License', ], packages=['adal'], install_requires=[ 'PyJWT>=1.0.0', 'requests>=2.0.0', 'python-dateutil>=2.1.0', 'cryptography>=1.1.0' ] ) adal-0.4.6/PKG-INFO0000644000175000017500000000160413133203245014356 0ustar travistravis00000000000000Metadata-Version: 1.1 Name: adal Version: 0.4.6 Summary: The ADAL for Python library makes it easy for python application to authenticate to Azure Active Directory (AAD) in order to access AAD protected web resources. Home-page: https://github.com/AzureAD/azure-activedirectory-library-for-python Author: Microsoft Corporation Author-email: nugetaad@microsoft.com License: MIT Description: UNKNOWN Platform: UNKNOWN Classifier: Development Status :: 3 - Alpha Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2 Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: License :: OSI Approved :: MIT License adal-0.4.6/setup.cfg0000644000175000017500000000020113133203245015072 0ustar travistravis00000000000000[bdist_wheel] universal = 1 [metadata] description-file = README.md [egg_info] tag_build = tag_date = 0 tag_svn_revision = 0