pax_global_header00006660000000000000000000000064144202636240014515gustar00rootroot0000000000000052 comment=4d3329f156bd00305db380a0250535af7e752799 microsoft-authentication-library-for-go-1.0.0/000077500000000000000000000000001442026362400213665ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/.github/000077500000000000000000000000001442026362400227265ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/.github/ISSUE_TEMPLATE/000077500000000000000000000000001442026362400251115ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/.github/ISSUE_TEMPLATE/bug_report.md000066400000000000000000000032631442026362400276070ustar00rootroot00000000000000--- name: Bug report about: Please do NOT file bugs without filling in this form. title: '[Bug] ' labels: '' assignees: '' --- **Which version of MSAL Go are you using?** Note that to get help, you need to run the latest version. **Where is the issue?** * Public client * [ ] Device code flow * [ ] Username/Password (ROPC grant) * [ ] Authorization code flow * Confidential client * [ ] Authorization code flow * [ ] Client credentials: * [ ] client secret * [ ] client certificate * Token cache serialization * [ ] In-memory cache * Other (please describe) **Is this a new or an existing app?** **What version of Go are you using (`go version`)?**
$ go version
**What operating system and processor architecture are you using (`go env`)?**
go env Output
$ go env

**Repro** var your = (code) => here; **Expected behavior** A clear and concise description of what you expected to happen (or code). **Actual behavior** A clear and concise description of what happens, e.g. an exception is thrown, UI freezes. **Possible solution** **Additional context / logs / screenshots** Add any other context about the problem here, such as logs and screenshots. microsoft-authentication-library-for-go-1.0.0/.github/ISSUE_TEMPLATE/documentation.md000066400000000000000000000007551442026362400303130ustar00rootroot00000000000000--- name: Documentation about: Suggest a change to the documentation. title: '[Documentation] ' labels: documentation assignees: '' --- ### Documentation related to component ### Please check all that apply - [ ] typo - [ ] documentation doesn't exist - [ ] documentation needs clarification - [ ] error(s) in the example - [ ] needs an example ### Description of the issue microsoft-authentication-library-for-go-1.0.0/.github/ISSUE_TEMPLATE/feature_request.md000066400000000000000000000012011442026362400306300ustar00rootroot00000000000000--- name: Feature request about: Suggest an idea for this project. title: "[Feature Request] " labels: enhancement, Feature Request assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]. **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. microsoft-authentication-library-for-go-1.0.0/.github/workflows/000077500000000000000000000000001442026362400247635ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/.github/workflows/go.yml000066400000000000000000000026211442026362400261140ustar00rootroot00000000000000name: Go on: push: branches: [dev] pull_request: # This guards against unknown PR until a community member vet it and label it. types: [ labeled ] jobs: build: name: Build runs-on: ubuntu-latest strategy: matrix: go: ["1.19", "1.20"] steps: - name: Set up Go 1.x uses: actions/setup-go@v2 with: go-version: ${{ matrix.go }} id: go - name: Check out code into the Go module directory uses: actions/checkout@v2 - name: Get dependencies run: go get -v -t -d ./... # designed to only run on linux - name: Format Check run: if [ $(gofmt -l -s . | wc -l) -ne 0 ]; then echo "fmt failed"; exit 1; fi - name: Build run: go build ./apps/... - name: Unit Tests run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/... - name: Integration Tests run: go test -race ./apps/tests/integration/... env : clientId: ${{ secrets.LAB_APP_CLIENT_ID }} clientSecret: ${{ secrets.LAB_APP_CLIENT_SECRET }} oboConfidentialClientId: ${{ secrets.OBO_CONFIDENTIAL_APP_CLIENT_ID }} oboConfidentialClientSecret: ${{ secrets.OBO_CONFIDENTIAL_APP_CLIENT_SECRET }} oboPublicClientId: ${{ secrets.OBO_PUBLIC_APP_CLIENT_ID }} CI: ${{secrets.ENABLECI}} microsoft-authentication-library-for-go-1.0.0/.github/workflows/golangci-lint.yml000066400000000000000000000014101442026362400302310ustar00rootroot00000000000000name: golangci-lint on: push: tags: - v* branches: - master - main - dev pull_request: # This guards against unknown PR until a community member vet it and label it. types: [ labeled ] jobs: golangci: name: lint runs-on: ubuntu-latest steps: - uses: actions/setup-go@v3 with: go-version: "1.20" - uses: actions/checkout@v3 - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. version: v1.51 # Optional: golangci-lint command line arguments. # args: --issues-exit-code=0 microsoft-authentication-library-for-go-1.0.0/.gitignore000066400000000000000000000004571442026362400233640ustar00rootroot00000000000000# Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib *.golangci.yml *.swp *.pprof # OSX specific os files *.DS_Store # Test binary, build with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out pkg/ github.com/ golang.org/ .vscode/ .idea/ microsoft-authentication-library-for-go-1.0.0/.golangci.yml000066400000000000000000000004231442026362400237510ustar00rootroot00000000000000linters: # enabled in addition to default enable: - gosec issues: # Excluding configuration per-path, per-linter, per-text and per-source exclude-rules: # Exclude some linters from running on tests files. - path: _test\.go linters: - gosec microsoft-authentication-library-for-go-1.0.0/CODE_OF_CONDUCT.md000066400000000000000000000007051442026362400241670ustar00rootroot00000000000000# Microsoft Open Source Code of Conduct This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). Resources: - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns microsoft-authentication-library-for-go-1.0.0/CONTRIBUTING.md000066400000000000000000000077151442026362400236310ustar00rootroot00000000000000# Microsoft Authentication Library for Go welcomes new contributors This document will guide you through the process. ## Contributor License agreement Please visit [https://cla.microsoft.com/](https://cla.microsoft.com/) and sign the Contributor License Agreement. You only need to do that once. We can not look at your code until you've submitted this request. ## FORK Fork the project [on GitHub](https://github.com/AzureAD/microsoft-authentication-library-for-go) and check out your copy. Example for MSAL Go: ``` $ git clone git@github.com:username/microsoft-authentication-library-for-go.git $ cd microsoft-authentication-library-for-go $ git remote add upstream git@github.com:AzureAD/microsoft-authentication-library-for-go.git ``` ## Setup, Building and Testing Please see the [Build & Run](https://github.com/AzureAD/microsoft-authentication-library-for-go/wiki/build-and-test) wiki page. ## Decide on which branch to create **Bug fixes for the current stable version need to go to 'master' branch.** If you need to contribute to a different branch, please contact us first (open an issue). All details after this point is standard - make sure your commits have nice messages, and prefer rebase to merge. In case of doubt, please open an issue in the [issue tracker](https://github.com/AzureAD/microsoft-authentication-library-for-go/issues). Especially do so if you plan to work on a major change in functionality. Nothing is more frustrating than seeing your hard work go to waste because your vision does not align with our goals for the SDK. ## Branch Okay, so you have decided on the proper branch. Create a feature branch and start hacking: ``` $ git checkout -b my-feature-branch ``` ## Commit Make sure git knows your name and email address: ``` $ git config --global user.name "J. Random User" $ git config --global user.email "j.random.user@example.com" ``` Writing good commit logs is important. A commit log should describe what changed and why. Follow these guidelines when writing one: 1. The first line should be 50 characters or less and contain a short description of the change prefixed with the name of the changed subsystem (e.g. "net: add localAddress and localPort to Socket"). 2. Keep the second line blank. 3. Wrap all other lines at 72 columns. A good commit log looks like this: ``` fix: explaining the commit in one line Body of commit message is a few lines of text, explaining things in more detail, possibly giving some background about the issue being fixed, etc etc. The body of the commit message can be several paragraphs, and please do proper word-wrap and keep columns shorter than about 72 characters or so. That way `git log` will show things nicely even when it is indented. ``` The header line should be meaningful; it is what other people see when they run `git shortlog` or `git log --oneline`. Check the output of `git log --oneline files_that_you_changed` to find out what directories your changes touch. ### Rebase Use `git rebase` (not `git merge`) to sync your work from time to time. ``` $ git fetch upstream $ git rebase upstream/v0.1 # or upstream/master ``` ### Tests It's all standard stuff, but please note that you won't be able to run integration tests locally because they connect to a KeyVault to fetch some test users and passwords. The CI will run them for you. ### Push ``` $ git push origin my-feature-branch ``` Go to `https://github.com/username/microsoft-authentication-library-for-go` and select your feature branch. Click the 'Pull Request' button and fill out the form. Pull requests are usually reviewed within a few days. If there are comments to address, apply your changes in a separate commit and push that to your feature branch. Post a comment in the pull request afterwards; GitHub does not send out notifications when you add commits. [on GitHub]: https://github.com/AzureAD/microsoft-authentication-library-for-go [issue tracker]: https://github.com/AzureAD/microsoft-authentication-library-for-go/issues microsoft-authentication-library-for-go-1.0.0/LICENSE000066400000000000000000000022121442026362400223700ustar00rootroot00000000000000 MIT License Copyright (c) Microsoft Corporation. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE microsoft-authentication-library-for-go-1.0.0/README.md000066400000000000000000000210671442026362400226530ustar00rootroot00000000000000# Microsoft Authentication Library (MSAL) for Go **MSAL for Go is a new addition to the MSAL family of libraries, has been made available in production ready preview to gauge customer interest and to gather feedback from the community. We welcome all contributors (see [CONTRIBUTING.md](https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/dev/CONTRIBUTING.md)) to help us grow our list of supported MSAL SDKs.** The Microsoft Authentication Library (MSAL) for Go is part of the [Microsoft identity platform for developers](https://aka.ms/aaddevv2) (formerly named Azure AD) v2.0. It allows you to sign in users or apps with Microsoft identities ([Azure AD](https://azure.microsoft.com/services/active-directory/) and [Microsoft Accounts](https://account.microsoft.com)) and obtain tokens to call Microsoft APIs such as [Microsoft Graph](https://graph.microsoft.io/) or your own APIs registered with the Microsoft identity platform. It is built using industry standard OAuth2 and OpenID Connect protocols. The latest code resides in the `dev` branch. Quick links: | [Getting Started](https://docs.microsoft.com/azure/active-directory/develop/#quickstarts) | [GoDoc](https://pkg.go.dev/github.com/AzureAD/microsoft-authentication-library-for-go/apps) | [Wiki](https://github.com/AzureAD/microsoft-authentication-library-for-go/wiki) | [Samples](https://github.com/AzureAD/microsoft-authentication-library-for-go/tree/dev/apps/tests/devapps) | [Support](README.md#community-help-and-support) | [Feedback](https://forms.office.com/r/s4waBAytFJ) | | ------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------- | ## Build Status ![Go](https://github.com/AzureAD/microsoft-authentication-library-for-go/workflows/Go/badge.svg?branch=dev) ## Installation ### Setting up Go To install Go, visit [this link](https://golang.org/dl/). ### Installing MSAL Go `go get -u github.com/AzureAD/microsoft-authentication-library-for-go/` ## Usage Before using MSAL Go, you will need to [register your application with the Microsoft identity platform](https://docs.microsoft.com/azure/active-directory/develop/quickstart-v2-register-an-app). ### Public Surface The Public API of the library can be found in the following directories under `apps`. ``` apps/ - Contains all our code confidential/ - The confidential application API public/ - The public application API cache/ - The cache interface that can be implemented to provide persistence cache storage of credentials ``` Acquiring tokens with MSAL Go follows this general three step pattern. There might be some slight differences for other token acquisition flows. Here is a basic example: 1. MSAL separates [public and confidential client applications](https://tools.ietf.org/html/rfc6749#section-2.1). So, you would create an instance of a `PublicClientApplication` and `ConfidentialClientApplication` and use this throughout the lifetime of your application. * Initializing a public client: ```go publicClientApp, err := public.New("client_id", public.WithAuthority("https://login.microsoft.com/Enter_The_Tenant_Name_Here")) ``` * Initializing a confidential client: ```go // Initializing the client credential cred, err := confidential.NewCredFromSecret("client_secret") if err != nil { return nil, fmt.Errorf("could not create a cred from a secret: %w", err) } confidentialClientApp, err := confidential.New("client_id", cred, confidential.WithAuthority("https://login.microsoft.com/Enter_The_Tenant_Name_Here")) ``` 1. MSAL comes packaged with an in-memory cache. Utilizing the cache is optional, but we would highly recommend it. ```go var userAccount public.Account accounts := publicClientApp.Accounts() if len(accounts) > 0 { // Assuming the user wanted the first account userAccount = accounts[0] // found a cached account, now see if an applicable token has been cached result, err := publicClientApp.AcquireTokenSilent(context.Background(), []string{"your_scope"}, public.WithSilentAccount(userAccount)) accessToken := result.AccessToken } ``` 1. If there is no suitable token in the cache, or you choose to skip this step, now we can send a request to AAD to obtain a token. ```go result, err := publicClientApp.AcquireToken"ByOneofTheActualMethods"([]string{"your_scope"}, ...(other parameters depending on the function)) if err != nil { log.Fatal(err) } accessToken := result.AccessToken ``` You can view the [dev apps](https://github.com/AzureAD/microsoft-authentication-library-for-go/tree/dev/apps/tests/devapps) on how to use MSAL Go with various application types in various scenarios. For more detailed information, please refer to the [wiki](https://github.com/AzureAD/microsoft-authentication-library-for-go/wiki). # Releases The list of [releases](https://github.com/AzureAD/microsoft-authentication-library-for-go/releases) ## Roadmap This is a preview library. Details of the roadmap will come soon in the [wiki pages](https://github.com/AzureAD/microsoft-authentication-library-for-go/wiki), along with release notes. ## Community Help and Support We use [Stack Overflow](http://stackoverflow.com/questions/tagged/msal) to work with the community on supporting Azure Active Directory and its SDKs, including this one! We highly recommend you ask your questions on Stack Overflow (we're all on there!) Also browse existing issues to see if someone has had your question before. Please use the "msal" tag when asking your questions. If you find and bug or have a feature request, please raise the issue on [GitHub Issues](https://github.com/AzureAD/microsoft-authentication-library-for-go/issues). ## Submit Feedback We'd like your thoughts on this library. Please complete [this short survey.](https://forms.office.com/r/s4waBAytFJ) ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. ## Security Library This library controls how users sign-in and access services. We recommend you always take the latest version of our library in your app when possible. We use [semantic versioning](http://semver.org) so you can control the risk associated with updating your app. As an example, always downloading the latest minor version number (e.g. x.*y*.x) ensures you get the latest security and feature enhancements but our API surface remains the same. You can always see the latest version and release notes under the Releases tab of GitHub. ## Security Reporting If you find a security issue with our libraries or services please report it to [secure@microsoft.com](mailto:secure@microsoft.com) with as much detail as possible. Your submission may be eligible for a bounty through the [Microsoft Bounty](http://aka.ms/bugbounty) program. Please do not post security issues to GitHub Issues or any other public site. We will contact you shortly upon receiving the information. We encourage you to get notifications of when security incidents occur by visiting [this page](https://technet.microsoft.com/en-us/security/dd252948) and subscribing to Security Advisory Alerts. Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License (the "License"). microsoft-authentication-library-for-go-1.0.0/RELEASES.md000066400000000000000000000145231442026362400231200ustar00rootroot00000000000000# Microsoft Identity SDK Versioning and Servicing FAQ We have adopted the semantic versioning flow that is industry standard for OSS projects. It gives the maximum amount of control on what risk you take with what versions. If you know how semantic versioning works with node.js, java, and ruby none of this will be new. ## Semantic Versioning and API stability promises Microsoft Identity libraries are independent open source libraries that are used by partners both internal and external to Microsoft. As with the rest of Microsoft, we have moved to a rapid iteration model where bugs are fixed daily and new versions are produced as required. To communicate these frequent changes to external partners and customers, we use semantic versioning for all our public Microsoft Identity SDK libraries. This follows the practices of other open source libraries on the internet. This allows us to support our downstream partners which will lock on certain versions for stability purposes, as well as providing for the distribution over NuGet, CocoaPods, and Maven. The semantics are: MAJOR.MINOR.PATCH (example 1.1.5) We will update our code distributions to use the latest PATCH semantic version number in order to make sure our customers and partners get the latest bug fixes. Downstream partner needs to pull the latest PATCH version. Most partners should try lock on the latest MINOR version number in their builds and accept any updates in the PATCH number. Examples: Using Cocapods, the following in the podfile will take the latest ADALiOS build that is > 1.1 but not 1.2. ``` pod 'ADALiOS', '~> 1.1' ``` Using NuGet, this ensures all 1.1.0 to 1.1.x updates are included when building your code, but not 1.2. ``` ``` | Version | Description | Example | |:-------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------:| | x.x.x | PATCH version number. Incrementing these numbers is for bug fixes and updates but do not introduce new features. This is used for close partners who build on our platform release (ex. Azure AD Fabric, Office, etc.),In addition, Cocoapods, NuGet, and Maven use this number to deliver the latest release to customers.,This will update frequently (sometimes within the same day),There is no new features, and no regressions or API surface changes. Code will continue to work unless affected by a particular code fix. | ADAL for iOS 1.0.10,(this was a fix for the Storyboard display that was fixed for a specific Office team) | | x.x | MINOR version numbers. Incrementing these second numbers are for new feature additions that do not impact existing features or introduce regressions. They are purely additive, but may require testing to ensure nothing is impacted.,All x.x.x bug fixes will also roll up in to this number.,There is no regressions or API surface changes. Code will continue to work unless affected by a particular code fix or needs this new feature. | ADAL for iOS 1.1.0,(this added WPJ capability to ADAL, and rolled all the updates from 1.0.0 to 1.0.12) | | x | MAJOR version numbers. This should be considered a new, supported version of Microsoft Identity SDK and begins the Azure two year support cycle anew. Major new features are introduced and API changes can occur.,This should only be used after a large amount of testing and used only if those features are needed.,We will continue to service MAJOR version numbers with bug fixes up to the two year support cycle. | ADAL for iOS 1.0,(our first official release of ADAL) | ## Serviceability When we release a new MINOR version, the previous MINOR version is abandoned. When we release a new MAJOR version, we will continue to apply bug fixes to the existing features in the previous MAJOR version for up to the 2 year support cycle for Azure. Example: We release ADALiOS 2.0 in the future which supports unified Auth for AAD and MSA. Later, we then have a fix in Conditional Access for ADALiOS. Since that feature exists both in ADALiOS 1.1 and ADALiOS 2.0, we will fix both. It will roll up in a PATCH number for each. Customers that are still locked down on ADALiOS 1.1 will receive the benefit of this fix. ## Microsoft Identity SDKs and Azure Active Directory Microsoft Identity SDKs major versions will maintain backwards compatibility with Azure Active Directory web services through the support period. This means that the API surface area defined in a MAJOR version will continue to work for 2 years after release. We will respond to bugs quickly from our partners and customers submitted through GitHub and through our private alias (tellaad@microsoft.com) for security issues and update the PATCH version number. We will also submit a change summary for each PATCH number. Occasionally, there will be security bugs or breaking bugs from our partners that will require an immediate fix and a publish of an update to all partners and customers. When this occurs, we will do an emergency roll up to a PATCH version number and update all our distribution methods to the latest.microsoft-authentication-library-for-go-1.0.0/SECURITY.md000066400000000000000000000054611442026362400231650ustar00rootroot00000000000000 ## Security Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. ## Reporting Security Issues **Please do not report security vulnerabilities through public GitHub issues.** Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) * Full paths of source file(s) related to the manifestation of the issue * The location of the affected source code (tag/branch/commit or direct URL) * Any special configuration required to reproduce the issue * Step-by-step instructions to reproduce the issue * Proof-of-concept or exploit code (if possible) * Impact of the issue, including how an attacker might exploit the issue This information will help us triage your report more quickly. If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. ## Preferred Languages We prefer all communications to be in English. ## Policy Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). microsoft-authentication-library-for-go-1.0.0/apps/000077500000000000000000000000001442026362400223315ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/cache/000077500000000000000000000000001442026362400233745ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/cache/cache.go000066400000000000000000000040211442026362400247630ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. /* Package cache allows third parties to implement external storage for caching token data for distributed systems or multiple local applications access. The data stored and extracted will represent the entire cache. Therefore it is recommended one msal instance per user. This data is considered opaque and there are no guarantees to implementers on the format being passed. */ package cache import "context" // Marshaler marshals data from an internal cache to bytes that can be stored. type Marshaler interface { Marshal() ([]byte, error) } // Unmarshaler unmarshals data from a storage medium into the internal cache, overwriting it. type Unmarshaler interface { Unmarshal([]byte) error } // Serializer can serialize the cache to binary or from binary into the cache. type Serializer interface { Marshaler Unmarshaler } // ExportHints are suggestions for storing data. type ExportHints struct { // PartitionKey is a suggested key for partitioning the cache PartitionKey string } // ReplaceHints are suggestions for loading data. type ReplaceHints struct { // PartitionKey is a suggested key for partitioning the cache PartitionKey string } // ExportReplace exports and replaces in-memory cache data. It doesn't support nil Context or // define the outcome of passing one. A Context without a timeout must receive a default timeout // specified by the implementor. Retries must be implemented inside the implementation. type ExportReplace interface { // Replace replaces the cache with what is in external storage. Implementors should honor // Context cancellations and return context.Canceled or context.DeadlineExceeded in those cases. Replace(ctx context.Context, cache Unmarshaler, hints ReplaceHints) error // Export writes the binary representation of the cache (cache.Marshal()) to external storage. // This is considered opaque. Context cancellations should be honored as in Replace. Export(ctx context.Context, cache Marshaler, hints ExportHints) error } microsoft-authentication-library-for-go-1.0.0/apps/confidential/000077500000000000000000000000001442026362400247705ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/confidential/confidential.go000066400000000000000000000564551442026362400277750ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. /* Package confidential provides a client for authentication of "confidential" applications. A "confidential" application is defined as an app that run on servers. They are considered difficult to access and for that reason capable of keeping an application secret. Confidential clients can hold configuration-time secrets. */ package confidential import ( "context" "crypto" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/pem" "errors" "fmt" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/options" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) /* Design note: confidential.Client uses base.Client as an embedded type. base.Client statically assigns its attributes during creation. As it doesn't have any pointers in it, anything borrowed from it, such as Base.AuthParams is a copy that is free to be manipulated here. Duplicate Calls shared between public.Client and this package: There is some duplicate call options provided here that are the same as in public.Client . This is a design choices. Go proverb(https://www.youtube.com/watch?v=PAAkCSZUG1c&t=9m28s): "a little copying is better than a little dependency". Yes, we could have another package with shared options (fail). That divides like 2 options from all others which makes the user look through more docs. We can have all clients in one package, but I think separate packages here makes for better naming (public.Client vs client.PublicClient). So I chose a little duplication. .Net People, Take note on X509: This uses x509.Certificates and private keys. x509 does not store private keys. .Net has some x509.Certificate2 thing that has private keys, but that is just some bullcrap that .Net added, it doesn't exist in real life. As such I've put a PEM decoder into here. */ // TODO(msal): This should have example code for each method on client using Go's example doc framework. // base usage details should be include in the package documentation. // AuthResult contains the results of one token acquisition operation. // For details see https://aka.ms/msal-net-authenticationresult type AuthResult = base.AuthResult type Account = shared.Account // CertFromPEM converts a PEM file (.pem or .key) for use with [NewCredFromCert]. The file // must contain the public certificate and the private key. If a PEM block is encrypted and // password is not an empty string, it attempts to decrypt the PEM blocks using the password. // Multiple certs are due to certificate chaining for use cases like TLS that sign from root to leaf. func CertFromPEM(pemData []byte, password string) ([]*x509.Certificate, crypto.PrivateKey, error) { var certs []*x509.Certificate var priv crypto.PrivateKey for { block, rest := pem.Decode(pemData) if block == nil { break } //nolint:staticcheck // x509.IsEncryptedPEMBlock and x509.DecryptPEMBlock are deprecated. They are used here only to support a usecase. if x509.IsEncryptedPEMBlock(block) { b, err := x509.DecryptPEMBlock(block, []byte(password)) if err != nil { return nil, nil, fmt.Errorf("could not decrypt encrypted PEM block: %v", err) } block, _ = pem.Decode(b) if block == nil { return nil, nil, fmt.Errorf("encounter encrypted PEM block that did not decode") } } switch block.Type { case "CERTIFICATE": cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, nil, fmt.Errorf("block labelled 'CERTIFICATE' could not be parsed by x509: %v", err) } certs = append(certs, cert) case "PRIVATE KEY": if priv != nil { return nil, nil, errors.New("found multiple private key blocks") } var err error priv, err = x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return nil, nil, fmt.Errorf("could not decode private key: %v", err) } case "RSA PRIVATE KEY": if priv != nil { return nil, nil, errors.New("found multiple private key blocks") } var err error priv, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, nil, fmt.Errorf("could not decode private key: %v", err) } } pemData = rest } if len(certs) == 0 { return nil, nil, fmt.Errorf("no certificates found") } if priv == nil { return nil, nil, fmt.Errorf("no private key found") } return certs, priv, nil } // AssertionRequestOptions has required information for client assertion claims type AssertionRequestOptions = exported.AssertionRequestOptions // Credential represents the credential used in confidential client flows. type Credential struct { secret string cert *x509.Certificate key crypto.PrivateKey x5c []string assertionCallback func(context.Context, AssertionRequestOptions) (string, error) tokenProvider func(context.Context, TokenProviderParameters) (TokenProviderResult, error) } // toInternal returns the accesstokens.Credential that is used internally. The current structure of the // code requires that client.go, requests.go and confidential.go share a credential type without // having import recursion. That requires the type used between is in a shared package. Therefore // we have this. func (c Credential) toInternal() (*accesstokens.Credential, error) { if c.secret != "" { return &accesstokens.Credential{Secret: c.secret}, nil } if c.cert != nil { if c.key == nil { return nil, errors.New("missing private key for certificate") } return &accesstokens.Credential{Cert: c.cert, Key: c.key, X5c: c.x5c}, nil } if c.key != nil { return nil, errors.New("missing certificate for private key") } if c.assertionCallback != nil { return &accesstokens.Credential{AssertionCallback: c.assertionCallback}, nil } if c.tokenProvider != nil { return &accesstokens.Credential{TokenProvider: c.tokenProvider}, nil } return nil, errors.New("invalid credential") } // NewCredFromSecret creates a Credential from a secret. func NewCredFromSecret(secret string) (Credential, error) { if secret == "" { return Credential{}, errors.New("secret can't be empty string") } return Credential{secret: secret}, nil } // NewCredFromAssertionCallback creates a Credential that invokes a callback to get assertions // authenticating the application. The callback must be thread safe. func NewCredFromAssertionCallback(callback func(context.Context, AssertionRequestOptions) (string, error)) Credential { return Credential{assertionCallback: callback} } // NewCredFromCert creates a Credential from a certificate or chain of certificates and an RSA private key // as returned by [CertFromPEM]. func NewCredFromCert(certs []*x509.Certificate, key crypto.PrivateKey) (Credential, error) { cred := Credential{key: key} k, ok := key.(*rsa.PrivateKey) if !ok { return cred, errors.New("key must be an RSA key") } for _, cert := range certs { if cert == nil { // not returning an error here because certs may still contain a sufficient cert/key pair continue } certKey, ok := cert.PublicKey.(*rsa.PublicKey) if ok && k.E == certKey.E && k.N.Cmp(certKey.N) == 0 { // We know this is the signing cert because its public key matches the given private key. // This cert must be first in x5c. cred.cert = cert cred.x5c = append([]string{base64.StdEncoding.EncodeToString(cert.Raw)}, cred.x5c...) } else { cred.x5c = append(cred.x5c, base64.StdEncoding.EncodeToString(cert.Raw)) } } if cred.cert == nil { return cred, errors.New("key doesn't match any certificate") } return cred, nil } // TokenProviderParameters is the authentication parameters passed to token providers type TokenProviderParameters = exported.TokenProviderParameters // TokenProviderResult is the authentication result returned by custom token providers type TokenProviderResult = exported.TokenProviderResult // NewCredFromTokenProvider creates a Credential from a function that provides access tokens. The function // must be concurrency safe. This is intended only to allow the Azure SDK to cache MSI tokens. It isn't // useful to applications in general because the token provider must implement all authentication logic. func NewCredFromTokenProvider(provider func(context.Context, TokenProviderParameters) (TokenProviderResult, error)) Credential { return Credential{tokenProvider: provider} } // AutoDetectRegion instructs MSAL Go to auto detect region for Azure regional token service. func AutoDetectRegion() string { return "TryAutoDetect" } // Client is a representation of authentication client for confidential applications as defined in the // package doc. A new Client should be created PER SERVICE USER. // For more information, visit https://docs.microsoft.com/azure/active-directory/develop/msal-client-applications type Client struct { base base.Client cred *accesstokens.Credential } // clientOptions are optional settings for New(). These options are set using various functions // returning Option calls. type clientOptions struct { accessor cache.ExportReplace authority, azureRegion string capabilities []string disableInstanceDiscovery, sendX5C bool httpClient ops.HTTPClient } // Option is an optional argument to New(). type Option func(o *clientOptions) // WithCache provides an accessor that will read and write authentication data to an externally managed cache. func WithCache(accessor cache.ExportReplace) Option { return func(o *clientOptions) { o.accessor = accessor } } // WithClientCapabilities allows configuring one or more client capabilities such as "CP1" func WithClientCapabilities(capabilities []string) Option { return func(o *clientOptions) { // there's no danger of sharing the slice's underlying memory with the application because // this slice is simply passed to base.WithClientCapabilities, which copies its data o.capabilities = capabilities } } // WithHTTPClient allows for a custom HTTP client to be set. func WithHTTPClient(httpClient ops.HTTPClient) Option { return func(o *clientOptions) { o.httpClient = httpClient } } // WithX5C specifies if x5c claim(public key of the certificate) should be sent to STS to enable Subject Name Issuer Authentication. func WithX5C() Option { return func(o *clientOptions) { o.sendX5C = true } } // WithInstanceDiscovery set to false to disable authority validation (to support private cloud scenarios) func WithInstanceDiscovery(enabled bool) Option { return func(o *clientOptions) { o.disableInstanceDiscovery = !enabled } } // WithAzureRegion sets the region(preferred) or Confidential.AutoDetectRegion() for auto detecting region. // Region names as per https://azure.microsoft.com/en-ca/global-infrastructure/geographies/. // See https://aka.ms/region-map for more details on region names. // The region value should be short region name for the region where the service is deployed. // For example "centralus" is short name for region Central US. // Not all auth flows can use the regional token service. // Service To Service (client credential flow) tokens can be obtained from the regional service. // Requires configuration at the tenant level. // Auto-detection works on a limited number of Azure artifacts (VMs, Azure functions). // If auto-detection fails, the non-regional endpoint will be used. // If an invalid region name is provided, the non-regional endpoint MIGHT be used or the token request MIGHT fail. func WithAzureRegion(val string) Option { return func(o *clientOptions) { o.azureRegion = val } } // New is the constructor for Client. authority is the URL of a token authority such as "https://login.microsoftonline.com/". // If the Client will connect directly to AD FS, use "adfs" for the tenant. clientID is the application's client ID (also called its // "application ID"). func New(authority, clientID string, cred Credential, options ...Option) (Client, error) { internalCred, err := cred.toInternal() if err != nil { return Client{}, err } opts := clientOptions{ authority: authority, // if the caller specified a token provider, it will handle all details of authentication, using Client only as a token cache disableInstanceDiscovery: cred.tokenProvider != nil, httpClient: shared.DefaultClient, } for _, o := range options { o(&opts) } baseOpts := []base.Option{ base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery), base.WithRegionDetection(opts.azureRegion), base.WithX5C(opts.sendX5C), } base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), baseOpts...) if err != nil { return Client{}, err } base.AuthParams.IsConfidentialClient = true return Client{base: base, cred: internalCred}, nil } // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { claims, loginHint, tenantID, domainHint string } // AuthCodeURLOption is implemented by options for AuthCodeURL type AuthCodeURLOption interface { authCodeURLOption() } // AuthCodeURL creates a URL used to acquire an authorization code. Users need to call CreateAuthorizationCodeURLParameters and pass it in. // // Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return "", err } ap, err := cca.base.AuthParams.WithTenant(o.tenantID) if err != nil { return "", err } ap.Claims = o.claims ap.LoginHint = o.loginHint ap.DomainHint = o.domainHint return cca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap) } // WithLoginHint pre-populates the login prompt with a username. func WithLoginHint(username string) interface { AuthCodeURLOption options.CallOption } { return struct { AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *authCodeURLOptions: t.loginHint = username default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // WithDomainHint adds the IdP domain as domain_hint query parameter in the auth url. func WithDomainHint(domain string) interface { AuthCodeURLOption options.CallOption } { return struct { AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *authCodeURLOptions: t.domainHint = domain default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // WithClaims sets additional claims to request for the token, such as those required by conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. // This option is valid for any token acquisition method. func WithClaims(claims string) interface { AcquireByAuthCodeOption AcquireByCredentialOption AcquireOnBehalfOfOption AcquireSilentOption AuthCodeURLOption options.CallOption } { return struct { AcquireByAuthCodeOption AcquireByCredentialOption AcquireOnBehalfOfOption AcquireSilentOption AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *acquireTokenByAuthCodeOptions: t.claims = claims case *acquireTokenByCredentialOptions: t.claims = claims case *acquireTokenOnBehalfOfOptions: t.claims = claims case *acquireTokenSilentOptions: t.claims = claims case *authCodeURLOptions: t.claims = claims default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // WithTenantID specifies a tenant for a single authentication. It may be different than the tenant set in [New]. // This option is valid for any token acquisition method. func WithTenantID(tenantID string) interface { AcquireByAuthCodeOption AcquireByCredentialOption AcquireOnBehalfOfOption AcquireSilentOption AuthCodeURLOption options.CallOption } { return struct { AcquireByAuthCodeOption AcquireByCredentialOption AcquireOnBehalfOfOption AcquireSilentOption AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *acquireTokenByAuthCodeOptions: t.tenantID = tenantID case *acquireTokenByCredentialOptions: t.tenantID = tenantID case *acquireTokenOnBehalfOfOptions: t.tenantID = tenantID case *acquireTokenSilentOptions: t.tenantID = tenantID case *authCodeURLOptions: t.tenantID = tenantID default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // acquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. // These are set by using various AcquireTokenSilentOption functions. type acquireTokenSilentOptions struct { account Account claims, tenantID string } // AcquireSilentOption is implemented by options for AcquireTokenSilent type AcquireSilentOption interface { acquireSilentOption() } // WithSilentAccount uses the passed account during an AcquireTokenSilent() call. func WithSilentAccount(account Account) interface { AcquireSilentOption options.CallOption } { return struct { AcquireSilentOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *acquireTokenSilentOptions: t.account = account default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // AcquireTokenSilent acquires a token from either the cache or using a refresh token. // // Options: [WithClaims], [WithSilentAccount], [WithTenantID] func (cca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts ...AcquireSilentOption) (AuthResult, error) { o := acquireTokenSilentOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } if o.claims != "" { return AuthResult{}, errors.New("call another AcquireToken method to request a new token having these claims") } silentParameters := base.AcquireTokenSilentParameters{ Scopes: scopes, Account: o.account, RequestType: accesstokens.ATConfidential, Credential: cca.cred, IsAppCache: o.account.IsZero(), TenantID: o.tenantID, } return cca.base.AcquireTokenSilent(ctx, silentParameters) } // acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. type acquireTokenByAuthCodeOptions struct { challenge, claims, tenantID string } // AcquireByAuthCodeOption is implemented by options for AcquireTokenByAuthCode type AcquireByAuthCodeOption interface { acquireByAuthCodeOption() } // WithChallenge allows you to provide a challenge for the .AcquireTokenByAuthCode() call. func WithChallenge(challenge string) interface { AcquireByAuthCodeOption options.CallOption } { return struct { AcquireByAuthCodeOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *acquireTokenByAuthCodeOptions: t.challenge = challenge default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // AcquireTokenByAuthCode is a request to acquire a security token from the authority, using an authorization code. // The specified redirect URI must be the same URI that was used when the authorization code was requested. // // Options: [WithChallenge], [WithClaims], [WithTenantID] func (cca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, opts ...AcquireByAuthCodeOption) (AuthResult, error) { o := acquireTokenByAuthCodeOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } params := base.AcquireTokenAuthCodeParameters{ Scopes: scopes, Code: code, Challenge: o.challenge, Claims: o.claims, AppType: accesstokens.ATConfidential, Credential: cca.cred, // This setting differs from public.Client.AcquireTokenByAuthCode RedirectURI: redirectURI, TenantID: o.tenantID, } return cca.base.AcquireTokenByAuthCode(ctx, params) } // acquireTokenByCredentialOptions contains optional configuration for AcquireTokenByCredential type acquireTokenByCredentialOptions struct { claims, tenantID string } // AcquireByCredentialOption is implemented by options for AcquireTokenByCredential type AcquireByCredentialOption interface { acquireByCredOption() } // AcquireTokenByCredential acquires a security token from the authority, using the client credentials grant. // // Options: [WithClaims], [WithTenantID] func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string, opts ...AcquireByCredentialOption) (AuthResult, error) { o := acquireTokenByCredentialOptions{} err := options.ApplyOptions(&o, opts) if err != nil { return AuthResult{}, err } authParams, err := cca.base.AuthParams.WithTenant(o.tenantID) if err != nil { return AuthResult{}, err } authParams.Scopes = scopes authParams.AuthorizationType = authority.ATClientCredentials authParams.Claims = o.claims token, err := cca.base.Token.Credential(ctx, authParams, cca.cred) if err != nil { return AuthResult{}, err } return cca.base.AuthResultFromToken(ctx, authParams, token, true) } // acquireTokenOnBehalfOfOptions contains optional configuration for AcquireTokenOnBehalfOf type acquireTokenOnBehalfOfOptions struct { claims, tenantID string } // AcquireOnBehalfOfOption is implemented by options for AcquireTokenOnBehalfOf type AcquireOnBehalfOfOption interface { acquireOBOOption() } // AcquireTokenOnBehalfOf acquires a security token for an app using middle tier apps access token. // Refer https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow. // // Options: [WithClaims], [WithTenantID] func (cca Client) AcquireTokenOnBehalfOf(ctx context.Context, userAssertion string, scopes []string, opts ...AcquireOnBehalfOfOption) (AuthResult, error) { o := acquireTokenOnBehalfOfOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } params := base.AcquireTokenOnBehalfOfParameters{ Scopes: scopes, UserAssertion: userAssertion, Claims: o.claims, Credential: cca.cred, TenantID: o.tenantID, } return cca.base.AcquireTokenOnBehalfOf(ctx, params) } // Account gets the account in the token cache with the specified homeAccountID. func (cca Client) Account(ctx context.Context, accountID string) (Account, error) { return cca.base.Account(ctx, accountID) } // RemoveAccount signs the account out and forgets account from token cache. func (cca Client) RemoveAccount(ctx context.Context, account Account) error { return cca.base.RemoveAccount(ctx, account) } microsoft-authentication-library-for-go-1.0.0/apps/confidential/confidential_test.go000066400000000000000000001213661442026362400310260ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package confidential import ( "context" "crypto" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "strings" "testing" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/golang-jwt/jwt/v4" "github.com/kylelemons/godebug/pretty" ) // errorClient is an HTTP client for tests that should fail when confidential.Client sends a request type errorClient struct{} func (*errorClient) Do(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String()) } func (*errorClient) CloseIdleConnections() {} func TestCertFromPEM(t *testing.T) { f, err := os.Open(filepath.Clean("../testdata/test-cert.pem")) if err != nil { t.Fatal(err) } defer f.Close() pemData, err := io.ReadAll(f) if err != nil { t.Fatal(err) } certs, key, err := CertFromPEM(pemData, "") if err != nil { t.Fatalf("TestCertFromPEM: got err == %s, want err == nil", err) } if len(certs) != 1 { t.Fatalf("TestCertFromPEM: got %d certs, want 1 cert", len(certs)) } if key == nil { t.Fatalf("TestCertFromPEM: got nil key, want key != nil") } } const ( authorityFmt = "https://%s/%s" fakeAuthority = "https://fake_authority/fake" fakeClientID = "fake_client_id" fakeSecret = "fake_secret" fakeTokenEndpoint = "https://fake_authority/fake/token" localhost = "http://localhost" refresh = "fake_refresh" token = "fake_token" ) var tokenScope = []string{"the_scope"} func fakeClient(tk accesstokens.TokenResponse, credential Credential, options ...Option) (Client, error) { client, err := New(fakeAuthority, fakeClientID, credential, options...) if err != nil { return Client{}, err } client.base.Token.AccessTokens = &fake.AccessTokens{ AccessToken: tk, } client.base.Token.Authority = &fake.Authority{ InstanceResp: authority.InstanceDiscoveryResponse{ TenantDiscoveryEndpoint: "https://fake_authority/fake/discovery/endpoint", Metadata: []authority.InstanceDiscoveryMetadata{ { PreferredNetwork: "fake_authority", PreferredCache: "fake_cache", Aliases: []string{ "fake_authority", "fake_auth", "fk_au", }, }, }, AdditionalFields: map[string]interface{}{ "api-version": "2020-02-02", }, }, } client.base.Token.Resolver = &fake.ResolveEndpoints{ Endpoints: authority.NewEndpoints("https://fake_authority/fake/auth", fakeTokenEndpoint, "https://fake_authority/fake/jwt", "fake_authority"), } client.base.Token.WSTrust = &fake.WSTrust{} return client, nil } func TestAcquireTokenByCredential(t *testing.T) { tests := []struct { desc string cred string }{ { desc: "Secret", cred: "fake_secret", }, { desc: "Signed Assertion", cred: "fake_assertion", }, } for _, test := range tests { cred, err := NewCredFromSecret(test.cred) if err != nil { t.Fatal(err) } client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, }, cred) if err != nil { t.Fatal(err) } _, err = client.AcquireTokenSilent(context.Background(), tokenScope) // first attempt should fail if err == nil { t.Errorf("TestAcquireTokenByCredential(%s): unexpected nil error from AcquireTokenSilent", test.desc) } tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope) if err != nil { t.Errorf("TestAcquireTokenByCredential(%s): got err == %s, want err == nil", test.desc, err) } if tk.AccessToken != token { t.Errorf("TestAcquireTokenByCredential(%s): unexpected access token %s", test.desc, tk.AccessToken) } // second attempt should return the cached token tk, err = client.AcquireTokenSilent(context.Background(), tokenScope) if err != nil { t.Errorf("TestAcquireTokenByCredential(%s): got err == %s, want err == nil", test.desc, err) } if tk.AccessToken != token { t.Errorf("TestAcquireTokenByCredential(%s): unexpected access token %s", test.desc, tk.AccessToken) } } } func TestAcquireTokenOnBehalfOf(t *testing.T) { // this test is an offline version of TestOnBehalfOf in integration_test.go cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } lmo := "login.microsoftonline.com" tenant := "tenant" assertion := "assertion" mockClient := mock.Client{} // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 3600))) client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } tk, err := client.AcquireTokenOnBehalfOf(context.Background(), assertion, tokenScope) if err != nil { t.Fatal(err) } if tk.AccessToken != token { t.Fatalf("wanted %q, got %q", token, tk.AccessToken) } // should return the cached access token tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion, tokenScope) if err != nil { t.Fatal(err) } if tk.AccessToken != token { t.Fatalf("wanted %q, got %q", token, tk.AccessToken) } // new assertion should trigger new token request token2 := token + "2" mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token2, "", "rt", "", 3600))) tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion+"2", tokenScope) if err != nil { t.Fatal(err) } if tk.AccessToken != token2 { t.Fatal("expected a new token") } } func TestAcquireTokenByAssertionCallback(t *testing.T) { calls := 0 key := struct{}{} ctx := context.WithValue(context.Background(), key, true) getAssertion := func(c context.Context, o AssertionRequestOptions) (string, error) { if v := c.Value(key); v == nil || !v.(bool) { t.Fatal("callback received unexpected context") } if o.ClientID != fakeClientID { t.Fatalf(`unexpected client ID "%s"`, o.ClientID) } if o.TokenEndpoint != fakeTokenEndpoint { t.Fatalf(`unexpected token endpoint "%s"`, o.TokenEndpoint) } calls++ if calls < 4 { return "assertion", nil } return "", errors.New("expected error") } cred := NewCredFromAssertionCallback(getAssertion) client, err := fakeClient(accesstokens.TokenResponse{}, cred) if err != nil { t.Fatal(err) } for i := 0; i < 3; i++ { if calls != i { t.Fatalf("expected %d calls, got %d", i, calls) } _, err = client.AcquireTokenByCredential(ctx, tokenScope) if err != nil { t.Fatal(err) } } _, err = client.AcquireTokenByCredential(ctx, tokenScope) if err == nil || err.Error() != "expected error" { t.Fatalf("expected an error from the callback, got %v", err) } } func TestAcquireTokenByAuthCode(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } for _, params := range []struct { upn, preferredUsername, utid string }{ {"", "fakeuser@fakeplace.fake", "fake"}, {"fakeuser@fakeplace.fake", "", ""}, } { t.Run("", func(t *testing.T) { tr := accesstokens.TokenResponse{ AccessToken: token, RefreshToken: refresh, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, IDToken: accesstokens.IDToken{ PreferredUsername: params.preferredUsername, UPN: params.upn, Name: "fake person", Oid: "123-456", TenantID: "fake", Subject: "nothing", Issuer: "https://fake_authority/fake", Audience: "abc-123", ExpirationTime: time.Now().Add(time.Hour).Unix(), IssuedAt: time.Now().Add(-5 * time.Minute).Unix(), NotBefore: time.Now().Add(-5 * time.Minute).Unix(), // NOTE: this is an invalid JWT however this doesn't cause a failure. // it simply falls back to calling Token.Refresh() which will obviously succeed. RawToken: "fake.raw.token", }, ClientInfo: accesstokens.ClientInfo{ UID: "123-456", UTID: params.utid, }, } client, err := fakeClient(tr, cred) if err != nil { t.Fatal(err) } _, err = client.AcquireTokenSilent(context.Background(), tokenScope) // first attempt should fail if err == nil { t.Fatal("unexpected nil error from AcquireTokenSilent") } tk, err := client.AcquireTokenByAuthCode(context.Background(), "fake_auth_code", "fake_redirect_uri", tokenScope) if err != nil { t.Fatal(err) } if tk.AccessToken != token { t.Fatalf("unexpected access token %s", tk.AccessToken) } account, err := client.Account(context.Background(), tk.Account.HomeAccountID) if err != nil { t.Fatal(err) } if params.utid == "" { if actual := account.HomeAccountID; actual != "123-456.123-456" { t.Fatalf("expected %q, got %q", "123-456.123-456", actual) } } else { if actual := account.HomeAccountID; actual != "123-456.fake" { t.Fatalf("expected %q, got %q", "123-456.fake", actual) } } if account.PreferredUsername != "fakeuser@fakeplace.fake" { t.Fatal("Unexpected Account.PreferredUsername") } // second attempt should return the cached token tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) if err != nil { t.Fatal(err) } if tk.AccessToken != token { t.Fatalf("unexpected access token %s", tk.AccessToken) } }) } } func TestAcquireTokenSilentTenants(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } tenants := []string{"a", "b"} lmo := "login.microsoftonline.com" mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenants[0]))) client, err := New(fmt.Sprintf(authorityFmt, lmo, tenants[0]), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } ctx := context.Background() // cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it. for _, tenant := range tenants { if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil { t.Fatal("silent auth should fail because the cache is empty") } mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(tenant, "", "", "", 3600))) if _, err := client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant)); err != nil { t.Fatal(err) } } // cache should return the correct access token for each tenant for _, tenant := range tenants { ar, err := client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)) if err != nil { t.Fatal(err) } if ar.AccessToken != tenant { t.Fatalf(`expected "%s", got "%s"`, tenant, ar.AccessToken) } } } func TestAuthorityValidation(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } for _, a := range []string{"", "https://login.microsoftonline.com", "http://login.microsoftonline.com/tenant"} { t.Run(a, func(t *testing.T) { _, err := New(a, fakeClientID, cred) if err == nil || !strings.Contains(err.Error(), "authority") { t.Fatalf("expected an error about the invalid authority, got %v", err) } }) } } func TestInvalidCredential(t *testing.T) { for _, cred := range []Credential{ {}, NewCredFromAssertionCallback(nil), } { t.Run("", func(t *testing.T) { _, err := New(fakeAuthority, fakeClientID, cred) if err == nil { t.Fatal("expected an error") } }) } } func TestNewCredFromCert(t *testing.T) { for _, file := range []struct { path string numCerts int }{ {"../testdata/test-cert.pem", 1}, {"../testdata/test-cert-chain.pem", 2}, {"../testdata/test-cert-chain-reverse.pem", 2}, } { f, err := os.Open(filepath.Clean(file.path)) if err != nil { t.Fatal(err) } defer f.Close() pemData, err := io.ReadAll(f) if err != nil { t.Fatal(err) } certs, key, err := CertFromPEM(pemData, "") if err != nil { t.Fatal(err) } if len(certs) != file.numCerts { t.Fatalf("expected %d certs, got %d", file.numCerts, len(certs)) } expectedCerts := make(map[string]struct{}, len(certs)) for _, cert := range certs { expectedCerts[base64.StdEncoding.EncodeToString(cert.Raw)] = struct{}{} } k, ok := key.(*rsa.PrivateKey) if !ok { t.Fatal("expected an RSA private key") } verifyingKey := &k.PublicKey cred, err := NewCredFromCert(certs, key) if err != nil { t.Fatal(err) } for _, sendX5c := range []bool{false, true} { opts := []Option{} if sendX5c { opts = append(opts, WithX5C()) } t.Run(fmt.Sprintf("%s/%v", filepath.Base(file.path), sendX5c), func(t *testing.T) { client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, }, cred, opts...) if err != nil { t.Fatal(err) } // the test fake passes assertions generated by the credential to this function validated := false client.base.Token.AccessTokens.(*fake.AccessTokens).ValidateAssertion = func(s string) { validated = true tk, err := jwt.Parse(s, func(tk *jwt.Token) (interface{}, error) { if signingMethod, ok := tk.Method.(*jwt.SigningMethodRSA); !ok { t.Fatalf("unexpected signing method %T", signingMethod) } return verifyingKey, nil }) if err != nil { t.Fatal(err) } if err = tk.Claims.Valid(); err != nil { t.Fatal(err) } // x5c header should be set iff the sendX5c is true if x5c, ok := tk.Header["x5c"]; ok != sendX5c { t.Fatal("x5c should be set only when application passed WithX5C option") } else if ok { if x := len(x5c.([]interface{})); x > file.numCerts { t.Fatalf("x5c contains %d certs; expected %d", x, file.numCerts) } // x5c must contain all the file's certs, signing cert first for i, cert := range x5c.([]interface{}) { s := cert.(string) if _, ok := expectedCerts[s]; ok { delete(expectedCerts, s) } else { t.Fatal("x5c contains an unexpected cert") } if i == 0 { decoded, err := base64.StdEncoding.DecodeString(s) if err != nil { t.Fatal(err) } parsed, err := x509.ParseCertificate(decoded) if err != nil { t.Fatal(err) } if !verifyingKey.Equal(parsed.PublicKey) { t.Fatal("signing cert must appear first in x5c") } } } if len(expectedCerts) > 0 { t.Fatal("x5c header is missing a cert") } } } tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope) if err != nil { t.Fatal(err) } if tk.AccessToken != token { t.Fatalf("unexpected access token %s", tk.AccessToken) } if !validated { t.Fatal("assertion validation function wasn't called") } }) } } } func TestNewCredFromCertError(t *testing.T) { data, err := os.ReadFile("../testdata/test-cert.pem") if err != nil { t.Fatal(err) } certs, key, err := CertFromPEM(data, "") if err != nil { t.Fatal(err) } for _, test := range []struct { certs []*x509.Certificate key crypto.PrivateKey }{ {nil, nil}, {certs, nil}, {nil, key}, {[]*x509.Certificate{}, nil}, {[]*x509.Certificate{}, key}, {[]*x509.Certificate{nil}, nil}, {[]*x509.Certificate{nil}, key}, } { t.Run("", func(t *testing.T) { _, err := NewCredFromCert(test.certs, test.key) if err == nil { t.Fatal("expected an error") } }) } // the key in this file doesn't match the cert loaded above if data, err = os.ReadFile("../testdata/test-cert-chain.pem"); err != nil { t.Fatal(err) } if _, key, err = CertFromPEM(data, ""); err != nil { t.Fatal(err) } if _, err = NewCredFromCert(certs, key); err == nil { t.Fatal("expected an error because key doesn't match certs") } } func TestNewCredFromTokenProvider(t *testing.T) { expectedToken := "expected token" called := false expiresIn := 4200 key := struct{}{} ctx := context.WithValue(context.Background(), key, true) cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) { if called { t.Fatal("expected exactly one token provider invocation") } called = true if v := c.Value(key); v == nil || !v.(bool) { t.Fatal("callback received unexpected context") } if tp.CorrelationID == "" { t.Fatal("expected CorrelationID") } if v := fmt.Sprint(tp.Scopes); v != fmt.Sprint(tokenScope) { t.Fatalf(`unexpected scopes "%v"`, v) } return exported.TokenProviderResult{ AccessToken: expectedToken, ExpiresInSeconds: expiresIn, }, nil }) client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } ar, err := client.AcquireTokenByCredential(ctx, tokenScope) if err != nil { t.Fatal(err) } if !called { t.Fatal("token provider wasn't invoked") } if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn { t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v) } if ar.AccessToken != expectedToken { t.Fatalf(`unexpected token "%s"`, ar.AccessToken) } ar, err = client.AcquireTokenSilent(context.Background(), tokenScope) if err != nil { t.Fatal(err) } if ar.AccessToken != expectedToken { t.Fatalf(`unexpected token "%s"`, ar.AccessToken) } } func TestNewCredFromTokenProviderError(t *testing.T) { expectedError := "something went wrong" cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) { return exported.TokenProviderResult{}, errors.New(expectedError) }) client, err := New(fakeAuthority, fakeClientID, cred) if err != nil { t.Fatal(err) } _, err = client.AcquireTokenByCredential(context.Background(), tokenScope) if err == nil || !strings.Contains(err.Error(), expectedError) { t.Fatalf(`unexpected error "%v"`, err) } } func TestTokenProviderOptions(t *testing.T) { accessToken, claims, tenant := "at", "claims", "tenant" cred := NewCredFromTokenProvider(func(ctx context.Context, tpp TokenProviderParameters) (TokenProviderResult, error) { if tpp.Claims != claims { t.Fatalf(`unexpected claims "%s"`, tpp.Claims) } if tpp.TenantID != tenant { t.Fatalf(`unexpected tenant "%s"`, tpp.TenantID) } return TokenProviderResult{AccessToken: accessToken, ExpiresInSeconds: 3600}, nil }) client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithClaims(claims), WithTenantID(tenant)) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } } // testCache is a simple in-memory cache.ExportReplace implementation type testCache map[string][]byte func (c testCache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { if v, err := m.Marshal(); err == nil { c[h.PartitionKey] = v } return nil } func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { if v, has := c[h.PartitionKey]; has { _ = u.Unmarshal(v) } return nil } func TestWithCache(t *testing.T) { cache := make(testCache) accessToken := "*" lmo := "login.microsoftonline.com" tenantA, tenantB := "a", "b" authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB) mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600))) cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } client, err := New(authorityA, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } // The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } account := ar.Account if actual := account.Realm; actual != tenantA { t.Fatalf(`unexpected realm "%s"`, actual) } // a client configured for a different tenant should be able to authenticate silently with the shared cache's data client, err = New(authorityB, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } // this should succeed because the cache contains an access token from tenantA mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA))) ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA)) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } // this should fail because the cache doesn't contain an access token from tenantB ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) if err == nil { t.Fatal("expected an error because the cache doesn't have an appropriate access token") } } func TestWithClaims(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } accessToken := "at" lmo, tenant := "login.microsoftonline.com", "tenant" authority := fmt.Sprintf(authorityFmt, lmo, tenant) for _, test := range []struct { capabilities []string claims, expected string }{ {}, { capabilities: []string{"cp1"}, expected: `{"access_token":{"xms_cc":{"values":["cp1"]}}}`, }, { claims: `{"id_token":{"auth_time":{"essential":true}}}`, expected: `{"id_token":{"auth_time":{"essential":true}}}`, }, { capabilities: []string{"cp1", "cp2"}, claims: `{"access_token":{"nbf":{"essential":true, "value":"42"}}}`, expected: `{"access_token":{"nbf":{"essential":true, "value":"42"}, "xms_cc":{"values":["cp1","cp2"]}}}`, }, } { var expected map[string]any if err := json.Unmarshal([]byte(test.expected), &expected); err != nil && test.expected != "" { t.Fatal("test bug: the expected result must be JSON or an empty string") } validate := func(t *testing.T, v url.Values) { if test.expected == "" { if v.Has("claims") { t.Fatal("claims shouldn't be set") } return } claims, ok := v["claims"] if !ok { t.Fatal("claims should be set") } if len(claims) != 1 { t.Fatalf("expected 1 value for claims, got %d", len(claims)) } var actual map[string]any if err := json.Unmarshal([]byte(claims[0]), &actual); err != nil { t.Fatal(err) } if diff := pretty.Compare(expected, actual); diff != "" { t.Fatal(diff) } } for _, method := range []string{"authcode", "authcodeURL", "credential", "obo"} { t.Run(method, func(t *testing.T) { mockClient := mock.Client{} clientInfo, idToken, refreshToken := "", "", "" if method == "obo" { clientInfo = base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) idToken = mock.GetIDToken(tenant, authority) refreshToken = "rt" // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) } mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600)), mock.WithCallback(func(r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatal(err) } validate(t, r.Form) }), ) client, err := New(authority, fakeClientID, cred, WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } if _, err = client.AcquireTokenSilent(context.Background(), tokenScope); err == nil { t.Fatal("silent authentication should fail because the cache is empty") } ctx := context.Background() var ar AuthResult switch method { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "code", localhost, tokenScope, WithClaims(test.claims)) case "authcodeURL": u := "" if u, err = client.AuthCodeURL(ctx, fakeClientID, localhost, tokenScope, WithClaims(test.claims)); err == nil { var parsed *url.URL if parsed, err = url.Parse(u); err == nil { validate(t, parsed.Query()) return // didn't acquire a token, no need for further validation } } case "credential": ar, err = client.AcquireTokenByCredential(ctx, tokenScope, WithClaims(test.claims)) case "obo": ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithClaims(test.claims)) default: t.Fatalf("test bug: no test for " + method) } if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } // silent auth should now succeed, provided no claims are requested, because the client has cached an access token if method == "obo" { ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope) } else { ar, err = client.AcquireTokenSilent(ctx, tokenScope) } if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } if test.claims != "" { if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithClaims(test.claims)); err == nil { t.Fatal("AcquireTokenSilent should fail when given claims") } if method == "obo" { // client has cached access and refresh tokens. When given claims, it should redeem a refresh token for a new access token. newToken := "new-access-token" mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600)), mock.WithCallback(func(r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatal(err) } // all token requests should include any specified claims validate(t, r.Form) if actual := r.Form.Get("refresh_token"); actual != refreshToken { t.Fatalf(`unexpected refresh token "%s"`, actual) } }), ) ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithClaims(test.claims)) if err != nil { t.Fatal(err) } if ar.AccessToken != newToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } } } }) } } } func TestWithTenantID(t *testing.T) { accessToken := "*" uuid1 := "00000000-0000-0000-0000-000000000000" uuid2 := strings.ReplaceAll(uuid1, "0", "1") lmo := "login.microsoftonline.com" host := fmt.Sprintf("https://%s/", lmo) for _, test := range []struct { authority, expectedAuthority, tenant string expectError bool }{ {authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1}, {authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1}, {authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2}, {authority: host + uuid1, tenant: "common", expectError: true}, {authority: host + uuid1, tenant: "organizations", expectError: true}, {authority: host + "consumers", tenant: uuid1, expectError: true}, } { for _, method := range []string{"authcode", "authcodeURL", "credential", "obo"} { t.Run(method, func(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } idToken, refreshToken, URL := "", "", "" mockClient := mock.Client{} if method == "obo" { idToken = mock.GetIDToken(test.tenant, test.authority) refreshToken = "refresh-token" // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant))) } mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant))) mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) client, err := New(test.authority, fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } ctx := context.Background() if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(test.tenant)); err == nil { t.Fatal("silent auth should fail because the cache is empty") } var ar AuthResult switch method { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(test.tenant)) case "authcodeURL": URL, err = client.AuthCodeURL(ctx, fakeClientID, localhost, tokenScope, WithTenantID(test.tenant)) case "credential": ar, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(test.tenant)) case "obo": ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant)) default: t.Fatalf("test bug: no test for " + method) } if err != nil { if test.expectError { return } t.Fatal(err) } else if test.expectError { t.Fatal("expected an error") } if !strings.HasPrefix(URL, test.expectedAuthority) { t.Fatalf(`expected "%s", got "%s"`, test.expectedAuthority, URL) } if method == "authcodeURL" { // didn't acquire a token, no need to test silent auth return } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } // silent authentication should now succeed for the given tenant... if method == "obo" { if ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant)); err != nil { t.Fatal(err) } } else if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(test.tenant)); err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatal("cached access token should match the one returned by AcquireToken...") } // ...but fail for another tenant unless we're authenticating OBO, in which case we have a refresh token otherTenant := "not-" + test.tenant if method == "obo" { mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600))) if _, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(otherTenant)); err != nil { t.Fatal(err) } } else if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(otherTenant)); err == nil { t.Fatal("expected an error") } }) } } // if every auth call specifies a different tenant, Client shouldn't send requests to its configured authority t.Run("enables fake authority", func(t *testing.T) { host := "host" defaultTenant := "default" cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } URL := "" mockClient := mock.Client{} client, err := New(fmt.Sprintf(authorityFmt, host, defaultTenant), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } checkForWrongTenant := func(r *http.Request) { if u := r.URL.String(); strings.Contains(u, defaultTenant) { t.Fatalf("unexpected request to the default authority: %q", u) } } ctx := context.Background() for i := 0; i < 3; i++ { tenant := fmt.Sprint(i) expected := fmt.Sprintf(authorityFmt, host, tenant) // TODO: prevent redundant discovery requests https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "", "", 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) if i == 0 { // TODO: see above (first silent auth rediscovers instance metadata) mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) } ar, err := client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(tenant)) if err != nil { t.Fatal(err) } if !strings.HasPrefix(URL, expected) { t.Fatalf(`expected "%s", got "%s"`, expected, URL) } if ar.AccessToken != accessToken { t.Fatalf("unexpected access token %q", ar.AccessToken) } // silent authentication should now succeed for the given tenant... if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatal("cached access token should match the one returned by AcquireToken...") } // ...but fail for another tenant otherTenant := "not-" + tenant if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(otherTenant)); err == nil { t.Fatal("expected an error") } } }) } func TestWithInstanceDiscovery(t *testing.T) { accessToken := "*" host := "stack.local" stackurl := fmt.Sprintf("https://%s/", host) for _, tenant := range []string{ "adfs", "98b8267d-e97f-426e-8b3f-7956511fd63f", } { for _, method := range []string{"authcode", "credential", "obo"} { t.Run(method, func(t *testing.T) { authority := stackurl + tenant cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } idToken, refreshToken := "", "" mockClient := mock.Client{} if method == "obo" { idToken = mock.GetIDToken(tenant, authority) refreshToken = "refresh-token" } mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(stackurl, tenant))) mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), ) client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) if err != nil { t.Fatal(err) } ctx := context.Background() if _, err = client.AcquireTokenSilent(ctx, tokenScope); err == nil { t.Fatal("silent auth should fail because the cache is empty") } var ar AuthResult switch method { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope) case "credential": ar, err = client.AcquireTokenByCredential(ctx, tokenScope) case "obo": ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope) default: t.Fatal("test bug: no test for " + method) } if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } if method == "obo" { if ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope); err != nil { t.Fatal(err) } } else if ar, err = client.AcquireTokenSilent(ctx, tokenScope); err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatal("cached access token should match the one returned by AcquireToken...") } }) } } } func TestWithPortAuthority(t *testing.T) { accessToken := "*" sl := "stack.local" port := ":3001" host := sl + port tenant := "00000000-0000-0000-0000-000000000000" authority := fmt.Sprintf("https://%s%s/%s", sl, port, tenant) cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } idToken, refreshToken, URL := "", "", "" mockClient := mock.Client{} //2 calls to instance discovery are made because Host is not trusted mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant))) mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } ctx := context.Background() if _, err = client.AcquireTokenSilent(ctx, tokenScope); err == nil { t.Fatal("silent auth should fail because the cache is empty") } var ar AuthResult ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope) if err != nil { t.Fatal(err) } if !strings.HasPrefix(URL, authority) { t.Fatalf(`expected "%s", got "%s"`, authority, URL) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } if ar, err = client.AcquireTokenSilent(ctx, tokenScope); err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatal("cached access token should match the one returned by AcquireToken...") } } func TestWithLoginHint(t *testing.T) { upn := "user@localhost" cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } client.base.Token.Resolver = &fake.ResolveEndpoints{} for _, expectHint := range []bool{true, false} { t.Run(fmt.Sprint(expectHint), func(t *testing.T) { opts := []AuthCodeURLOption{} if expectHint { opts = append(opts, WithLoginHint(upn)) } u, err := client.AuthCodeURL(context.Background(), "id", localhost, tokenScope, opts...) if err != nil { t.Fatal(err) } parsed, err := url.Parse(u) if err != nil { t.Fatal(err) } if !parsed.Query().Has("login_hint") { if !expectHint { return } t.Fatal("expected a login hint") } else if !expectHint { t.Fatal("expected no login hint") } if actual := parsed.Query()["login_hint"]; len(actual) != 1 || actual[0] != upn { t.Fatalf(`unexpected login_hint "%v"`, actual) } }) } } func TestWithDomainHint(t *testing.T) { domain := "contoso.com" cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } client.base.Token.Resolver = &fake.ResolveEndpoints{} for _, expectHint := range []bool{true, false} { t.Run(fmt.Sprint(expectHint), func(t *testing.T) { var opts []AuthCodeURLOption if expectHint { opts = append(opts, WithDomainHint(domain)) } u, err := client.AuthCodeURL(context.Background(), "id", localhost, tokenScope, opts...) if err != nil { t.Fatal(err) } parsed, err := url.Parse(u) if err != nil { t.Fatal(err) } if !parsed.Query().Has("domain_hint") { if !expectHint { return } t.Fatal("expected a domain hint") } else if !expectHint { t.Fatal("expected no domain hint") } if actual := parsed.Query()["domain_hint"]; len(actual) != 1 || actual[0] != domain { t.Fatalf(`unexpected domain_hint "%v"`, actual) } }) } } microsoft-authentication-library-for-go-1.0.0/apps/confidential/examples_test.go000066400000000000000000000013371442026362400302000ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package confidential_test import ( "fmt" "log" "os" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) func ExampleNewCredFromCert_pem() { b, err := os.ReadFile("key.pem") if err != nil { log.Fatal(err) } // This extracts our public certificates and private key from the PEM file. If it is // encrypted, the second argument must be password to decode. certs, priv, err := confidential.CertFromPEM(b, "") if err != nil { log.Fatal(err) } cred, err := confidential.NewCredFromCert(certs, priv) if err != nil { log.Fatal(err) } fmt.Println(cred) // Simply here so cred is used, otherwise won't compile. } microsoft-authentication-library-for-go-1.0.0/apps/design/000077500000000000000000000000001442026362400236025ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/design/design.md000066400000000000000000000151261442026362400254020ustar00rootroot00000000000000# MSAL Go Design Guide Author: John Doak(jdoak@microsoft.com) Contributors: - Keegan Caruso(Keegan.Caruso@microsoft.com) - Joel Hendrix(jhendrix@microsoft.com) - Santiago Gonzalez(Santiago.Gonzalez@microsoft.com) - Bogdan Gavril (bogavril@microsoft.com) ## History The original code submitted for Go MSAL was a translation of either Java or .Net code. This was done as a best effort by an intern who was attempting their first crack at Go. It had a very interesting structure that didn't fit into Go style and made it difficult to understand or change. It used global locks, global variables, base type classes (mimicing inheritence), ... This probably should have be re-written from scratch, but we decided to try and do it in pieces. The lesson to be learned from this is that this type of refactor leads to re-writing the code 7 or 8 times instead of once. Much of this lead to a re-write where we were not seeing the forrest because of the trees. Every small change would inevitably become some 60 file refactor and have much larger ramifications than intended. The work could not be divided up, because the API and the internals were linked across logical boundaries. What has resulted should be a design that divides code into logical layers and splits the public API from the internal structure. ## General Structure Public Surface: ``` apps/ - Contains all our code confidential/ - The confidential application API public/ - The public application API cache/ - The cache interface that can be implemented to provide persistence cache storage of credentials ``` Internals: ``` apps/ internal/ client/ - Shared package for common calls that Public and Confidential apps share json/ - Our own json encoder/decoder for special needs shared/ - Holds types that need to be in multiple packages and can't be moved into a single one due to import cycles requests/ - The pacakge to communicate to services to get tokens ``` ### Use of the Go special internal/ directory In Go, a directory called internal/ contains packages that should only be used by other packages rooted at the same location. This is documented here: https://golang.org/doc/go1.4#internalpackages For example, a package .../a/b/c/internal/d/e/f can be imported only by code in the directory tree rooted at .../a/b/c. It cannot be imported by code in .../a/b/g or in any other repository. We use this featurs quite liberally to make clear what is using an internal package. For example: ``` apps/internal/base - Only can be used by packages defined at apps/ apps/internal/base/internal/storage - Only can be use by package client ``` ## Public API The public API will be encapsulated in apps/. apps/ has 3 packages of interest to users: - public/ - This is what MSAL calls the Public Application Client (service client) - confidential/ - This is what MSAL calls the Confidential Application Client (service) - cache/ - This provides the interfaces that must be implemented to create peristent caches for any MSAL client ## Internals In this section we will be talking about internal/. ### JSON Handling JSON must be handled specially in our app. The basics are, if we receive fields that our structs do not contain, we cannot drop them. We must send them back to the service. To handle that, we use our own custom json package that handles this. See the design at: [Design](https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/dev/internal/json/design.md) ### Backend communication Communication to the backends is done via the requests/ package. oauth.Token is the client for all communication. oauth.Token communicates via REST calls that are encapsulated in the ops/ client. ## Adding A Feature This is the general way to add a new feature to MSAL: - Add the REST calls to ops.REST - Add the higher level manipulations to oauth.Token - Add your logic to the app/\ and access the services via your oauth.Token ## Notable Differences To Other Clients ### TBD: Confidential applications needs to handle multiple users without one big cache The MSAL caching design is rather simple. These design decisions and the fact that multiple applications in different languages can share a cache mean it cannot be easily changed. The entire cache contents of a confidential.Client is read and written on almost any action to and from an external cache. It is not clear to a user that a confidential client should be per user to prevent scaling problems. We cannot change the MSAL cache design at this time, therefore it should be clear that confidential.Client should be done per user. This must go beyond a simple doc entry that can be ignored. Its great to say: "we told you in the doc", but that is AFTER a support call. TBD ... ### Use of x509.Certificate and CertFromPEM() function The original version of this package used an thumbprint and a private key to do authorizations based on a certificate. But there wasn't a real way to get a thumbprint. A thumbprint is defined in the Oauth spec, which we had to track down. It is an SHA-1 hash from the x509 certificate's DER encdoed ASN1 bytes. Since the user was going to need the x509, we moved to having the user provide the x509.Certificate object. We wrote the thumbprint creator for the internals. Since we also require the private key and it is not straightforward to get, we added a CertFromPEM() function that will extract the x509.Certificate and private key. We did support encrypted PEM. It should be noted that Keyvault stores things in PKCS12 and PEM. Keyvault is not straight forward in how it works. Frankly, I'm in serious doubt that a regular Go user can get certs out of Keyvault's Go API. Before I began working on MSAL I was re-writing the Keyvault Go API. https://github.com/element-of-surprise/keyvault . It does the right things to extract cers for TLS now. I was still working on the Cert() API and hadn't exposed the public surface when I stopped. Since we have representation from the Go SDK team, we might have them go bridge this problem in the current implementation using some of that code so its possible for our users to store the cert in Keyvault. ## Logging For errors, see [error design](../errors/error_design.md). This library does not log personal identifiable information (PII). For a definition of PII, see https://www.microsoft.com/en-us/trust-center/privacy/customer-data-definitions. MSAL Go does not log any of the 3 data categories listed there. The library may log information related to your organization, such as tenant id, authority, client id etc. as well as information that cannot be tied to a user such as request correlation id, HTTP status codes etc. microsoft-authentication-library-for-go-1.0.0/apps/errors/000077500000000000000000000000001442026362400236455ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/errors/error_design.md000066400000000000000000000105421442026362400266530ustar00rootroot00000000000000# MSAL Error Design Author: Abhidnya Patil(abhidnya.patil@microsoft.com) Contributors: - John Doak(jdoak@microsoft.com) - Keegan Caruso(Keegan.Caruso@microsoft.com) - Joel Hendrix(jhendrix@microsoft.com) ## Background Errors in MSAL are intended for app developers to troubleshoot and not for displaying to end-users. ### Go error handling vs other MSAL languages Most modern languages use exception based errors. Simply put, you "throw" an exception and it must be caught at some routine in the upper stack or it will eventually crash the program. Go doesn't use exceptions, instead it relies on multiple return values, one of which can be the builtin error interface type. It is up to the user to decide what to do. ### Go custom error types Errors can be created in Go by simply using errors.New() or fmt.Errorf() to create an "error". Custom errors can be created in multiple ways. One of the more robust ways is simply to satisfy the error interface: ```go type MyCustomErr struct { Msg string } func (m MyCustomErr) Error() string { // This implements "error" return m.Msg } ``` ### MSAL Error Goals - Provide diagnostics to the user and for tickets that can be used to track down bugs or client misconfigurations - Detect errors that are transitory and can be retried - Allow the user to identify certain errors that the program can respond to, such a informing the user for the need to do an enrollment ## Implementing Client Side Errors Client side errors indicate a misconfiguration or passing of bad arguments that is non-recoverable. Retrying isn't possible. These errors can simply be standard Go errors created by errors.New() or fmt.Errorf(). If down the line we need a custom error, we can introduce it, but for now the error messages just need to be clear on what the issue was. ## Implementing Service Side Errors Service side errors occur when an external RPC responds either with an HTTP error code or returns a message that includes an error. These errors can be transitory (please slow down) or permanent (HTTP 404). To provide our diagnostic goals, we require the ability to differentiate these errors from other errors. The current implementation includes a specialized type that captures any error from the server: ```go // CallErr represents an HTTP call error. Has a Verbose() method that allows getting the // http.Request and Response objects. Implements error. type CallErr struct { Req *http.Request Resp *http.Response Err error } // Errors implements error.Error(). func (e CallErr) Error() string { return e.Err.Error() } // Verbose prints a versbose error message with the request or response. func (e CallErr) Verbose() string { e.Resp.Request = nil // This brings in a bunch of TLS stuff we don't need e.Resp.TLS = nil // Same return fmt.Sprintf("%s:\nRequest:\n%s\nResponse:\n%s", e.Err, prettyConf.Sprint(e.Req), prettyConf.Sprint(e.Resp)) } ``` A user will always receive the most concise error we provide. They can tell if it is a server side error using Go error package: ```go var callErr CallErr if errors.As(err, &callErr) { ... } ``` We provide a Verbose() function that can retrieve the most verbose message from any error we provide: ```go fmt.Println(errors.Verbose(err)) ``` If further differentiation is required, we can add custom errors that use Go error wrapping on top of CallErr to achieve our diagnostic goals (such as detecting when to retry a call due to transient errors). CallErr is always thrown from the comm package (which handles all http requests) and looks similar to: ```go return nil, errors.CallErr{ Req: req, Resp: reply, Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s", req.URL.String(), req.Method, reply.StatusCode, ErrorResponse), //ErrorResponse is the json body extracted from the http response } ``` ## Future Decisions The ability to retry calls needs to have centralized responsibility. Either the user is doing it or the client is doing it. If the user should be responsible, our errors package will include a CanRetry() function that will inform the user if the error provided to them is retryable. This is based on the http error code and possibly the type of error that was returned. It would also include a sleep time if the server returned an amount of time to wait. Otherwise we will do this internally and retries will be left to us. microsoft-authentication-library-for-go-1.0.0/apps/errors/errors.go000066400000000000000000000040561442026362400255150ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package errors import ( "errors" "fmt" "io" "net/http" "reflect" "strings" "github.com/kylelemons/godebug/pretty" ) var prettyConf = &pretty.Config{ IncludeUnexported: false, SkipZeroFields: true, TrackCycles: true, Formatter: map[reflect.Type]interface{}{ reflect.TypeOf((*io.Reader)(nil)).Elem(): func(r io.Reader) string { b, err := io.ReadAll(r) if err != nil { return "could not read io.Reader content" } return string(b) }, }, } type verboser interface { Verbose() string } // Verbose prints the most verbose error that the error message has. func Verbose(err error) string { build := strings.Builder{} for { if err == nil { break } if v, ok := err.(verboser); ok { build.WriteString(v.Verbose()) } else { build.WriteString(err.Error()) } err = errors.Unwrap(err) } return build.String() } // New is equivalent to errors.New(). func New(text string) error { return errors.New(text) } // CallErr represents an HTTP call error. Has a Verbose() method that allows getting the // http.Request and Response objects. Implements error. type CallErr struct { Req *http.Request // Resp contains response body Resp *http.Response Err error } // Errors implements error.Error(). func (e CallErr) Error() string { return e.Err.Error() } // Verbose prints a versbose error message with the request or response. func (e CallErr) Verbose() string { e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need e.Resp.TLS = nil // Same return fmt.Sprintf("%s:\nRequest:\n%s\nResponse:\n%s", e.Err, prettyConf.Sprint(e.Req), prettyConf.Sprint(e.Resp)) } // Is reports whether any error in errors chain matches target. func Is(err, target error) bool { return errors.Is(err, target) } // As finds the first error in errors chain that matches target, // and if so, sets target to that error value and returns true. // Otherwise, it returns false. func As(err error, target interface{}) bool { return errors.As(err, target) } microsoft-authentication-library-for-go-1.0.0/apps/internal/000077500000000000000000000000001442026362400241455ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/base/000077500000000000000000000000001442026362400250575ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/base/base.go000066400000000000000000000371411442026362400263260ustar00rootroot00000000000000// Package base contains a "Base" client that is used by the external public.Client and confidential.Client. // Base holds shared attributes that must be available to both clients and methods that act as // shared calls. package base import ( "context" "errors" "fmt" "net/url" "reflect" "strings" "sync" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) const ( // AuthorityPublicCloud is the default AAD authority host AuthorityPublicCloud = "https://login.microsoftonline.com/common" scopeSeparator = " " ) // manager provides an internal cache. It is defined to allow faking the cache in tests. // In production it's a *storage.Manager or *storage.PartitionedManager. type manager interface { cache.Serializer Read(context.Context, authority.AuthParams) (storage.TokenResponse, error) Write(authority.AuthParams, accesstokens.TokenResponse) (shared.Account, error) } // accountManager is a manager that also caches accounts. In production it's a *storage.Manager. type accountManager interface { manager AllAccounts() []shared.Account Account(homeAccountID string) shared.Account RemoveAccount(account shared.Account, clientID string) } // AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache). type AcquireTokenSilentParameters struct { Scopes []string Account shared.Account RequestType accesstokens.AppType Credential *accesstokens.Credential IsAppCache bool TenantID string UserAssertion string AuthorizationType authority.AuthorizeType Claims string } // AcquireTokenAuthCodeParameters contains the parameters required to acquire an access token using the auth code flow. // To use PKCE, set the CodeChallengeParameter. // Code challenges are used to secure authorization code grants; for more information, visit // https://tools.ietf.org/html/rfc7636. type AcquireTokenAuthCodeParameters struct { Scopes []string Code string Challenge string Claims string RedirectURI string AppType accesstokens.AppType Credential *accesstokens.Credential TenantID string } type AcquireTokenOnBehalfOfParameters struct { Scopes []string Claims string Credential *accesstokens.Credential TenantID string UserAssertion string } // AuthResult contains the results of one token acquisition operation in PublicClientApplication // or ConfidentialClientApplication. For details see https://aka.ms/msal-net-authenticationresult type AuthResult struct { Account shared.Account IDToken accesstokens.IDToken AccessToken string ExpiresOn time.Time GrantedScopes []string DeclinedScopes []string } // AuthResultFromStorage creates an AuthResult from a storage token response (which is generated from the cache). func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResult, error) { if err := storageTokenResponse.AccessToken.Validate(); err != nil { return AuthResult{}, fmt.Errorf("problem with access token in StorageTokenResponse: %w", err) } account := storageTokenResponse.Account accessToken := storageTokenResponse.AccessToken.Secret grantedScopes := strings.Split(storageTokenResponse.AccessToken.Scopes, scopeSeparator) // Checking if there was an ID token in the cache; this will throw an error in the case of confidential client applications. var idToken accesstokens.IDToken if !storageTokenResponse.IDToken.IsZero() { err := idToken.UnmarshalJSON([]byte(storageTokenResponse.IDToken.Secret)) if err != nil { return AuthResult{}, fmt.Errorf("problem decoding JWT token: %w", err) } } return AuthResult{account, idToken, accessToken, storageTokenResponse.AccessToken.ExpiresOn.T, grantedScopes, nil}, nil } // NewAuthResult creates an AuthResult. func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Account) (AuthResult, error) { if len(tokenResponse.DeclinedScopes) > 0 { return AuthResult{}, fmt.Errorf("token response failed because declined scopes are present: %s", strings.Join(tokenResponse.DeclinedScopes, ",")) } return AuthResult{ Account: account, IDToken: tokenResponse.IDToken, AccessToken: tokenResponse.AccessToken, ExpiresOn: tokenResponse.ExpiresOn.T, GrantedScopes: tokenResponse.GrantedScopes.Slice, }, nil } // Client is a base client that provides access to common methods and primatives that // can be used by multiple clients. type Client struct { Token *oauth.Client manager accountManager // *storage.Manager or fakeManager in tests // pmanager is a partitioned cache for OBO authentication. *storage.PartitionedManager or fakeManager in tests pmanager manager AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). cacheAccessor cache.ExportReplace cacheAccessorMu *sync.RWMutex } // Option is an optional argument to the New constructor. type Option func(c *Client) error // WithCacheAccessor allows you to set some type of cache for storing authentication tokens. func WithCacheAccessor(ca cache.ExportReplace) Option { return func(c *Client) error { if ca != nil { c.cacheAccessor = ca } return nil } } // WithClientCapabilities allows configuring one or more client capabilities such as "CP1" func WithClientCapabilities(capabilities []string) Option { return func(c *Client) error { var err error if len(capabilities) > 0 { cc, err := authority.NewClientCapabilities(capabilities) if err == nil { c.AuthParams.Capabilities = cc } } return err } } // WithKnownAuthorityHosts specifies hosts Client shouldn't validate or request metadata for because they're known to the user func WithKnownAuthorityHosts(hosts []string) Option { return func(c *Client) error { cp := make([]string, len(hosts)) copy(cp, hosts) c.AuthParams.KnownAuthorityHosts = cp return nil } } // WithX5C specifies if x5c claim(public key of the certificate) should be sent to STS to enable Subject Name Issuer Authentication. func WithX5C(sendX5C bool) Option { return func(c *Client) error { c.AuthParams.SendX5C = sendX5C return nil } } func WithRegionDetection(region string) Option { return func(c *Client) error { c.AuthParams.AuthorityInfo.Region = region return nil } } func WithInstanceDiscovery(instanceDiscoveryEnabled bool) Option { return func(c *Client) error { c.AuthParams.AuthorityInfo.ValidateAuthority = instanceDiscoveryEnabled c.AuthParams.AuthorityInfo.InstanceDiscoveryDisabled = !instanceDiscoveryEnabled return nil } } // New is the constructor for Base. func New(clientID string, authorityURI string, token *oauth.Client, options ...Option) (Client, error) { //By default, validateAuthority is set to true and instanceDiscoveryDisabled is set to false authInfo, err := authority.NewInfoFromAuthorityURI(authorityURI, true, false) if err != nil { return Client{}, err } authParams := authority.NewAuthParams(clientID, authInfo) client := Client{ // Note: Hey, don't even THINK about making Base into *Base. See "design notes" in public.go and confidential.go Token: token, AuthParams: authParams, cacheAccessorMu: &sync.RWMutex{}, manager: storage.New(token), pmanager: storage.NewPartitionedManager(token), } for _, o := range options { if err = o(&client); err != nil { break } } return client, err } // AuthCodeURL creates a URL used to acquire an authorization code. func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, authParams authority.AuthParams) (string, error) { endpoints, err := b.Token.ResolveEndpoints(ctx, authParams.AuthorityInfo, "") if err != nil { return "", err } baseURL, err := url.Parse(endpoints.AuthorizationEndpoint) if err != nil { return "", err } claims, err := authParams.MergeCapabilitiesAndClaims() if err != nil { return "", err } v := url.Values{} v.Add("client_id", clientID) v.Add("response_type", "code") v.Add("redirect_uri", redirectURI) v.Add("scope", strings.Join(scopes, scopeSeparator)) if authParams.State != "" { v.Add("state", authParams.State) } if claims != "" { v.Add("claims", claims) } if authParams.CodeChallenge != "" { v.Add("code_challenge", authParams.CodeChallenge) } if authParams.CodeChallengeMethod != "" { v.Add("code_challenge_method", authParams.CodeChallengeMethod) } if authParams.LoginHint != "" { v.Add("login_hint", authParams.LoginHint) } if authParams.Prompt != "" { v.Add("prompt", authParams.Prompt) } if authParams.DomainHint != "" { v.Add("domain_hint", authParams.DomainHint) } // There were left over from an implementation that didn't use any of these. We may // need to add them later, but as of now aren't needed. /* if p.ResponseMode != "" { urlParams.Add("response_mode", p.ResponseMode) } */ baseURL.RawQuery = v.Encode() return baseURL.String(), nil } func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) { ar := AuthResult{} // when tenant == "", the caller didn't specify a tenant and WithTenant will choose the client's configured tenant tenant := silent.TenantID authParams, err := b.AuthParams.WithTenant(tenant) if err != nil { return ar, err } authParams.Scopes = silent.Scopes authParams.HomeAccountID = silent.Account.HomeAccountID authParams.AuthorizationType = silent.AuthorizationType authParams.Claims = silent.Claims authParams.UserAssertion = silent.UserAssertion m := b.pmanager if authParams.AuthorizationType != authority.ATOnBehalfOf { authParams.AuthorizationType = authority.ATRefreshToken m = b.manager } if b.cacheAccessor != nil { key := authParams.CacheKey(silent.IsAppCache) b.cacheAccessorMu.RLock() err = b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key}) b.cacheAccessorMu.RUnlock() } if err != nil { return ar, err } storageTokenResponse, err := m.Read(ctx, authParams) if err != nil { return ar, err } // ignore cached access tokens when given claims if silent.Claims == "" { ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { return ar, err } } // redeem a cached refresh token, if available if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() { return ar, errors.New("no token found") } var cc *accesstokens.Credential if silent.RequestType == accesstokens.ATConfidential { cc = silent.Credential } token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken) if err != nil { return ar, err } return b.AuthResultFromToken(ctx, authParams, token, true) } func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) { authParams, err := b.AuthParams.WithTenant(authCodeParams.TenantID) if err != nil { return AuthResult{}, err } authParams.Claims = authCodeParams.Claims authParams.Scopes = authCodeParams.Scopes authParams.Redirecturi = authCodeParams.RedirectURI authParams.AuthorizationType = authority.ATAuthCode var cc *accesstokens.Credential if authCodeParams.AppType == accesstokens.ATConfidential { cc = authCodeParams.Credential authParams.IsConfidentialClient = true } req, err := accesstokens.NewCodeChallengeRequest(authParams, authCodeParams.AppType, cc, authCodeParams.Code, authCodeParams.Challenge) if err != nil { return AuthResult{}, err } token, err := b.Token.AuthCode(ctx, req) if err != nil { return AuthResult{}, err } return b.AuthResultFromToken(ctx, authParams, token, true) } // AcquireTokenOnBehalfOf acquires a security token for an app using middle tier apps access token. func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams AcquireTokenOnBehalfOfParameters) (AuthResult, error) { var ar AuthResult silentParameters := AcquireTokenSilentParameters{ Scopes: onBehalfOfParams.Scopes, RequestType: accesstokens.ATConfidential, Credential: onBehalfOfParams.Credential, UserAssertion: onBehalfOfParams.UserAssertion, AuthorizationType: authority.ATOnBehalfOf, TenantID: onBehalfOfParams.TenantID, Claims: onBehalfOfParams.Claims, } ar, err := b.AcquireTokenSilent(ctx, silentParameters) if err == nil { return ar, err } authParams, err := b.AuthParams.WithTenant(onBehalfOfParams.TenantID) if err != nil { return AuthResult{}, err } authParams.AuthorizationType = authority.ATOnBehalfOf authParams.Claims = onBehalfOfParams.Claims authParams.Scopes = onBehalfOfParams.Scopes authParams.UserAssertion = onBehalfOfParams.UserAssertion token, err := b.Token.OnBehalfOf(ctx, authParams, onBehalfOfParams.Credential) if err == nil { ar, err = b.AuthResultFromToken(ctx, authParams, token, true) } return ar, err } func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) { if !cacheWrite { return NewAuthResult(token, shared.Account{}) } var m manager = b.manager if authParams.AuthorizationType == authority.ATOnBehalfOf { m = b.pmanager } key := token.CacheKey(authParams) if b.cacheAccessor != nil { b.cacheAccessorMu.Lock() defer b.cacheAccessorMu.Unlock() err := b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key}) if err != nil { return AuthResult{}, err } } account, err := m.Write(authParams, token) if err != nil { return AuthResult{}, err } ar, err := NewAuthResult(token, account) if err == nil && b.cacheAccessor != nil { err = b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key}) } return ar, err } func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) { if b.cacheAccessor != nil { b.cacheAccessorMu.RLock() defer b.cacheAccessorMu.RUnlock() key := b.AuthParams.CacheKey(false) err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { return nil, err } } return b.manager.AllAccounts(), nil } func (b Client) Account(ctx context.Context, homeAccountID string) (shared.Account, error) { if b.cacheAccessor != nil { b.cacheAccessorMu.RLock() defer b.cacheAccessorMu.RUnlock() authParams := b.AuthParams // This is a copy, as we don't have a pointer receiver and .AuthParams is not a pointer. authParams.AuthorizationType = authority.AccountByID authParams.HomeAccountID = homeAccountID key := b.AuthParams.CacheKey(false) err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { return shared.Account{}, err } } return b.manager.Account(homeAccountID), nil } // RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account. func (b Client) RemoveAccount(ctx context.Context, account shared.Account) error { if b.cacheAccessor == nil { b.manager.RemoveAccount(account, b.AuthParams.ClientID) return nil } b.cacheAccessorMu.Lock() defer b.cacheAccessorMu.Unlock() key := b.AuthParams.CacheKey(false) err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { return err } b.manager.RemoveAccount(account, b.AuthParams.ClientID) return b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key}) } microsoft-authentication-library-for-go-1.0.0/apps/internal/base/base_test.go000066400000000000000000000340041442026362400273600ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package base import ( "context" "errors" "fmt" "reflect" "testing" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" "github.com/kylelemons/godebug/pretty" ) const ( fakeAccessToken = "fake-access-token" fakeAuthority = "fake_authority" fakeClientID = "fake-client-id" fakeRefreshToken = "fake-refresh-token" fakeTenantID = "fake-tenant-id" fakeUsername = "fake-username" ) var ( fakeIDToken = accesstokens.IDToken{ Oid: "oid", PreferredUsername: fakeUsername, RawToken: "x.e30", TenantID: fakeTenantID, UPN: fakeUsername, } testScopes = []string{"scope"} ) func fakeClient(t *testing.T, opts ...Option) Client { client, err := New(fakeClientID, fmt.Sprintf("https://%s/%s", fakeAuthority, fakeTenantID), &oauth.Client{}, opts...) if err != nil { t.Fatal(err) } client.Token.AccessTokens = &fake.AccessTokens{ AccessToken: accesstokens.TokenResponse{ AccessToken: fakeAccessToken, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, FamilyID: "family-id", GrantedScopes: accesstokens.Scopes{Slice: testScopes}, IDToken: fakeIDToken, RefreshToken: fakeRefreshToken, }, } client.Token.Authority = &fake.Authority{ InstanceResp: authority.InstanceDiscoveryResponse{ Metadata: []authority.InstanceDiscoveryMetadata{ {Aliases: []string{fakeAuthority}, PreferredNetwork: fakeAuthority}, }, TenantDiscoveryEndpoint: fmt.Sprintf("https://%s/fake/discovery/endpoint", fakeAuthority), }, } client.Token.Resolver = &fake.ResolveEndpoints{ Endpoints: authority.NewEndpoints( fmt.Sprintf("https://%s/fake/auth", fakeAuthority), fmt.Sprintf("https://%s/fake/token", fakeAuthority), fmt.Sprintf("https://%s/fake/jwt", fakeAuthority), fakeAuthority, ), } return client } func TestAcquireTokenSilentEmptyCache(t *testing.T) { client := fakeClient(t) _, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{ Account: shared.NewAccount("homeAccountID", "env", "realm", "localAccountID", authority.AAD, "username"), Scopes: testScopes, }) if err == nil { t.Fatal("expected an error because the cache is empty") } } func TestAcquireTokenSilentScopes(t *testing.T) { // ensure fakeIDToken.RawToken unmarshals (doesn't matter to what) because an unmarshalling // error can conceal a test bug by making an "err != nil" check true for the wrong reason var idToken accesstokens.IDToken if err := idToken.UnmarshalJSON([]byte(fakeIDToken.RawToken)); err != nil { t.Fatal(err) } for _, test := range []struct { desc string cachedTokenScopes []string }{ {"expired access token", testScopes}, {"no access token", []string{"other-" + testScopes[0]}}, } { t.Run(test.desc, func(t *testing.T) { client := fakeClient(t) validated := false client.Token.AccessTokens.(*fake.AccessTokens).FromRefreshTokenCallback = func(at accesstokens.AppType, ap authority.AuthParams, cc *accesstokens.Credential, rt string) { validated = true if !reflect.DeepEqual(ap.Scopes, testScopes) { t.Fatalf("unexpected scopes: %v", ap.Scopes) } if cc != nil { t.Fatal("client shouldn't have a credential") } if rt != fakeRefreshToken { t.Fatal("unexpected refresh token") } } // cache a refresh token and an expired access token for the given scopes // (testing only the public client code path) storage.FakeValidate = func(storage.AccessToken) error { return nil } account, err := client.manager.Write( authority.AuthParams{ AuthorityInfo: authority.Info{ AuthorityType: authority.AAD, Host: fakeAuthority, Tenant: fakeIDToken.TenantID, }, ClientID: fakeClientID, Scopes: test.cachedTokenScopes, Username: fakeIDToken.PreferredUsername, }, accesstokens.TokenResponse{ AccessToken: fakeAccessToken, ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes}, IDToken: fakeIDToken, RefreshToken: fakeRefreshToken, }, ) storage.FakeValidate = nil if err != nil { t.Fatal(err) } // AcquireTokenSilent should redeem the refresh token for a new access token ar, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{Account: account, Scopes: testScopes}) if err != nil { t.Fatal(err) } if ar.AccessToken != fakeAccessToken { t.Fatal("unexpected access token") } if !validated { t.Fatal("FromRefreshTokenCallback wasn't called") } }) } } func TestAcquireTokenSilentGrantedScopes(t *testing.T) { client := fakeClient(t) grantedScopes := []string{"scope1", "scope2"} expectedToken := "not-" + fakeAccessToken account, err := client.manager.Write( authority.AuthParams{ AuthorityInfo: authority.Info{ AuthorityType: authority.AAD, Host: fakeAuthority, Tenant: fakeIDToken.TenantID, }, ClientID: fakeClientID, Scopes: grantedScopes[1:], }, accesstokens.TokenResponse{ AccessToken: expectedToken, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: grantedScopes}, }, ) if err != nil { t.Fatal(err) } for _, scope := range grantedScopes { ar, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{Account: account, Scopes: []string{scope}}) if err != nil { t.Fatal(err) } if ar.AccessToken != expectedToken { t.Fatal("unexpected access token") } } } // failCache helps tests inject cache I/O errors type failCache struct { exported bool exportErr, replaceErr error } func (c *failCache) Export(context.Context, cache.Marshaler, cache.ExportHints) error { c.exported = true return c.exportErr } func (c failCache) Replace(context.Context, cache.Unmarshaler, cache.ReplaceHints) error { return c.replaceErr } func TestCacheIOErrors(t *testing.T) { ctx := context.Background() expected := errors.New("cache error") for _, export := range []bool{true, false} { name := "replace" cache := failCache{} if export { cache.exportErr = expected name = "export" } else { cache.replaceErr = expected } t.Run(name, func(t *testing.T) { client := fakeClient(t, WithCacheAccessor(&cache)) if !export { // Account and AllAccounts don't export the cache, and AcquireTokenSilent does so // only after redeeming a refresh token, so we test them only for replace errors _, actual := client.Account(ctx, "...") if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } _, actual = client.AllAccounts(ctx) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } _, actual = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{Scopes: testScopes}) if cache.replaceErr != nil && !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } } _, actual := client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential, Scopes: testScopes}) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } _, actual = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}, Scopes: testScopes}) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } _, actual = client.AuthResultFromToken(ctx, authority.AuthParams{}, accesstokens.TokenResponse{}, true) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } actual = client.RemoveAccount(ctx, shared.Account{}) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } }) } // testing that AcquireTokenSilent propagates errors from Export requires more elaborate // setup because that method exports the cache only after acquiring a new access token t.Run("silent auth export error", func(t *testing.T) { cache := failCache{} hid := "uid.utid" client := fakeClient(t, WithCacheAccessor(&cache)) // cache fake tokens and app metadata _, err := client.AuthResultFromToken(ctx, authority.AuthParams{ AuthorityInfo: authority.Info{Host: fakeAuthority}, ClientID: fakeClientID, HomeAccountID: hid, Scopes: testScopes, }, accesstokens.TokenResponse{ AccessToken: "at", ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, GrantedScopes: accesstokens.Scopes{Slice: testScopes}, IDToken: fakeIDToken, RefreshToken: "rt", }, true, ) if err != nil { t.Fatal(err) } // AcquireTokenSilent should return this error after redeeming a refresh token cache.exportErr = expected _, actual := client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{ Account: shared.NewAccount(hid, fakeAuthority, "realm", "id", authority.AAD, "upn"), Scopes: []string{"not-" + testScopes[0]}, }, ) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } }) // when the client fails to acquire a token, it should return an error instead of exporting the cache t.Run("auth error", func(t *testing.T) { cache := failCache{} client := fakeClient(t, WithCacheAccessor(&cache)) client.Token.AccessTokens.(*fake.AccessTokens).Err = true _, err := client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential}) if err == nil || cache.exported { t.Fatal("client should have returned an error instead of exporting the cache") } _, err = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}}) if err == nil || cache.exported { t.Fatal("client should have returned an error instead of exporting the cache") } _, err = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{}) if err == nil || cache.exported { t.Fatal("client should have returned an error instead of exporting the cache") } }) } func TestCreateAuthenticationResult(t *testing.T) { future := time.Now().Add(400 * time.Second) tests := []struct { desc string input accesstokens.TokenResponse want AuthResult err bool }{ { desc: "no declined scopes", input: accesstokens.TokenResponse{ AccessToken: "accessToken", ExpiresOn: internalTime.DurationTime{T: future}, GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}}, DeclinedScopes: nil, }, want: AuthResult{ AccessToken: "accessToken", ExpiresOn: future, GrantedScopes: []string{"user.read"}, DeclinedScopes: nil, }, }, { desc: "declined scopes", input: accesstokens.TokenResponse{ AccessToken: "accessToken", ExpiresOn: internalTime.DurationTime{T: future}, GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}}, DeclinedScopes: []string{"openid"}, }, err: true, }, } for _, test := range tests { got, err := NewAuthResult(test.input, shared.Account{}) switch { case err == nil && test.err: t.Errorf("TestCreateAuthenticationResult(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestCreateAuthenticationResult(%s): got err == %s, want err == nil", test.desc, err) case err != nil: continue } if diff := pretty.Compare(test.want, got); diff != "" { t.Errorf("TestCreateAuthenticationResult(%s): -want/+got:\n%s", test.desc, diff) } } } func TestAuthResultFromStorage(t *testing.T) { now := time.Now() future := time.Now().Add(500 * time.Second) tests := []struct { desc string storeToken storage.TokenResponse want AuthResult err bool }{ { desc: "Error: AccessToken.Validate error (AccessToken.CachedAt not set)", storeToken: storage.TokenResponse{ AccessToken: storage.AccessToken{ ExpiresOn: internalTime.Unix{T: future}, Secret: "secret", Scopes: "profile openid user.read", }, IDToken: storage.IDToken{Secret: "x.e30"}, }, err: true, }, { desc: "Success", storeToken: storage.TokenResponse{ AccessToken: storage.AccessToken{ CachedAt: internalTime.Unix{T: now}, ExpiresOn: internalTime.Unix{T: future}, Secret: "secret", Scopes: "profile openid user.read", }, IDToken: storage.IDToken{Secret: "x.e30"}, }, want: AuthResult{ AccessToken: "secret", IDToken: accesstokens.IDToken{ RawToken: "x.e30", }, ExpiresOn: future, GrantedScopes: []string{"profile", "openid", "user.read"}, }, }, } for _, test := range tests { got, err := AuthResultFromStorage(test.storeToken) switch { case err == nil && test.err: t.Errorf("TestAuthResultFromStorage(%s): got err == nil, want == != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestAuthResultFromStorage(%s): got err == %s, want == nil", test.desc, err) continue case err != nil: continue } if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" { t.Errorf("TestAuthResultFromStorage: -want/+got:\n%s", diff) } } } microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/000077500000000000000000000000001442026362400266735ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/000077500000000000000000000000001442026362400303375ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/items.go000066400000000000000000000160071442026362400320130ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package storage import ( "errors" "fmt" "reflect" "strings" "time" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) // Contract is the JSON structure that is written to any storage medium when serializing // the internal cache. This design is shared between MSAL versions in many languages. // This cannot be changed without design that includes other SDKs. type Contract struct { AccessTokens map[string]AccessToken `json:"AccessToken,omitempty"` RefreshTokens map[string]accesstokens.RefreshToken `json:"RefreshToken,omitempty"` IDTokens map[string]IDToken `json:"IdToken,omitempty"` Accounts map[string]shared.Account `json:"Account,omitempty"` AppMetaData map[string]AppMetaData `json:"AppMetadata,omitempty"` AdditionalFields map[string]interface{} } // Contract is the JSON structure that is written to any storage medium when serializing // the internal cache. This design is shared between MSAL versions in many languages. // This cannot be changed without design that includes other SDKs. type InMemoryContract struct { AccessTokensPartition map[string]map[string]AccessToken RefreshTokensPartition map[string]map[string]accesstokens.RefreshToken IDTokensPartition map[string]map[string]IDToken AccountsPartition map[string]map[string]shared.Account AppMetaData map[string]AppMetaData } // NewContract is the constructor for Contract. func NewInMemoryContract() *InMemoryContract { return &InMemoryContract{ AccessTokensPartition: map[string]map[string]AccessToken{}, RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{}, IDTokensPartition: map[string]map[string]IDToken{}, AccountsPartition: map[string]map[string]shared.Account{}, AppMetaData: map[string]AppMetaData{}, } } // NewContract is the constructor for Contract. func NewContract() *Contract { return &Contract{ AccessTokens: map[string]AccessToken{}, RefreshTokens: map[string]accesstokens.RefreshToken{}, IDTokens: map[string]IDToken{}, Accounts: map[string]shared.Account{}, AppMetaData: map[string]AppMetaData{}, AdditionalFields: map[string]interface{}{}, } } // AccessToken is the JSON representation of a MSAL access token for encoding to storage. type AccessToken struct { HomeAccountID string `json:"home_account_id,omitempty"` Environment string `json:"environment,omitempty"` Realm string `json:"realm,omitempty"` CredentialType string `json:"credential_type,omitempty"` ClientID string `json:"client_id,omitempty"` Secret string `json:"secret,omitempty"` Scopes string `json:"target,omitempty"` ExpiresOn internalTime.Unix `json:"expires_on,omitempty"` ExtendedExpiresOn internalTime.Unix `json:"extended_expires_on,omitempty"` CachedAt internalTime.Unix `json:"cached_at,omitempty"` UserAssertionHash string `json:"user_assertion_hash,omitempty"` AdditionalFields map[string]interface{} } // NewAccessToken is the constructor for AccessToken. func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, extendedExpiresOn time.Time, scopes, token string) AccessToken { return AccessToken{ HomeAccountID: homeID, Environment: env, Realm: realm, CredentialType: "AccessToken", ClientID: clientID, Secret: token, Scopes: scopes, CachedAt: internalTime.Unix{T: cachedAt.UTC()}, ExpiresOn: internalTime.Unix{T: expiresOn.UTC()}, ExtendedExpiresOn: internalTime.Unix{T: extendedExpiresOn.UTC()}, } } // Key outputs the key that can be used to uniquely look up this entry in a map. func (a AccessToken) Key() string { return strings.Join( []string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes}, shared.CacheKeySeparator, ) } // FakeValidate enables tests to fake access token validation var FakeValidate func(AccessToken) error // Validate validates that this AccessToken can be used. func (a AccessToken) Validate() error { if FakeValidate != nil { return FakeValidate(a) } if a.CachedAt.T.After(time.Now()) { return errors.New("access token isn't valid, it was cached at a future time") } if a.ExpiresOn.T.Before(time.Now().Add(5 * time.Minute)) { return fmt.Errorf("access token is expired") } if a.CachedAt.T.IsZero() { return fmt.Errorf("access token does not have CachedAt set") } return nil } // IDToken is the JSON representation of an MSAL id token for encoding to storage. type IDToken struct { HomeAccountID string `json:"home_account_id,omitempty"` Environment string `json:"environment,omitempty"` Realm string `json:"realm,omitempty"` CredentialType string `json:"credential_type,omitempty"` ClientID string `json:"client_id,omitempty"` Secret string `json:"secret,omitempty"` UserAssertionHash string `json:"user_assertion_hash,omitempty"` AdditionalFields map[string]interface{} } // IsZero determines if IDToken is the zero value. func (i IDToken) IsZero() bool { v := reflect.ValueOf(i) for i := 0; i < v.NumField(); i++ { field := v.Field(i) if !field.IsZero() { switch field.Kind() { case reflect.Map, reflect.Slice: if field.Len() == 0 { continue } } return false } } return true } // NewIDToken is the constructor for IDToken. func NewIDToken(homeID, env, realm, clientID, idToken string) IDToken { return IDToken{ HomeAccountID: homeID, Environment: env, Realm: realm, CredentialType: "IDToken", ClientID: clientID, Secret: idToken, } } // Key outputs the key that can be used to uniquely look up this entry in a map. func (id IDToken) Key() string { return strings.Join( []string{id.HomeAccountID, id.Environment, id.CredentialType, id.ClientID, id.Realm}, shared.CacheKeySeparator, ) } // AppMetaData is the JSON representation of application metadata for encoding to storage. type AppMetaData struct { FamilyID string `json:"family_id,omitempty"` ClientID string `json:"client_id,omitempty"` Environment string `json:"environment,omitempty"` AdditionalFields map[string]interface{} } // NewAppMetaData is the constructor for AppMetaData. func NewAppMetaData(familyID, clientID, environment string) AppMetaData { return AppMetaData{ FamilyID: familyID, ClientID: clientID, Environment: environment, } } // Key outputs the key that can be used to uniquely look up this entry in a map. func (a AppMetaData) Key() string { return strings.Join( []string{"AppMetaData", a.Environment, a.ClientID}, shared.CacheKeySeparator, ) } microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/items_test.go000066400000000000000000000372171442026362400330600ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package storage import ( stdJSON "encoding/json" "os" "testing" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" "github.com/kylelemons/godebug/pretty" ) var ( testHID = "testHID" env = "env" credential = "AccessToken" clientID = "clientID" realm = "realm" scopes = "user.read" secret = "access" expiresOn = time.Unix(1592049600, 0) extExpiresOn = time.Unix(1592049600, 0) cachedAt = time.Unix(1592049600, 0) atCacheEntity = &AccessToken{ HomeAccountID: testHID, Environment: env, CredentialType: credential, ClientID: clientID, Realm: realm, Scopes: scopes, Secret: secret, ExpiresOn: internalTime.Unix{T: expiresOn}, ExtendedExpiresOn: internalTime.Unix{T: extExpiresOn}, CachedAt: internalTime.Unix{T: cachedAt}, } ) func TestCreateAccessToken(t *testing.T) { testExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) testExtExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) testCachedAt := time.Date(2020, time.June, 13, 11, 0, 0, 0, time.UTC) actualAt := NewAccessToken("testHID", "env", "realm", "clientID", testCachedAt, testExpiresOn, testExtExpiresOn, "user.read", "access", ) if !extExpiresOn.Equal(actualAt.ExtendedExpiresOn.T) { t.Errorf("Actual ext expires on %s differs from expected ext expires on %s", actualAt.ExtendedExpiresOn, extExpiresOn) } } func TestKeyForAccessToken(t *testing.T) { const want = "testHID-env-AccessToken-clientID-realm-user.read" got := atCacheEntity.Key() if got != want { t.Errorf("TestKeyForAccessToken: got %s, want %s", got, want) } } func TestAccessTokenUnmarshal(t *testing.T) { jsonMap := map[string]interface{}{ "home_account_id": "testHID", "environment": "env", "extra": "this_is_extra", "cached_at": "100", } jsonData, err := stdJSON.Marshal(jsonMap) if err != nil { panic(err) } want := &AccessToken{ HomeAccountID: testHID, Environment: env, CachedAt: internalTime.Unix{T: time.Unix(100, 0)}, AdditionalFields: map[string]interface{}{ "extra": json.MarshalRaw("this_is_extra"), }, } got := &AccessToken{} err = json.Unmarshal(jsonData, got) if err != nil { t.Errorf("Error is supposed to be nil, but it is %v", err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestAccessTokenUnmarshal(access tokens): -want/+got:\n %s", diff) } } func TestAccessTokenMarshal(t *testing.T) { accessToken := &AccessToken{ HomeAccountID: testHID, Environment: "", CachedAt: internalTime.Unix{T: time.Unix(100, 0)}, CredentialType: credential, AdditionalFields: map[string]interface{}{ "extra": json.MarshalRaw("this_is_extra"), }, } b, err := json.Marshal(accessToken) if err != nil { t.Fatalf("TestAccessTokenMarshal: unable to marshal: %s", err) } got := AccessToken{} if err := json.Unmarshal(b, &got); err != nil { t.Fatalf("TestAccessTokenMarshal: unable to take JSON byte output and unmarshal: %s", err) } if diff := pretty.Compare(accessToken, got); diff != "" { t.Errorf("TestAccessTokenConvertToJSONMap(access token): -want/+got:\n%s", diff) } } var ( appClient = "cid" appEnv = "env" appMeta = &AppMetaData{ ClientID: appClient, Environment: appEnv, FamilyID: "", } ) func TestKeyForAppMetaData(t *testing.T) { want := "AppMetaData-env-cid" got := appMeta.Key() if want != got { t.Errorf("actual key %v differs from expected key %v", want, got) } } func TestAppMetaDataUnmarshal(t *testing.T) { jsonMap := map[string]interface{}{ "environment": "env", "extra": "this_is_extra", "cached_at": "100", "client_id": "cid", "family_id": nil, } want := AppMetaData{ ClientID: "cid", Environment: "env", AdditionalFields: map[string]interface{}{ "extra": json.MarshalRaw("this_is_extra"), "cached_at": json.MarshalRaw("100"), }, } b, err := stdJSON.Marshal(jsonMap) if err != nil { panic(err) } got := AppMetaData{} if err := json.Unmarshal(b, &got); err != nil { t.Fatalf("TestAppMetaDataUnmarshal(unmarshal): got err == %s, want err == nil", err) } if diff := pretty.Compare(want, got); diff != "" { t.Fatalf("TestAppMetaDataUnmarshal: -want/+got:\n%s", diff) } } func TestAppMetaDataMarshal(t *testing.T) { AppMetaData := AppMetaData{ Environment: "", ClientID: appClient, FamilyID: "", AdditionalFields: map[string]interface{}{ "extra": "this_is_extra", "cached_at": "100", }, } want := map[string]interface{}{ "client_id": "cid", "extra": "this_is_extra", "cached_at": "100", } b, err := json.Marshal(AppMetaData) if err != nil { panic(err) } got := map[string]interface{}{} if err := stdJSON.Unmarshal(b, &got); err != nil { t.Fatalf("TestAppMetaDataMarshal(unmarshal): err == %s, want err == nil", err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestAppMetaDataConvertToJSONMap: -want/+got:\n%s", diff) } } func TestContractUnmarshalJSON(t *testing.T) { testCache, err := os.ReadFile(testFile) if err != nil { panic(err) } got := Contract{} err = json.Unmarshal(testCache, &got) if err != nil { t.Fatalf("TestContractUnmarshalJSON(unmarshal): %v", err) } want := Contract{ AccessTokens: map[string]AccessToken{ "an-entry": { AdditionalFields: map[string]interface{}{ "foo": json.MarshalRaw("bar"), }, }, "uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": { Environment: defaultEnvironment, CredentialType: "AccessToken", Secret: accessTokenSecret, Realm: defaultRealm, Scopes: defaultScopes, ClientID: defaultClientID, CachedAt: internalTime.Unix{T: atCached}, HomeAccountID: defaultHID, ExpiresOn: internalTime.Unix{T: atExpires}, ExtendedExpiresOn: internalTime.Unix{T: atExpires}, }, }, Accounts: map[string]shared.Account{ "uid.utid-login.windows.net-contoso": { PreferredUsername: "John Doe", LocalAccountID: "object1234", Realm: "contoso", Environment: "login.windows.net", HomeAccountID: "uid.utid", AuthorityType: "MSSTS", }, }, RefreshTokens: map[string]accesstokens.RefreshToken{ "uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": { Target: defaultScopes, Environment: defaultEnvironment, CredentialType: "RefreshToken", Secret: rtSecret, ClientID: defaultClientID, HomeAccountID: defaultHID, }, }, IDTokens: map[string]IDToken{ "uid.utid-login.windows.net-idtoken-my_client_id-contoso-": { Realm: defaultRealm, Environment: defaultEnvironment, CredentialType: idCred, Secret: idSecret, ClientID: defaultClientID, HomeAccountID: defaultHID, }, }, AppMetaData: map[string]AppMetaData{ "AppMetadata-login.windows.net-my_client_id": { Environment: defaultEnvironment, FamilyID: "", ClientID: defaultClientID, }, }, AdditionalFields: map[string]interface{}{ "unknownEntity": json.MarshalRaw( map[string]interface{}{ "field1": "1", "field2": "whats", }, ), }, } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestContractUnmarshalJSON: -want/+got:\n%s", diff) t.Errorf(string(got.AdditionalFields["unknownEntity"].(stdJSON.RawMessage))) } } func TestContractMarshalJSON(t *testing.T) { want := Contract{ AccessTokens: map[string]AccessToken{ "an-entry": { AdditionalFields: map[string]interface{}{ "foo": json.MarshalRaw("bar"), }, }, "uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": { Environment: defaultEnvironment, CredentialType: "AccessToken", Secret: accessTokenSecret, Realm: defaultRealm, Scopes: defaultScopes, ClientID: defaultClientID, CachedAt: internalTime.Unix{T: atCached}, HomeAccountID: defaultHID, ExpiresOn: internalTime.Unix{T: atExpires}, ExtendedExpiresOn: internalTime.Unix{T: atExpires}, }, }, RefreshTokens: map[string]accesstokens.RefreshToken{ "uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": { Target: defaultScopes, Environment: defaultEnvironment, CredentialType: "RefreshToken", Secret: rtSecret, ClientID: defaultClientID, HomeAccountID: defaultHID, }, }, IDTokens: map[string]IDToken{ "uid.utid-login.windows.net-idtoken-my_client_id-contoso-": { Realm: defaultRealm, Environment: defaultEnvironment, CredentialType: idCred, Secret: idSecret, ClientID: defaultClientID, HomeAccountID: defaultHID, }, }, Accounts: map[string]shared.Account{ "uid.utid-login.windows.net-contoso": { PreferredUsername: accUser, LocalAccountID: accLID, Realm: defaultRealm, Environment: defaultEnvironment, HomeAccountID: defaultHID, AuthorityType: accAuth, }, }, AppMetaData: map[string]AppMetaData{ "AppMetadata-login.windows.net-my_client_id": { Environment: defaultEnvironment, FamilyID: "", ClientID: defaultClientID, }, }, AdditionalFields: map[string]interface{}{ "unknownEntity": json.MarshalRaw( map[string]interface{}{ "field1": "1", "field2": "whats", }, ), }, } b, err := json.Marshal(want) if err != nil { t.Fatalf("TestContractMarshalJSON(marshal): got err == %s, want err == nil", err) } got := Contract{} if err := json.Unmarshal(b, &got); err != nil { t.Fatalf("TestContractMarshalJSON(unmarshal back): got err == %s, want err == nil", err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestContractMarshalJSON: -want/+got:\n%s", diff) } } var ( idHid = "HID" idEnv = "env" idCredential = "IdToken" idClient = "clientID" idRealm = "realm" idTokSecret = "id" ) var idToken = IDToken{ HomeAccountID: idHid, Environment: idEnv, CredentialType: idCredential, ClientID: idClient, Realm: idRealm, Secret: idTokSecret, } func TestKeyForIDToken(t *testing.T) { want := "HID-env-IdToken-clientID-realm" if idToken.Key() != want { t.Errorf("actual key %v differs from expected key %v", idToken.Key(), want) } } func TestIDTokenUnmarshal(t *testing.T) { jsonMap := map[string]interface{}{ "home_account_id": "HID", "environment": "env", "extra": "this_is_extra", } b, err := stdJSON.Marshal(jsonMap) if err != nil { panic(err) } want := IDToken{ HomeAccountID: "HID", Environment: "env", AdditionalFields: map[string]interface{}{ "extra": json.MarshalRaw("this_is_extra"), }, } got := IDToken{} if err := json.Unmarshal(b, &got); err != nil { panic(err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestIDTokenUnmarshal: -want/+got:\n%s", diff) } } func TestIDTokenMarshal(t *testing.T) { idToken := IDToken{ HomeAccountID: idHid, Environment: idEnv, Realm: "", AdditionalFields: map[string]interface{}{"extra": "this_is_extra"}, } want := map[string]interface{}{ "home_account_id": "HID", "environment": "env", "extra": "this_is_extra", } b, err := json.Marshal(idToken) if err != nil { panic(err) } got := map[string]interface{}{} if err := stdJSON.Unmarshal(b, &got); err != nil { panic(err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestIDTokenMarshal: -want/+got:\n%s", diff) } } var ( hid = "HID" rtEnv = "env" rtClientID = "clientID" rtCredential = "accesstokens.RefreshToken" refSecret = "secret" ) var rt = &accesstokens.RefreshToken{ HomeAccountID: hid, Environment: env, ClientID: rtClientID, CredentialType: rtCredential, Secret: refSecret, } func TestNewRefreshToken(t *testing.T) { got := accesstokens.NewRefreshToken("HID", "env", "clientID", "secret", "") if refSecret != got.Secret { t.Errorf("expected secret %s differs from actualSecret %s", refSecret, got.Secret) } } func TestKeyForRefreshToken(t *testing.T) { want := "HID-env-accesstokens.RefreshToken-clientID" got := rt.Key() if want != got { t.Errorf("Actual key %v differs from expected key %v", got, want) } } func TestRefreshTokenUnmarshal(t *testing.T) { jsonMap := map[string]interface{}{ "home_account_id": "hid", "environment": "env", "extra": "this_is_extra", "secret": "secret", } b, err := stdJSON.Marshal(jsonMap) if err != nil { panic(err) } want := accesstokens.RefreshToken{ HomeAccountID: "hid", Environment: "env", Secret: "secret", AdditionalFields: map[string]interface{}{ "extra": json.MarshalRaw("this_is_extra"), }, } got := accesstokens.RefreshToken{} err = json.Unmarshal(b, &got) if err != nil { panic(err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestRefreshTokenUnmarshal: -want/+got:\n%s", diff) } } func TestRefreshTokenMarshal(t *testing.T) { refreshToken := accesstokens.RefreshToken{ HomeAccountID: "", Environment: rtEnv, CredentialType: rtCredential, Secret: refSecret, AdditionalFields: map[string]interface{}{ "extra": "this_is_extra", }, } want := map[string]interface{}{ "environment": "env", "credential_type": "accesstokens.RefreshToken", "secret": "secret", "extra": "this_is_extra", } b, err := json.Marshal(refreshToken) if err != nil { panic(err) } got := map[string]interface{}{} if err := stdJSON.Unmarshal(b, &got); err != nil { panic(err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestRefreshTokenMarshal: -want/+got:\n%s", diff) } } func TestRegression196(t *testing.T) { // https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/196 // Note: all values here look real, but they have been altered to prevent any exposure // of even a temporary security value. contract := &Contract{ AccessTokens: map[string]AccessToken{ "-login.microsoftonline.com-AccessToken-5b0c5134eacb-https://graph.microsoft.com/.default": { HomeAccountID: "", Environment: "login.microsoftonline.com", Realm: "2cce-489d-4002-8293-5b0eacb", CredentialType: "AccessToken", ClientID: "841-b1d2-460b-bc46-11cfb", Secret: "secret", Scopes: "https://graph.microsoft.com/.default", ExpiresOn: internalTime.Unix{T: expiresOn}, ExtendedExpiresOn: internalTime.Unix{T: extExpiresOn}, CachedAt: internalTime.Unix{T: cachedAt}, }, }, AppMetaData: map[string]AppMetaData{ "AppMetaData-login.microsoftonline.com-84a31-b1d2-460b-bc46-1158fb": { ClientID: "8431-bd2-460b-bc46-11c4c8fb", Environment: "login.microsoftonline.com", }, }, } b, err := json.Marshal(contract) if err != nil { t.Fatalf("TestRegression196: Marshal had unexpected error: %v", err) } got := &Contract{} if err := json.Unmarshal(b, got); err != nil { t.Fatalf("TestRegression196: Unmarshal had unexpected error: %v, json was:\n%s", err, string(b)) } if diff := pretty.Compare(contract, got); diff != "" { t.Fatalf("TestRegression196: -want/+got:\n%s", diff) } } partitioned_storage.go000066400000000000000000000366131442026362400346660ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package storage import ( "context" "errors" "fmt" "strings" "sync" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) // PartitionedManager is a partitioned in-memory cache of access tokens, accounts and meta data. type PartitionedManager struct { contract *InMemoryContract contractMu sync.RWMutex requests aadInstanceDiscoveryer // *oauth.Token aadCacheMu sync.RWMutex aadCache map[string]authority.InstanceDiscoveryMetadata } // NewPartitionedManager is the constructor for PartitionedManager. func NewPartitionedManager(requests *oauth.Client) *PartitionedManager { m := &PartitionedManager{requests: requests, aadCache: make(map[string]authority.InstanceDiscoveryMetadata)} m.contract = NewInMemoryContract() return m } // Read reads a storage token from the cache if it exists. func (m *PartitionedManager) Read(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) { tr := TokenResponse{} realm := authParameters.AuthorityInfo.Tenant clientID := authParameters.ClientID scopes := authParameters.Scopes // fetch metadata if instanceDiscovery is enabled aliases := []string{authParameters.AuthorityInfo.Host} if !authParameters.AuthorityInfo.InstanceDiscoveryDisabled { metadata, err := m.getMetadataEntry(ctx, authParameters.AuthorityInfo) if err != nil { return TokenResponse{}, err } aliases = metadata.Aliases } userAssertionHash := authParameters.AssertionHash() partitionKeyFromRequest := userAssertionHash // errors returned by read* methods indicate a cache miss and are therefore non-fatal. We continue populating // TokenResponse fields so that e.g. lack of an ID token doesn't prevent the caller from receiving a refresh token. accessToken, err := m.readAccessToken(aliases, realm, clientID, userAssertionHash, scopes, partitionKeyFromRequest) if err == nil { tr.AccessToken = accessToken } idToken, err := m.readIDToken(aliases, realm, clientID, userAssertionHash, getPartitionKeyIDTokenRead(accessToken)) if err == nil { tr.IDToken = idToken } if appMetadata, err := m.readAppMetaData(aliases, clientID); err == nil { // we need the family ID to identify the correct refresh token, if any familyID := appMetadata.FamilyID refreshToken, err := m.readRefreshToken(aliases, familyID, clientID, userAssertionHash, partitionKeyFromRequest) if err == nil { tr.RefreshToken = refreshToken } } account, err := m.readAccount(aliases, realm, userAssertionHash, idToken.HomeAccountID) if err == nil { tr.Account = account } return tr, nil } // Write writes a token response to the cache and returns the account information the token is stored with. func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) { authParameters.HomeAccountID = tokenResponse.ClientInfo.HomeAccountID() homeAccountID := authParameters.HomeAccountID environment := authParameters.AuthorityInfo.Host realm := authParameters.AuthorityInfo.Tenant clientID := authParameters.ClientID target := strings.Join(tokenResponse.GrantedScopes.Slice, scopeSeparator) userAssertionHash := authParameters.AssertionHash() cachedAt := time.Now() var account shared.Account if len(tokenResponse.RefreshToken) > 0 { refreshToken := accesstokens.NewRefreshToken(homeAccountID, environment, clientID, tokenResponse.RefreshToken, tokenResponse.FamilyID) if authParameters.AuthorizationType == authority.ATOnBehalfOf { refreshToken.UserAssertionHash = userAssertionHash } if err := m.writeRefreshToken(refreshToken, getPartitionKeyRefreshToken(refreshToken)); err != nil { return account, err } } if len(tokenResponse.AccessToken) > 0 { accessToken := NewAccessToken( homeAccountID, environment, realm, clientID, cachedAt, tokenResponse.ExpiresOn.T, tokenResponse.ExtExpiresOn.T, target, tokenResponse.AccessToken, ) if authParameters.AuthorizationType == authority.ATOnBehalfOf { accessToken.UserAssertionHash = userAssertionHash // get Hash method on this } // Since we have a valid access token, cache it before moving on. if err := accessToken.Validate(); err == nil { if err := m.writeAccessToken(accessToken, getPartitionKeyAccessToken(accessToken)); err != nil { return account, err } } else { return shared.Account{}, err } } idTokenJwt := tokenResponse.IDToken if !idTokenJwt.IsZero() { idToken := NewIDToken(homeAccountID, environment, realm, clientID, idTokenJwt.RawToken) if authParameters.AuthorizationType == authority.ATOnBehalfOf { idToken.UserAssertionHash = userAssertionHash } if err := m.writeIDToken(idToken, getPartitionKeyIDToken(idToken)); err != nil { return shared.Account{}, err } localAccountID := idTokenJwt.LocalAccountID() authorityType := authParameters.AuthorityInfo.AuthorityType preferredUsername := idTokenJwt.UPN if idTokenJwt.PreferredUsername != "" { preferredUsername = idTokenJwt.PreferredUsername } account = shared.NewAccount( homeAccountID, environment, realm, localAccountID, authorityType, preferredUsername, ) if authParameters.AuthorizationType == authority.ATOnBehalfOf { account.UserAssertionHash = userAssertionHash } if err := m.writeAccount(account, getPartitionKeyAccount(account)); err != nil { return shared.Account{}, err } } AppMetaData := NewAppMetaData(tokenResponse.FamilyID, clientID, environment) if err := m.writeAppMetaData(AppMetaData); err != nil { return shared.Account{}, err } return account, nil } func (m *PartitionedManager) getMetadataEntry(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { md, err := m.aadMetadataFromCache(ctx, authorityInfo) if err != nil { // not in the cache, retrieve it md, err = m.aadMetadata(ctx, authorityInfo) } return md, err } func (m *PartitionedManager) aadMetadataFromCache(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { m.aadCacheMu.RLock() defer m.aadCacheMu.RUnlock() metadata, ok := m.aadCache[authorityInfo.Host] if ok { return metadata, nil } return metadata, errors.New("not found") } func (m *PartitionedManager) aadMetadata(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { discoveryResponse, err := m.requests.AADInstanceDiscovery(ctx, authorityInfo) if err != nil { return authority.InstanceDiscoveryMetadata{}, err } m.aadCacheMu.Lock() defer m.aadCacheMu.Unlock() for _, metadataEntry := range discoveryResponse.Metadata { for _, aliasedAuthority := range metadataEntry.Aliases { m.aadCache[aliasedAuthority] = metadataEntry } } if _, ok := m.aadCache[authorityInfo.Host]; !ok { m.aadCache[authorityInfo.Host] = authority.InstanceDiscoveryMetadata{ PreferredNetwork: authorityInfo.Host, PreferredCache: authorityInfo.Host, } } return m.aadCache[authorityInfo.Host], nil } func (m *PartitionedManager) readAccessToken(envAliases []string, realm, clientID, userAssertionHash string, scopes []string, partitionKey string) (AccessToken, error) { m.contractMu.RLock() defer m.contractMu.RUnlock() if accessTokens, ok := m.contract.AccessTokensPartition[partitionKey]; ok { // TODO: linear search (over a map no less) is slow for a large number (thousands) of tokens. // this shows up as the dominating node in a profile. for real-world scenarios this likely isn't // an issue, however if it does become a problem then we know where to look. for _, at := range accessTokens { if at.Realm == realm && at.ClientID == clientID && at.UserAssertionHash == userAssertionHash { if checkAlias(at.Environment, envAliases) { if isMatchingScopes(scopes, at.Scopes) { return at, nil } } } } } return AccessToken{}, fmt.Errorf("access token not found") } func (m *PartitionedManager) writeAccessToken(accessToken AccessToken, partitionKey string) error { m.contractMu.Lock() defer m.contractMu.Unlock() key := accessToken.Key() if m.contract.AccessTokensPartition[partitionKey] == nil { m.contract.AccessTokensPartition[partitionKey] = make(map[string]AccessToken) } m.contract.AccessTokensPartition[partitionKey][key] = accessToken return nil } func matchFamilyRefreshTokenObo(rt accesstokens.RefreshToken, userAssertionHash string, envAliases []string) bool { return rt.UserAssertionHash == userAssertionHash && checkAlias(rt.Environment, envAliases) && rt.FamilyID != "" } func matchClientIDRefreshTokenObo(rt accesstokens.RefreshToken, userAssertionHash string, envAliases []string, clientID string) bool { return rt.UserAssertionHash == userAssertionHash && checkAlias(rt.Environment, envAliases) && rt.ClientID == clientID } func (m *PartitionedManager) readRefreshToken(envAliases []string, familyID, clientID, userAssertionHash, partitionKey string) (accesstokens.RefreshToken, error) { byFamily := func(rt accesstokens.RefreshToken) bool { return matchFamilyRefreshTokenObo(rt, userAssertionHash, envAliases) } byClient := func(rt accesstokens.RefreshToken) bool { return matchClientIDRefreshTokenObo(rt, userAssertionHash, envAliases, clientID) } var matchers []func(rt accesstokens.RefreshToken) bool if familyID == "" { matchers = []func(rt accesstokens.RefreshToken) bool{ byClient, byFamily, } } else { matchers = []func(rt accesstokens.RefreshToken) bool{ byFamily, byClient, } } // TODO(keegan): All the tests here pass, but Bogdan says this is // more complicated. I'm opening an issue for this to have him // review the tests and suggest tests that would break this so // we can re-write against good tests. His comments as follow: // The algorithm is a bit more complex than this, I assume there are some tests covering everything. I would keep the order as is. // The algorithm is: // If application is NOT part of the family, search by client_ID // If app is part of the family or if we DO NOT KNOW if it's part of the family, search by family ID, then by client_id (we will know if an app is part of the family after the first token response). // https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/311fe8b16e7c293462806f397e189a6aa1159769/src/client/Microsoft.Identity.Client/Internal/Requests/Silent/CacheSilentStrategy.cs#L95 m.contractMu.RLock() defer m.contractMu.RUnlock() for _, matcher := range matchers { for _, rt := range m.contract.RefreshTokensPartition[partitionKey] { if matcher(rt) { return rt, nil } } } return accesstokens.RefreshToken{}, fmt.Errorf("refresh token not found") } func (m *PartitionedManager) writeRefreshToken(refreshToken accesstokens.RefreshToken, partitionKey string) error { m.contractMu.Lock() defer m.contractMu.Unlock() key := refreshToken.Key() if m.contract.AccessTokensPartition[partitionKey] == nil { m.contract.RefreshTokensPartition[partitionKey] = make(map[string]accesstokens.RefreshToken) } m.contract.RefreshTokensPartition[partitionKey][key] = refreshToken return nil } func (m *PartitionedManager) readIDToken(envAliases []string, realm, clientID, userAssertionHash, partitionKey string) (IDToken, error) { m.contractMu.RLock() defer m.contractMu.RUnlock() for _, idt := range m.contract.IDTokensPartition[partitionKey] { if idt.Realm == realm && idt.ClientID == clientID && idt.UserAssertionHash == userAssertionHash { if checkAlias(idt.Environment, envAliases) { return idt, nil } } } return IDToken{}, fmt.Errorf("token not found") } func (m *PartitionedManager) writeIDToken(idToken IDToken, partitionKey string) error { key := idToken.Key() m.contractMu.Lock() defer m.contractMu.Unlock() if m.contract.IDTokensPartition[partitionKey] == nil { m.contract.IDTokensPartition[partitionKey] = make(map[string]IDToken) } m.contract.IDTokensPartition[partitionKey][key] = idToken return nil } func (m *PartitionedManager) readAccount(envAliases []string, realm, UserAssertionHash, partitionKey string) (shared.Account, error) { m.contractMu.RLock() defer m.contractMu.RUnlock() // You might ask why, if cache.Accounts is a map, we would loop through all of these instead of using a key. // We only use a map because the storage contract shared between all language implementations says use a map. // We can't change that. The other is because the keys are made using a specific "env", but here we are allowing // a match in multiple envs (envAlias). That means we either need to hash each possible keyand do the lookup // or just statically check. Since the design is to have a storage.Manager per user, the amount of keys stored // is really low (say 2). Each hash is more expensive than the entire iteration. for _, acc := range m.contract.AccountsPartition[partitionKey] { if checkAlias(acc.Environment, envAliases) && acc.UserAssertionHash == UserAssertionHash && acc.Realm == realm { return acc, nil } } return shared.Account{}, fmt.Errorf("account not found") } func (m *PartitionedManager) writeAccount(account shared.Account, partitionKey string) error { key := account.Key() m.contractMu.Lock() defer m.contractMu.Unlock() if m.contract.AccountsPartition[partitionKey] == nil { m.contract.AccountsPartition[partitionKey] = make(map[string]shared.Account) } m.contract.AccountsPartition[partitionKey][key] = account return nil } func (m *PartitionedManager) readAppMetaData(envAliases []string, clientID string) (AppMetaData, error) { m.contractMu.RLock() defer m.contractMu.RUnlock() for _, app := range m.contract.AppMetaData { if checkAlias(app.Environment, envAliases) && app.ClientID == clientID { return app, nil } } return AppMetaData{}, fmt.Errorf("not found") } func (m *PartitionedManager) writeAppMetaData(AppMetaData AppMetaData) error { key := AppMetaData.Key() m.contractMu.Lock() defer m.contractMu.Unlock() m.contract.AppMetaData[key] = AppMetaData return nil } // update updates the internal cache object. This is for use in tests, other uses are not // supported. func (m *PartitionedManager) update(cache *InMemoryContract) { m.contractMu.Lock() defer m.contractMu.Unlock() m.contract = cache } // Marshal implements cache.Marshaler. func (m *PartitionedManager) Marshal() ([]byte, error) { return json.Marshal(m.contract) } // Unmarshal implements cache.Unmarshaler. func (m *PartitionedManager) Unmarshal(b []byte) error { m.contractMu.Lock() defer m.contractMu.Unlock() contract := NewInMemoryContract() err := json.Unmarshal(b, contract) if err != nil { return err } m.contract = contract return nil } func getPartitionKeyAccessToken(item AccessToken) string { if item.UserAssertionHash != "" { return item.UserAssertionHash } return item.HomeAccountID } func getPartitionKeyRefreshToken(item accesstokens.RefreshToken) string { if item.UserAssertionHash != "" { return item.UserAssertionHash } return item.HomeAccountID } func getPartitionKeyIDToken(item IDToken) string { return item.HomeAccountID } func getPartitionKeyAccount(item shared.Account) string { return item.HomeAccountID } func getPartitionKeyIDTokenRead(item AccessToken) string { return item.HomeAccountID } partitioned_storage_test.go000066400000000000000000000435251442026362400357250ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package storage import ( "context" "fmt" "reflect" "testing" "time" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" "github.com/kylelemons/godebug/pretty" ) func newPartitionedManagerForTest(authorityClient aadInstanceDiscoveryer) *PartitionedManager { m := &PartitionedManager{requests: authorityClient, aadCache: make(map[string]authority.InstanceDiscoveryMetadata)} m.contract = NewInMemoryContract() return m } func TestOBOAccessTokenScopes(t *testing.T) { fakeAuthority := "fakeauthority" mgr := newPartitionedManagerForTest(&fakeDiscoveryResponser{ ret: authority.InstanceDiscoveryResponse{ Metadata: []authority.InstanceDiscoveryMetadata{ {Aliases: []string{fakeAuthority}}, }, }, }) upn := "upn" idt := accesstokens.IDToken{ Oid: upn + "-oid", PreferredUsername: upn, TenantID: "tenant", UPN: upn, } authParams := []authority.AuthParams{} for _, scope := range [][]string{{"scopeA"}, {"scopeB"}} { ap := authority.AuthParams{ AuthorityInfo: authority.Info{ AuthorityType: authority.AAD, Host: fakeAuthority, Tenant: idt.TenantID, }, AuthorizationType: authority.ATOnBehalfOf, ClientID: "client-id", Scopes: scope, UserAssertion: upn + "-assertion", Username: idt.PreferredUsername, } _, err := mgr.Write( ap, accesstokens.TokenResponse{ AccessToken: scope[0] + "-at", ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: scope}, IDToken: idt, RefreshToken: upn + "-rt", }, ) if err != nil { t.Fatal(err) } authParams = append(authParams, ap) } for _, ap := range authParams { tr, err := mgr.Read(context.Background(), ap) if err != nil { t.Fatal(err) } if tr.AccessToken.Secret != ap.Scopes[0]+"-at" { t.Fatalf(`unexpected access token "%s"`, tr.AccessToken.Secret) } } } func TestOBOPartitioning(t *testing.T) { fakeAuthority := "fakeauthority" mgr := newPartitionedManagerForTest(&fakeDiscoveryResponser{ ret: authority.InstanceDiscoveryResponse{ Metadata: []authority.InstanceDiscoveryMetadata{ {Aliases: []string{fakeAuthority}}, }, }, }) scopes := []string{"scope"} accounts := make([]shared.Account, 2) authParams := make([]authority.AuthParams, len(accounts)) for i := 0; i < len(accounts); i++ { upn := fmt.Sprintf("%d", i) idt := accesstokens.IDToken{ Oid: upn + "-oid", PreferredUsername: upn, TenantID: "tenant", UPN: upn, } authParams[i] = authority.AuthParams{ AuthorityInfo: authority.Info{ AuthorityType: authority.AAD, Host: fakeAuthority, Tenant: idt.TenantID, }, AuthorizationType: authority.ATOnBehalfOf, ClientID: "client-id", Scopes: scopes, UserAssertion: upn + "-assertion", Username: idt.PreferredUsername, } account, err := mgr.Write( authParams[i], accesstokens.TokenResponse{ AccessToken: upn + "-at", ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: scopes}, IDToken: idt, RefreshToken: upn + "-rt", }, ) if err != nil { t.Fatal(err) } accounts[i] = account } for i, ap := range authParams { tr, err := mgr.Read(context.Background(), ap) if err != nil { t.Fatal(err) } if tr.AccessToken.Secret != accounts[i].PreferredUsername+"-at" { t.Fatalf(`unexpected access token "%s"`, tr.AccessToken.Secret) } } } func TestReadPartitionedAccessToken(t *testing.T) { now := time.Now() testAccessToken := NewAccessToken( "hid", "env", "realm", "cid", now, now, now, "openid user.read", "secret", ) testAccessToken.UserAssertionHash = "user_assertion_hash" cache := &InMemoryContract{ AccessTokensPartition: map[string]map[string]AccessToken{ "at_partition": {testAccessToken.Key(): testAccessToken}, }, } storageManager := newPartitionedManagerForTest(nil) storageManager.update(cache) retAccessToken, err := storageManager.readAccessToken( []string{"hello", "env", "test"}, "realm", "cid", "user_assertion_hash", []string{"user.read", "openid"}, "at_partition", ) if err != nil { t.Errorf("TestReadPartitionedAccessToken: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAccessToken, retAccessToken); diff != "" { t.Fatalf("Returned access token is not the same as expected access token: -want/+got:\n%s", diff) } _, err = storageManager.readAccessToken( []string{"hello", "env", "test"}, "realm", "cid", "this_should_break_it", []string{"user.read", "openid"}, "at_partition", ) if err == nil { t.Errorf("TestReadPartitionedAccessToken: got err == nil, want err != nil") } } func TestWritePartitionedAccessToken(t *testing.T) { now := time.Now() storageManager := newPartitionedManagerForTest(nil) testAccessToken := NewAccessToken( "hid", "env", "realm", "cid", now, now, now, "openid", "secret", ) key := testAccessToken.Key() err := storageManager.writeAccessToken(testAccessToken, "at_partition") if err != nil { t.Fatalf("TestWritePartitionedAccessToken: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAccessToken, storageManager.contract.AccessTokensPartition["at_partition"][key]); diff != "" { t.Errorf("TestWritePartitionedAccessToken: -want/+got:\n%s", diff) } } func TestReadPartitionedAccount(t *testing.T) { testAcc := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") testAcc.UserAssertionHash = "user_assertion_hash" cache := &InMemoryContract{ AccountsPartition: map[string]map[string]shared.Account{ "acc_partition": {testAcc.Key(): testAcc}, }, } storageManager := newPartitionedManagerForTest(nil) storageManager.update(cache) returnedAccount, err := storageManager.readAccount([]string{"hello", "env", "test"}, "realm", "user_assertion_hash", "acc_partition") if err != nil { t.Fatalf("TestReadPartitionedAccount: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAcc, returnedAccount); diff != "" { t.Errorf("TestReadPartitionedAccount: -want/+got:\n%s", diff) } _, err = storageManager.readAccount([]string{"hello", "env", "test"}, "realm", "this_should_break_it", "acc_partition") if err == nil { t.Errorf("TestReadPartitionedAccount: got err == nil, want err != nil") } } func TestWritePartitionedAccount(t *testing.T) { storageManager := newPartitionedManagerForTest(nil) testAcc := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") testAcc.UserAssertionHash = "user_assertion_hash" key := testAcc.Key() err := storageManager.writeAccount(testAcc, "acc_partition") if err != nil { t.Fatalf("TestWritePartitionedAccount: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAcc, storageManager.contract.AccountsPartition["acc_partition"][key]); diff != "" { t.Errorf("TestWritePartitionedAccount: -want/+got:\n%s", diff) } } func TestReadAppMetaDataPartitionedManager(t *testing.T) { testAppMeta := NewAppMetaData("fid", "cid", "env") cache := &InMemoryContract{ AppMetaData: map[string]AppMetaData{ testAppMeta.Key(): testAppMeta, }, } storageManager := newPartitionedManagerForTest(nil) storageManager.update(cache) returnedAppMeta, err := storageManager.readAppMetaData([]string{"hello", "test", "env"}, "cid") if err != nil { t.Fatalf("TestreadAppMetaDataPartitionedManager(readAppMetaData): got err == %s, want err == nil", err) } if diff := pretty.Compare(testAppMeta, returnedAppMeta); diff != "" { t.Fatalf("TestreadAppMetaDataPartitionedManager(readAppMetaData): -want/+got:\n%s", diff) } _, err = storageManager.readAppMetaData([]string{"hello", "test", "env"}, "break_this") if err == nil { t.Fatalf("TestreadAppMetaDataPartitionedManager(bad readAppMetaData): got err == nil, want err != nil") } } func TestWriteAppMetaDataPartitionedManager(t *testing.T) { storageManager := newPartitionedManagerForTest(nil) testAppMeta := NewAppMetaData("fid", "cid", "env") key := testAppMeta.Key() err := storageManager.writeAppMetaData(testAppMeta) if err != nil { t.Fatalf("TestwriteAppMetaDataPartitionedManager: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAppMeta, storageManager.contract.AppMetaData[key]); diff != "" { t.Errorf("TestwriteAppMetaDataPartitionedManager: -want/+got:\n%s", diff) } } func TestReadPartitionedIDToken(t *testing.T) { testIDToken := NewIDToken( "hid", "env", "realm", "cid", "secret", ) testIDToken.UserAssertionHash = "user_assertion_hash" cache := &InMemoryContract{ IDTokensPartition: map[string]map[string]IDToken{ "idt_partition": {testIDToken.Key(): testIDToken}, }, } storageManager := newPartitionedManagerForTest(nil) storageManager.update(cache) returnedIDToken, err := storageManager.readIDToken( []string{"hello", "env", "test"}, "realm", "cid", "user_assertion_hash", "idt_partition", ) if err != nil { panic(err) } if diff := pretty.Compare(testIDToken, returnedIDToken); diff != "" { t.Fatalf("TestReadPartitionedIDToken(good token): -want/+got:\n%s", diff) } _, err = storageManager.readIDToken( []string{"hello", "env", "test"}, "realm", "cid", "this_should_break_it", "idt_partition", ) if err == nil { t.Errorf("TestReadPartitionedIDToken(bad token): got err == nil, want err != nil") } } func TestWritePartitionedIDToken(t *testing.T) { storageManager := newPartitionedManagerForTest(nil) testIDToken := NewIDToken( "hid", "env", "realm", "cid", "secret", ) testIDToken.UserAssertionHash = "user_assertion_hash" key := testIDToken.Key() err := storageManager.writeIDToken(testIDToken, "idt_partition") if err != nil { t.Fatalf("TestWritePartitionedIDToken: got err == %s, want err == nil", err) } if diff := pretty.Compare(testIDToken, storageManager.contract.IDTokensPartition["idt_partition"][key]); diff != "" { t.Errorf("TestWritePartitionedIDToken: -want/+got:\n%s", diff) } } func TestReadPartionedRefreshToken(t *testing.T) { testRefreshTokenWithFID := accesstokens.NewRefreshToken( "hid", "env", "cid", "secret", "fid", ) testRefreshTokenWithFID.UserAssertionHash = "user_assertion_hash" testRefreshTokenWoFID := accesstokens.NewRefreshToken( "hid", "env", "cid", "secret", "", ) testRefreshTokenWoFID.UserAssertionHash = "user_assertion_hash" testRefreshTokenWoFIDAltCID := accesstokens.NewRefreshToken( "hid", "env", "cid2", "secret", "", ) testRefreshTokenWoFIDAltCID.UserAssertionHash = "user_assertion_hash" type args struct { envAliases []string familyID string clientID string userAssertionHash string } tests := []struct { name string contract *InMemoryContract args args want accesstokens.RefreshToken err bool }{ { name: "Token without fid, read with fid, cid, env, and hid", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWoFID}, }, }, args: args{ envAliases: []string{"test", "env", "hello"}, familyID: "fid", clientID: "cid", userAssertionHash: "user_assertion_hash", }, want: testRefreshTokenWoFID, }, { name: "Token without fid, read with cid, env, and hid", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWoFID}, }, }, args: args{ envAliases: []string{"test", "env", "hello"}, familyID: "", clientID: "cid", userAssertionHash: "user_assertion_hash", }, want: testRefreshTokenWoFID, }, { name: "Token without fid, verify CID is required", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWoFID}, }, }, args: args{ envAliases: []string{"test", "env", "hello"}, familyID: "", clientID: "", userAssertionHash: "user_assertion_hash", }, err: true, }, { name: "Token without fid, Verify env is required", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWoFID}, }, }, args: args{ envAliases: []string{}, familyID: "", clientID: "", userAssertionHash: "user_assertion_hash", }, err: true, }, { name: "Token with fid, read with fid, cid, env, and hid", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWithFID}, }, }, args: args{ envAliases: []string{"test", "env", "hello"}, familyID: "fid", clientID: "cid", userAssertionHash: "user_assertion_hash", }, want: testRefreshTokenWithFID, }, { name: "Token with fid, read with cid, env, and hid", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWithFID}, }, }, args: args{ envAliases: []string{"test", "env", "hello"}, familyID: "", clientID: "cid", userAssertionHash: "user_assertion_hash", }, want: testRefreshTokenWithFID, }, { name: "Token with fid, verify CID is not required", // match on hid, env, and has fid contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWithFID}, }, }, args: args{ envAliases: []string{"test", "env", "hello"}, familyID: "", clientID: "", userAssertionHash: "user_assertion_hash", }, want: testRefreshTokenWithFID, }, { name: "Token with fid, Verify env is required", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWithFID}, }, }, args: args{ envAliases: []string{}, familyID: "", clientID: "", userAssertionHash: "user_assertion_hash", }, err: true, }, { name: "Multiple items in cache, given a fid, item with fid will be returned", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": { testRefreshTokenWoFID.Key(): testRefreshTokenWoFID, testRefreshTokenWoFID.Key(): testRefreshTokenWithFID, testRefreshTokenWoFIDAltCID.Key(): testRefreshTokenWoFIDAltCID, }, }, }, args: args{ envAliases: []string{}, familyID: "fid", clientID: "cid", userAssertionHash: "user_assertion_hash", }, err: true, }, // Cannot guarentee that without an alternate cid which token will be // returned deterministically when HID, CID, and env match. { name: "Multiple items in cache, without a fid and with alternate CID, token with alternate CID is returned", contract: &InMemoryContract{ RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{ "rt_partition": { testRefreshTokenWoFID.Key(): testRefreshTokenWoFID, testRefreshTokenWoFID.Key(): testRefreshTokenWithFID, testRefreshTokenWoFIDAltCID.Key(): testRefreshTokenWoFIDAltCID, }, }, }, args: args{ envAliases: []string{}, familyID: "", clientID: "cid2", userAssertionHash: "user_assertion_hash", }, err: true, }, } m := &PartitionedManager{} for _, test := range tests { m.update(test.contract) got, err := m.readRefreshToken(test.args.envAliases, test.args.familyID, test.args.clientID, test.args.userAssertionHash, "rt_partition") switch { case test.err && err == nil: t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): got err == nil, want err != nil", test.name) continue case !test.err && err != nil: t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): got err == %s, want err == nil", test.name, err) continue case err != nil: continue } if diff := pretty.Compare(test.want, got); diff != "" { t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): -want/+got:\n%s", test.name, diff) } } } func TestWritePartitionedRefreshToken(t *testing.T) { storageManager := newPartitionedManagerForTest(nil) testRefreshToken := accesstokens.NewRefreshToken( "hid", "env", "cid", "secret", "fid", ) testRefreshToken.UserAssertionHash = "user_assertion_hash" key := testRefreshToken.Key() err := storageManager.writeRefreshToken(testRefreshToken, "rt_partition") if err != nil { t.Errorf("Error should be nil, but it is %v", err) } if !reflect.DeepEqual(storageManager.contract.RefreshTokensPartition["rt_partition"][key], testRefreshToken) { t.Errorf("Added refresh token %v differs from expected refresh token %v", storageManager.contract.RefreshTokensPartition["rt_partition"][key], testRefreshToken) } } microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/storage.go000066400000000000000000000420271442026362400323370ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // Package storage holds all cached token information for MSAL. This storage can be // augmented with third-party extensions to provide persistent storage. In that case, // reads and writes in upper packages will call Marshal() to take the entire in-memory // representation and write it to storage and Unmarshal() to update the entire in-memory // storage with what was in the persistent storage. The persistent storage can only be // accessed in this way because multiple MSAL clients written in multiple languages can // access the same storage and must adhere to the same method that was defined // previously. package storage import ( "context" "errors" "fmt" "strings" "sync" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) // aadInstanceDiscoveryer allows faking in tests. // It is implemented in production by ops/authority.Client type aadInstanceDiscoveryer interface { AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) } // TokenResponse mimics a token response that was pulled from the cache. type TokenResponse struct { RefreshToken accesstokens.RefreshToken IDToken IDToken // *Credential AccessToken AccessToken Account shared.Account } // Manager is an in-memory cache of access tokens, accounts and meta data. This data is // updated on read/write calls. Unmarshal() replaces all data stored here with whatever // was given to it on each call. type Manager struct { contract *Contract contractMu sync.RWMutex requests aadInstanceDiscoveryer // *oauth.Token aadCacheMu sync.RWMutex aadCache map[string]authority.InstanceDiscoveryMetadata } // New is the constructor for Manager. func New(requests *oauth.Client) *Manager { m := &Manager{requests: requests, aadCache: make(map[string]authority.InstanceDiscoveryMetadata)} m.contract = NewContract() return m } func checkAlias(alias string, aliases []string) bool { for _, v := range aliases { if alias == v { return true } } return false } func isMatchingScopes(scopesOne []string, scopesTwo string) bool { newScopesTwo := strings.Split(scopesTwo, scopeSeparator) scopeCounter := 0 for _, scope := range scopesOne { for _, otherScope := range newScopesTwo { if strings.EqualFold(scope, otherScope) { scopeCounter++ continue } } } return scopeCounter == len(scopesOne) } // Read reads a storage token from the cache if it exists. func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) { tr := TokenResponse{} homeAccountID := authParameters.HomeAccountID realm := authParameters.AuthorityInfo.Tenant clientID := authParameters.ClientID scopes := authParameters.Scopes // fetch metadata if instanceDiscovery is enabled aliases := []string{authParameters.AuthorityInfo.Host} if !authParameters.AuthorityInfo.InstanceDiscoveryDisabled { metadata, err := m.getMetadataEntry(ctx, authParameters.AuthorityInfo) if err != nil { return TokenResponse{}, err } aliases = metadata.Aliases } accessToken := m.readAccessToken(homeAccountID, aliases, realm, clientID, scopes) tr.AccessToken = accessToken if homeAccountID == "" { // caller didn't specify a user, so there's no reason to search for an ID or refresh token return tr, nil } // errors returned by read* methods indicate a cache miss and are therefore non-fatal. We continue populating // TokenResponse fields so that e.g. lack of an ID token doesn't prevent the caller from receiving a refresh token. idToken, err := m.readIDToken(homeAccountID, aliases, realm, clientID) if err == nil { tr.IDToken = idToken } if appMetadata, err := m.readAppMetaData(aliases, clientID); err == nil { // we need the family ID to identify the correct refresh token, if any familyID := appMetadata.FamilyID refreshToken, err := m.readRefreshToken(homeAccountID, aliases, familyID, clientID) if err == nil { tr.RefreshToken = refreshToken } } account, err := m.readAccount(homeAccountID, aliases, realm) if err == nil { tr.Account = account } return tr, nil } const scopeSeparator = " " // Write writes a token response to the cache and returns the account information the token is stored with. func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) { authParameters.HomeAccountID = tokenResponse.ClientInfo.HomeAccountID() homeAccountID := authParameters.HomeAccountID environment := authParameters.AuthorityInfo.Host realm := authParameters.AuthorityInfo.Tenant clientID := authParameters.ClientID target := strings.Join(tokenResponse.GrantedScopes.Slice, scopeSeparator) cachedAt := time.Now() var account shared.Account if len(tokenResponse.RefreshToken) > 0 { refreshToken := accesstokens.NewRefreshToken(homeAccountID, environment, clientID, tokenResponse.RefreshToken, tokenResponse.FamilyID) if err := m.writeRefreshToken(refreshToken); err != nil { return account, err } } if len(tokenResponse.AccessToken) > 0 { accessToken := NewAccessToken( homeAccountID, environment, realm, clientID, cachedAt, tokenResponse.ExpiresOn.T, tokenResponse.ExtExpiresOn.T, target, tokenResponse.AccessToken, ) // Since we have a valid access token, cache it before moving on. if err := accessToken.Validate(); err == nil { if err := m.writeAccessToken(accessToken); err != nil { return account, err } } } idTokenJwt := tokenResponse.IDToken if !idTokenJwt.IsZero() { idToken := NewIDToken(homeAccountID, environment, realm, clientID, idTokenJwt.RawToken) if err := m.writeIDToken(idToken); err != nil { return shared.Account{}, err } localAccountID := idTokenJwt.LocalAccountID() authorityType := authParameters.AuthorityInfo.AuthorityType preferredUsername := idTokenJwt.UPN if idTokenJwt.PreferredUsername != "" { preferredUsername = idTokenJwt.PreferredUsername } account = shared.NewAccount( homeAccountID, environment, realm, localAccountID, authorityType, preferredUsername, ) if err := m.writeAccount(account); err != nil { return shared.Account{}, err } } AppMetaData := NewAppMetaData(tokenResponse.FamilyID, clientID, environment) if err := m.writeAppMetaData(AppMetaData); err != nil { return shared.Account{}, err } return account, nil } func (m *Manager) getMetadataEntry(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { md, err := m.aadMetadataFromCache(ctx, authorityInfo) if err != nil { // not in the cache, retrieve it md, err = m.aadMetadata(ctx, authorityInfo) } return md, err } func (m *Manager) aadMetadataFromCache(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { m.aadCacheMu.RLock() defer m.aadCacheMu.RUnlock() metadata, ok := m.aadCache[authorityInfo.Host] if ok { return metadata, nil } return metadata, errors.New("not found") } func (m *Manager) aadMetadata(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { m.aadCacheMu.Lock() defer m.aadCacheMu.Unlock() discoveryResponse, err := m.requests.AADInstanceDiscovery(ctx, authorityInfo) if err != nil { return authority.InstanceDiscoveryMetadata{}, err } for _, metadataEntry := range discoveryResponse.Metadata { for _, aliasedAuthority := range metadataEntry.Aliases { m.aadCache[aliasedAuthority] = metadataEntry } } if _, ok := m.aadCache[authorityInfo.Host]; !ok { m.aadCache[authorityInfo.Host] = authority.InstanceDiscoveryMetadata{ PreferredNetwork: authorityInfo.Host, PreferredCache: authorityInfo.Host, } } return m.aadCache[authorityInfo.Host], nil } func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, clientID string, scopes []string) AccessToken { m.contractMu.RLock() defer m.contractMu.RUnlock() // TODO: linear search (over a map no less) is slow for a large number (thousands) of tokens. // this shows up as the dominating node in a profile. for real-world scenarios this likely isn't // an issue, however if it does become a problem then we know where to look. for _, at := range m.contract.AccessTokens { if at.HomeAccountID == homeID && at.Realm == realm && at.ClientID == clientID { if checkAlias(at.Environment, envAliases) { if isMatchingScopes(scopes, at.Scopes) { return at } } } } return AccessToken{} } func (m *Manager) writeAccessToken(accessToken AccessToken) error { m.contractMu.Lock() defer m.contractMu.Unlock() key := accessToken.Key() m.contract.AccessTokens[key] = accessToken return nil } func (m *Manager) readRefreshToken(homeID string, envAliases []string, familyID, clientID string) (accesstokens.RefreshToken, error) { byFamily := func(rt accesstokens.RefreshToken) bool { return matchFamilyRefreshToken(rt, homeID, envAliases) } byClient := func(rt accesstokens.RefreshToken) bool { return matchClientIDRefreshToken(rt, homeID, envAliases, clientID) } var matchers []func(rt accesstokens.RefreshToken) bool if familyID == "" { matchers = []func(rt accesstokens.RefreshToken) bool{ byClient, byFamily, } } else { matchers = []func(rt accesstokens.RefreshToken) bool{ byFamily, byClient, } } // TODO(keegan): All the tests here pass, but Bogdan says this is // more complicated. I'm opening an issue for this to have him // review the tests and suggest tests that would break this so // we can re-write against good tests. His comments as follow: // The algorithm is a bit more complex than this, I assume there are some tests covering everything. I would keep the order as is. // The algorithm is: // If application is NOT part of the family, search by client_ID // If app is part of the family or if we DO NOT KNOW if it's part of the family, search by family ID, then by client_id (we will know if an app is part of the family after the first token response). // https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/311fe8b16e7c293462806f397e189a6aa1159769/src/client/Microsoft.Identity.Client/Internal/Requests/Silent/CacheSilentStrategy.cs#L95 m.contractMu.RLock() defer m.contractMu.RUnlock() for _, matcher := range matchers { for _, rt := range m.contract.RefreshTokens { if matcher(rt) { return rt, nil } } } return accesstokens.RefreshToken{}, fmt.Errorf("refresh token not found") } func matchFamilyRefreshToken(rt accesstokens.RefreshToken, homeID string, envAliases []string) bool { return rt.HomeAccountID == homeID && checkAlias(rt.Environment, envAliases) && rt.FamilyID != "" } func matchClientIDRefreshToken(rt accesstokens.RefreshToken, homeID string, envAliases []string, clientID string) bool { return rt.HomeAccountID == homeID && checkAlias(rt.Environment, envAliases) && rt.ClientID == clientID } func (m *Manager) writeRefreshToken(refreshToken accesstokens.RefreshToken) error { key := refreshToken.Key() m.contractMu.Lock() defer m.contractMu.Unlock() m.contract.RefreshTokens[key] = refreshToken return nil } func (m *Manager) readIDToken(homeID string, envAliases []string, realm, clientID string) (IDToken, error) { m.contractMu.RLock() defer m.contractMu.RUnlock() for _, idt := range m.contract.IDTokens { if idt.HomeAccountID == homeID && idt.Realm == realm && idt.ClientID == clientID { if checkAlias(idt.Environment, envAliases) { return idt, nil } } } return IDToken{}, fmt.Errorf("token not found") } func (m *Manager) writeIDToken(idToken IDToken) error { key := idToken.Key() m.contractMu.Lock() defer m.contractMu.Unlock() m.contract.IDTokens[key] = idToken return nil } func (m *Manager) AllAccounts() []shared.Account { m.contractMu.RLock() defer m.contractMu.RUnlock() var accounts []shared.Account for _, v := range m.contract.Accounts { accounts = append(accounts, v) } return accounts } func (m *Manager) Account(homeAccountID string) shared.Account { m.contractMu.RLock() defer m.contractMu.RUnlock() for _, v := range m.contract.Accounts { if v.HomeAccountID == homeAccountID { return v } } return shared.Account{} } func (m *Manager) readAccount(homeAccountID string, envAliases []string, realm string) (shared.Account, error) { m.contractMu.RLock() defer m.contractMu.RUnlock() // You might ask why, if cache.Accounts is a map, we would loop through all of these instead of using a key. // We only use a map because the storage contract shared between all language implementations says use a map. // We can't change that. The other is because the keys are made using a specific "env", but here we are allowing // a match in multiple envs (envAlias). That means we either need to hash each possible keyand do the lookup // or just statically check. Since the design is to have a storage.Manager per user, the amount of keys stored // is really low (say 2). Each hash is more expensive than the entire iteration. for _, acc := range m.contract.Accounts { if acc.HomeAccountID == homeAccountID && checkAlias(acc.Environment, envAliases) && acc.Realm == realm { return acc, nil } } return shared.Account{}, fmt.Errorf("account not found") } func (m *Manager) writeAccount(account shared.Account) error { key := account.Key() m.contractMu.Lock() defer m.contractMu.Unlock() m.contract.Accounts[key] = account return nil } func (m *Manager) readAppMetaData(envAliases []string, clientID string) (AppMetaData, error) { m.contractMu.RLock() defer m.contractMu.RUnlock() for _, app := range m.contract.AppMetaData { if checkAlias(app.Environment, envAliases) && app.ClientID == clientID { return app, nil } } return AppMetaData{}, fmt.Errorf("not found") } func (m *Manager) writeAppMetaData(AppMetaData AppMetaData) error { key := AppMetaData.Key() m.contractMu.Lock() defer m.contractMu.Unlock() m.contract.AppMetaData[key] = AppMetaData return nil } // RemoveAccount removes all the associated ATs, RTs and IDTs from the cache associated with this account. func (m *Manager) RemoveAccount(account shared.Account, clientID string) { m.removeRefreshTokens(account.HomeAccountID, account.Environment, clientID) m.removeAccessTokens(account.HomeAccountID, account.Environment) m.removeIDTokens(account.HomeAccountID, account.Environment) m.removeAccounts(account.HomeAccountID, account.Environment) } func (m *Manager) removeRefreshTokens(homeID string, env string, clientID string) { m.contractMu.Lock() defer m.contractMu.Unlock() for key, rt := range m.contract.RefreshTokens { // Check for RTs associated with the account. if rt.HomeAccountID == homeID && rt.Environment == env { // Do RT's app ownership check as a precaution, in case family apps // and 3rd-party apps share same token cache, although they should not. if rt.ClientID == clientID || rt.FamilyID != "" { delete(m.contract.RefreshTokens, key) } } } } func (m *Manager) removeAccessTokens(homeID string, env string) { m.contractMu.Lock() defer m.contractMu.Unlock() for key, at := range m.contract.AccessTokens { // Remove AT's associated with the account if at.HomeAccountID == homeID && at.Environment == env { // # To avoid the complexity of locating sibling family app's AT, we skip AT's app ownership check. // It means ATs for other apps will also be removed, it is OK because: // non-family apps are not supposed to share token cache to begin with; // Even if it happens, we keep other app's RT already, so SSO still works. delete(m.contract.AccessTokens, key) } } } func (m *Manager) removeIDTokens(homeID string, env string) { m.contractMu.Lock() defer m.contractMu.Unlock() for key, idt := range m.contract.IDTokens { // Remove ID tokens associated with the account. if idt.HomeAccountID == homeID && idt.Environment == env { delete(m.contract.IDTokens, key) } } } func (m *Manager) removeAccounts(homeID string, env string) { m.contractMu.Lock() defer m.contractMu.Unlock() for key, acc := range m.contract.Accounts { // Remove the specified account. if acc.HomeAccountID == homeID && acc.Environment == env { delete(m.contract.Accounts, key) } } } // update updates the internal cache object. This is for use in tests, other uses are not // supported. func (m *Manager) update(cache *Contract) { m.contractMu.Lock() defer m.contractMu.Unlock() m.contract = cache } // Marshal implements cache.Marshaler. func (m *Manager) Marshal() ([]byte, error) { m.contractMu.RLock() defer m.contractMu.RUnlock() return json.Marshal(m.contract) } // Unmarshal implements cache.Unmarshaler. func (m *Manager) Unmarshal(b []byte) error { m.contractMu.Lock() defer m.contractMu.Unlock() contract := NewContract() err := json.Unmarshal(b, contract) if err != nil { return err } m.contract = contract return nil } microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/storage_test.go000066400000000000000000000763251442026362400334060ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package storage import ( "context" "errors" "os" "reflect" "sort" "testing" "time" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" "github.com/kylelemons/godebug/pretty" ) const ( testFile = "test_serialized_cache.json" defaultEnvironment = "login.windows.net" defaultHID = "uid.utid" defaultRealm = "contoso" defaultScopes = "s2 s1 s3" defaultClientID = "my_client_id" accessTokenSecret = "an access token" rtSecret = "a refresh token" idCred = "IdToken" idSecret = "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature" accUser = "John Doe" accLID = "object1234" accAuth = "MSSTS" ) var ( atCached = time.Unix(1000, 0) atExpires = time.Unix(4600, 0) ) func newForTest(authorityClient aadInstanceDiscoveryer) *Manager { m := &Manager{requests: authorityClient, aadCache: make(map[string]authority.InstanceDiscoveryMetadata)} m.contract = NewContract() return m } type fakeDiscoveryResponser struct { err bool ret authority.InstanceDiscoveryResponse } func (f *fakeDiscoveryResponser) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) { if f.err { return authority.InstanceDiscoveryResponse{}, errors.New("error") } return f.ret, nil } func TestCheckAlias(t *testing.T) { aliases := []string{"testOne", "testTwo", "testThree"} aliasOne := "noTest" aliasTwo := "testOne" if checkAlias(aliasOne, aliases) { t.Errorf("%v isn't supposed to be in %v", aliasOne, aliases) } if !checkAlias(aliasTwo, aliases) { t.Errorf("%v is supposed to be in %v", aliasTwo, aliases) } } func TestIsMatchingScopes(t *testing.T) { scopesOne := []string{"user.read", "openid", "user.write"} scopesTwo := "openid user.write user.read" if !isMatchingScopes(scopesOne, scopesTwo) { t.Fatalf("Scopes %v and %v are supposed to be the same", scopesOne, scopesTwo) } scopesUpperCase := "openid User.Write User.Read" if !isMatchingScopes(scopesOne, scopesUpperCase) { t.Fatalf("Scopes %v and %v are supposed to be the same as the comparison is case insensitive", scopesOne, scopesUpperCase) } errorScopes := "openid user.read hello" if isMatchingScopes(scopesOne, errorScopes) { t.Fatalf("Scopes %v and %v are not supposed to be the same", scopesOne, errorScopes) } } func TestAllAccounts(t *testing.T) { testAccOne := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") testAccTwo := shared.NewAccount("HID", "ENV", "REALM", "LID", accAuth, "USERNAME") cache := &Contract{ Accounts: map[string]shared.Account{ testAccOne.Key(): testAccOne, testAccTwo.Key(): testAccTwo, }, } storageManager := Manager{} storageManager.update(cache) actualAccounts := storageManager.AllAccounts() // AllAccounts() is unstable in that the order can be reversed between calls. // This fixes that. sort.Slice( actualAccounts, func(i, j int) bool { return actualAccounts[i].HomeAccountID > actualAccounts[j].HomeAccountID }, ) expectedAccounts := []shared.Account{testAccOne, testAccTwo} if diff := pretty.Compare(expectedAccounts, actualAccounts); diff != "" { t.Errorf("Actual accounts differ from expected accounts: -want/+got:\n%s", diff) } } func TestReadAccessToken(t *testing.T) { now := time.Now() testAccessToken := NewAccessToken( "hid", "env", "realm", "cid", now, now, now, "openid user.read", "secret", ) cache := &Contract{ AccessTokens: map[string]AccessToken{ testAccessToken.Key(): testAccessToken, }, } storageManager := newForTest(nil) storageManager.update(cache) retAccessToken := storageManager.readAccessToken( "hid", []string{"hello", "env", "test"}, "realm", "cid", []string{"user.read", "openid"}, ) if diff := pretty.Compare(testAccessToken, retAccessToken); diff != "" { t.Fatalf("Returned access token is not the same as expected access token: -want/+got:\n%s", diff) } retAccessToken = storageManager.readAccessToken( "this_should_break_it", []string{"hello", "env", "test"}, "realm", "cid", []string{"user.read", "openid"}, ) if !reflect.ValueOf(retAccessToken).IsZero() { t.Fatal("expected to find no access token") } } func TestWriteAccessToken(t *testing.T) { now := time.Now() storageManager := newForTest(nil) testAccessToken := NewAccessToken( "hid", "env", "realm", "cid", now, now, now, "openid", "secret", ) key := testAccessToken.Key() err := storageManager.writeAccessToken(testAccessToken) if err != nil { t.Fatalf("TestwriteAccessToken: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAccessToken, storageManager.contract.AccessTokens[key]); diff != "" { t.Errorf("TestwriteAccessToken: -want/+got:\n%s", diff) } } func TestReadAccount(t *testing.T) { testAcc := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") cache := &Contract{ Accounts: map[string]shared.Account{ testAcc.Key(): testAcc, }, } storageManager := newForTest(nil) storageManager.update(cache) returnedAccount, err := storageManager.readAccount("hid", []string{"hello", "env", "test"}, "realm") if err != nil { t.Fatalf("TestreadAccount: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAcc, returnedAccount); diff != "" { t.Errorf("TestreadAccount: -want/+got:\n%s", diff) } _, err = storageManager.readAccount("this_should_break_it", []string{"hello", "env", "test"}, "realm") if err == nil { t.Errorf("TestreadAccount: got err == nil, want err != nil") } } func TestWriteAccount(t *testing.T) { storageManager := newForTest(nil) testAcc := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") key := testAcc.Key() err := storageManager.writeAccount(testAcc) if err != nil { t.Fatalf("TestwriteAccount: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAcc, storageManager.contract.Accounts[key]); diff != "" { t.Errorf("TestwriteAccount: -want/+got:\n%s", diff) } } func TestReadAppMetaData(t *testing.T) { testAppMeta := NewAppMetaData("fid", "cid", "env") cache := &Contract{ AppMetaData: map[string]AppMetaData{ testAppMeta.Key(): testAppMeta, }, } storageManager := newForTest(nil) storageManager.update(cache) returnedAppMeta, err := storageManager.readAppMetaData([]string{"hello", "test", "env"}, "cid") if err != nil { t.Fatalf("TestreadAppMetaData(readAppMetaData): got err == %s, want err == nil", err) } if diff := pretty.Compare(testAppMeta, returnedAppMeta); diff != "" { t.Fatalf("TestreadAppMetaData(readAppMetaData): -want/+got:\n%s", diff) } _, err = storageManager.readAppMetaData([]string{"hello", "test", "env"}, "break_this") if err == nil { t.Fatalf("TestreadAppMetaData(bad readAppMetaData): got err == nil, want err != nil") } } func TestWriteAppMetaData(t *testing.T) { storageManager := newForTest(nil) testAppMeta := NewAppMetaData("fid", "cid", "env") key := testAppMeta.Key() err := storageManager.writeAppMetaData(testAppMeta) if err != nil { t.Fatalf("TestwriteAppMetaData: got err == %s, want err == nil", err) } if diff := pretty.Compare(testAppMeta, storageManager.contract.AppMetaData[key]); diff != "" { t.Errorf("TestwriteAppMetaData: -want/+got:\n%s", diff) } } func TestReadIDToken(t *testing.T) { testIDToken := NewIDToken( "hid", "env", "realm", "cid", "secret", ) cache := &Contract{ IDTokens: map[string]IDToken{ testIDToken.Key(): testIDToken, }, } storageManager := newForTest(nil) storageManager.update(cache) returnedIDToken, err := storageManager.readIDToken( "hid", []string{"hello", "env", "test"}, "realm", "cid", ) if err != nil { panic(err) } if diff := pretty.Compare(testIDToken, returnedIDToken); diff != "" { t.Fatalf("TestreadIDToken(good token): -want/+got:\n%s", diff) } _, err = storageManager.readIDToken( "this_should_break_it", []string{"hello", "env", "test"}, "realm", "cid", ) if err == nil { t.Errorf("TestreadIDToken(bad token): got err == nil, want err != nil") } } func TestWriteIDToken(t *testing.T) { storageManager := newForTest(nil) testIDToken := NewIDToken( "hid", "env", "realm", "cid", "secret", ) key := testIDToken.Key() err := storageManager.writeIDToken(testIDToken) if err != nil { t.Fatalf("TestwriteIDToken: got err == %s, want err == nil", err) } if diff := pretty.Compare(testIDToken, storageManager.contract.IDTokens[key]); diff != "" { t.Errorf("TestwriteIDToken: -want/+got:\n%s", diff) } } func TestDefaultStorageManagerreadRefreshToken(t *testing.T) { testRefreshTokenWithFID := accesstokens.NewRefreshToken( "hid", "env", "cid", "secret", "fid", ) testRefreshTokenWoFID := accesstokens.NewRefreshToken( "hid", "env", "cid", "secret", "", ) testRefreshTokenWoFIDAltCID := accesstokens.NewRefreshToken( "hid", "env", "cid2", "secret", "", ) type args struct { homeAccountID string envAliases []string familyID string clientID string } tests := []struct { name string contract *Contract args args want accesstokens.RefreshToken err bool }{ { name: "Token without fid, read with fid, cid, env, and hid", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWoFID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{"test", "env", "hello"}, familyID: "fid", clientID: "cid", }, want: testRefreshTokenWoFID, }, { name: "Token without fid, read with cid, env, and hid", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWoFID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{"test", "env", "hello"}, familyID: "", clientID: "cid", }, want: testRefreshTokenWoFID, }, { name: "Token without fid, verify CID is required", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWoFID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{"test", "env", "hello"}, familyID: "", clientID: "", }, err: true, }, { name: "Token without fid, Verify env is required", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWoFID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{}, familyID: "", clientID: "", }, err: true, }, { name: "Token without fid, read with fid, cid, env, and hid", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWithFID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{"test", "env", "hello"}, familyID: "fid", clientID: "cid", }, want: testRefreshTokenWithFID, }, { name: "Token with fid, read with cid, env, and hid", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWithFID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{"test", "env", "hello"}, familyID: "", clientID: "cid", }, want: testRefreshTokenWithFID, }, { name: "Token with fid, verify CID is not required", // match on hid, env, and has fid contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWithFID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{"test", "env", "hello"}, familyID: "", clientID: "", }, want: testRefreshTokenWithFID, }, { name: "Token with fid, Verify env is required", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWithFID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{}, familyID: "", clientID: "", }, err: true, }, { name: "Multiple items in cache, given a fid, item with fid will be returned", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWoFID, testRefreshTokenWithFID.Key(): testRefreshTokenWithFID, testRefreshTokenWoFIDAltCID.Key(): testRefreshTokenWoFIDAltCID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{}, familyID: "fid", clientID: "cid", }, err: true, }, // Cannot guarentee that without an alternate cid which token will be // returned deterministically when HID, CID, and env match. { name: "Multiple items in cache, without a fid and with alternate CID, token with alternate CID is returned", contract: &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshTokenWoFID.Key(): testRefreshTokenWoFID, testRefreshTokenWithFID.Key(): testRefreshTokenWithFID, testRefreshTokenWoFIDAltCID.Key(): testRefreshTokenWoFIDAltCID, }, }, args: args{ homeAccountID: "hid", envAliases: []string{}, familyID: "", clientID: "cid2", }, err: true, }, } m := &Manager{} for _, test := range tests { m.update(test.contract) got, err := m.readRefreshToken(test.args.homeAccountID, test.args.envAliases, test.args.familyID, test.args.clientID) switch { case test.err && err == nil: t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): got err == nil, want err != nil", test.name) continue case !test.err && err != nil: t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): got err == %s, want err == nil", test.name, err) continue case err != nil: continue } if diff := pretty.Compare(test.want, got); diff != "" { t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): -want/+got:\n%s", test.name, diff) } } } func TestWriteRefreshToken(t *testing.T) { storageManager := newForTest(nil) testRefreshToken := accesstokens.NewRefreshToken( "hid", "env", "cid", "secret", "fid", ) key := testRefreshToken.Key() err := storageManager.writeRefreshToken(testRefreshToken) if err != nil { t.Errorf("Error should be nil, but it is %v", err) } if !reflect.DeepEqual(storageManager.contract.RefreshTokens[key], testRefreshToken) { t.Errorf("Added refresh token %v differs from expected refresh token %v", storageManager.contract.RefreshTokens[key], testRefreshToken) } } func TestStorageManagerSerialize(t *testing.T) { contract := &Contract{ AccessTokens: map[string]AccessToken{ "an-entry": { AdditionalFields: map[string]interface{}{ "foo": "bar", }, }, "uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": { Environment: defaultEnvironment, CredentialType: "AccessToken", Secret: accessTokenSecret, Realm: defaultRealm, Scopes: defaultScopes, ClientID: defaultClientID, CachedAt: internalTime.Unix{T: atCached}, HomeAccountID: defaultHID, ExpiresOn: internalTime.Unix{T: atExpires}, ExtendedExpiresOn: internalTime.Unix{T: atExpires}, }, }, RefreshTokens: map[string]accesstokens.RefreshToken{ "uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": { Target: defaultScopes, Environment: defaultEnvironment, CredentialType: "RefreshToken", Secret: rtSecret, ClientID: defaultClientID, HomeAccountID: defaultHID, }, }, IDTokens: map[string]IDToken{ "uid.utid-login.windows.net-idtoken-my_client_id-contoso-": { Realm: defaultRealm, Environment: defaultEnvironment, CredentialType: idCred, Secret: idSecret, ClientID: defaultClientID, HomeAccountID: defaultHID, }, }, Accounts: map[string]shared.Account{ "uid.utid-login.windows.net-contoso": { PreferredUsername: accUser, LocalAccountID: accLID, Realm: defaultRealm, Environment: defaultEnvironment, HomeAccountID: defaultHID, AuthorityType: accAuth, }, }, AppMetaData: map[string]AppMetaData{ "AppMetadata-login.windows.net-my_client_id": { Environment: defaultEnvironment, FamilyID: "", ClientID: defaultClientID, }, }, } manager := newForTest(nil) manager.update(contract) _, err := manager.Marshal() if err != nil { t.Errorf("Error should be nil; instead it is %v", err) } } func TestUnmarshal(t *testing.T) { manager := newForTest(nil) b, err := os.ReadFile(testFile) if err != nil { panic(err) } err = manager.Unmarshal(b) if err != nil { t.Fatalf("TestUnmarshal(unmarshal): got err == %s, want err == nil", err) } actualAccessTokenSecret := manager.contract.AccessTokens["uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3"].Secret if accessTokenSecret != actualAccessTokenSecret { t.Errorf("TestUnmarshal(access token secret):got %q, want %q", actualAccessTokenSecret, accessTokenSecret) } actualRTSecret := manager.contract.RefreshTokens["uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3"].Secret if diff := pretty.Compare(rtSecret, actualRTSecret); diff != "" { t.Errorf("TestUnmarshal(refresh token secret): -want/+got:\n%s", diff) } actualIDSecret := manager.contract.IDTokens["uid.utid-login.windows.net-idtoken-my_client_id-contoso-"].Secret if diff := pretty.Compare(idSecret, actualIDSecret); diff != "" { t.Errorf("TestUnmarshal(id secret): -want/+got:\n%s", diff) } actualUser := manager.contract.Accounts["uid.utid-login.windows.net-contoso"].PreferredUsername if diff := pretty.Compare(actualUser, accUser); diff != "" { t.Errorf("TestUnmarshal(actula user): -want/+got:\n%s", diff) } if manager.contract.AppMetaData["AppMetadata-login.windows.net-my_client_id"].FamilyID != "" { t.Errorf("TestUnmarshal(app metadata family id): got %q, want empty string", manager.contract.AppMetaData["AppMetadata-login.windows.net-my_client_id"].FamilyID) } } func TestIsAccessTokenValid(t *testing.T) { cachedAt := time.Now() badCachedAt := time.Now().Add(500 * time.Second) expiresOn := time.Now().Add(1000 * time.Second) badExpiresOn := time.Now().Add(200 * time.Second) extended := time.Now() tests := []struct { desc string token AccessToken err bool }{ { desc: "Success", token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, expiresOn, extended, "openid", "secret"), }, { desc: "ExpiresOnUnixTimestamp has expired", token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, badExpiresOn, extended, "openid", "secret"), err: true, }, { desc: "Success", token: NewAccessToken("hid", "env", "realm", "cid", badCachedAt, expiresOn, extended, "openid", "secret"), err: true, }, } for _, test := range tests { err := test.token.Validate() switch { case err == nil && test.err: t.Errorf("TestIsAccessTokenValid(%s): got err == nil, want err != nil", test.desc) case err != nil && !test.err: t.Errorf("TestIsAccessTokenValid(%s): got err == %s, want err == nil", test.desc, err) } } } func TestRead(t *testing.T) { accessTokenCacheItem := NewAccessToken( "hid", "env", "realm", "cid", time.Now(), time.Now().Add(1000*time.Second), time.Now(), "openid profile", "secret", ) testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret") testAppMeta := NewAppMetaData("fid", "cid", "env") testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid") testAccount := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") contract := &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshToken.Key(): testRefreshToken, }, Accounts: map[string]shared.Account{ testAccount.Key(): testAccount, }, AppMetaData: map[string]AppMetaData{ testAppMeta.Key(): testAppMeta, }, IDTokens: map[string]IDToken{ testIDToken.Key(): testIDToken, }, AccessTokens: map[string]AccessToken{ accessTokenCacheItem.Key(): accessTokenCacheItem, }, } authInfo := authority.Info{ Host: "env", Tenant: "realm", } authParameters := authority.AuthParams{ HomeAccountID: "hid", AuthorityInfo: authInfo, ClientID: "cid", Scopes: []string{"openid", "profile"}, } tests := []struct { desc string discRespErr bool discResp authority.InstanceDiscoveryResponse err bool want TokenResponse }{ { desc: "Error: AAD Discovery Fails", discRespErr: true, err: true, }, { desc: "Success", discResp: authority.InstanceDiscoveryResponse{ TenantDiscoveryEndpoint: "tenant", Metadata: []authority.InstanceDiscoveryMetadata{ { Aliases: []string{"env", "alias2"}, }, { Aliases: []string{"alias3", "alias4"}, }, }, }, want: TokenResponse{ AccessToken: accessTokenCacheItem, RefreshToken: testRefreshToken, IDToken: testIDToken, Account: testAccount, }, }, } for _, test := range tests { responder := &fakeDiscoveryResponser{err: test.discRespErr, ret: test.discResp} manager := newForTest(responder) manager.update(contract) got, err := manager.Read(context.Background(), authParameters) switch { case err == nil && test.err: t.Errorf("TestRead(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestRead(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if diff := pretty.Compare(test.want, got); diff != "" { t.Errorf("TestRead(%s): -want/+got:\n%s", test.desc, diff) } } } func removeSubSeconds(t time.Time) time.Time { t = t.Add(-time.Duration(t.Nanosecond())) return t } func TestWrite(t *testing.T) { now := removeSubSeconds(time.Now().UTC()) cacheManager := newForTest(nil) clientInfo := accesstokens.ClientInfo{ UID: "testUID", UTID: "testUtid", } idToken := accesstokens.IDToken{ RawToken: "idToken", Oid: "lid", PreferredUsername: "username", } expiresOn := internalTime.DurationTime{T: now.Add(1000 * time.Second)} tokenResponse := accesstokens.TokenResponse{ AccessToken: "accessToken", RefreshToken: "refreshToken", IDToken: idToken, FamilyID: "fid", ClientInfo: clientInfo, GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}}, ExpiresOn: expiresOn, ExtExpiresOn: internalTime.DurationTime{T: now}, } authInfo := authority.Info{Host: "env", Tenant: "realm", AuthorityType: accAuth} authParams := authority.AuthParams{ AuthorityInfo: authInfo, ClientID: "cid", } testRefreshToken := accesstokens.NewRefreshToken( "testUID.testUtid", "env", "cid", "refreshToken", "fid", ) AccessToken := NewAccessToken( "testUID.testUtid", "env", "realm", "cid", now, now.Add(1000*time.Second), now, "openid profile", "accessToken", ) testIDToken := NewIDToken( "testUID.testUtid", "env", "realm", "cid", "idToken", ) testAccount := shared.NewAccount("testUID.testUtid", "env", "realm", "lid", accAuth, "username") testAppMeta := NewAppMetaData("fid", "cid", "env") actualAccount, err := cacheManager.Write(authParams, tokenResponse) if err != nil { t.Errorf("Error should be nil; instead, it is %v", err) } if !reflect.DeepEqual(actualAccount, testAccount) { t.Errorf("Actual account %+v differs from expected account %+v", actualAccount, testAccount) } gotRefresh, ok := cacheManager.contract.RefreshTokens[testRefreshToken.Key()] if !ok { t.Fatalf("TestWrite(refresh token): refresh token was not written as expected") } if diff := pretty.Compare(testRefreshToken, gotRefresh); diff != "" { t.Fatalf("TestWrite(refresh token): -want/+got\n%s", diff) } gotAccess, ok := cacheManager.contract.AccessTokens[AccessToken.Key()] if !ok { t.Fatalf("TestWrite(access token): access token was not written as expected") } // CachedAt is generated for this exact moment, not from input. We would need to // fake time.Now() call with a var now = time.Now() in the package in order to // control this or we can just ignore this value. We are going to simply check its // not zero and then zero it for our got/want comparison. if gotAccess.CachedAt.T.IsZero() { t.Fatalf("TestWrite(access token): AccessToken.CachedAt is the zero value, which is incorrect") } gotAccess.CachedAt = internalTime.Unix{} AccessToken.CachedAt = internalTime.Unix{} if diff := pretty.Compare(AccessToken, gotAccess); diff != "" { t.Fatalf("TestWrite(access token): -want/+got\n%s", diff) } gotToken, ok := cacheManager.contract.IDTokens[testIDToken.Key()] if !ok { t.Fatalf("TestWrite(id token): id token was not written as expected") } if diff := pretty.Compare(testIDToken, gotToken); diff != "" { t.Fatalf("TestWrite(id token): -want/+got\n%s", diff) } gotAccount, ok := cacheManager.contract.Accounts[testAccount.Key()] if !ok { t.Fatalf("TestWrite(account): account was not written as expected") } if diff := pretty.Compare(testAccount, gotAccount); diff != "" { t.Fatalf("TestWrite(account): -want/+got\n%s", diff) } gotMeta, ok := cacheManager.contract.AppMetaData[testAppMeta.Key()] if !ok { t.Fatalf("TestWrite(app metadata): metadata was not written as expected") } if diff := pretty.Compare(testAppMeta, gotMeta); diff != "" { t.Fatalf("TestWrite(app metadata): -want/+got\n%s", diff) } } func TestRemoveRefreshTokens(t *testing.T) { storageManager := newForTest(nil) testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid") key := testRefreshToken.Key() contract := &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ key: testRefreshToken, }, } storageManager.update(contract) storageManager.removeRefreshTokens("hid", "env", "cid") if val, ok := storageManager.contract.RefreshTokens[key]; ok { t.Fatalf("TestRemoveRefreshTokens: got refreshToken == %s, want refreshToken == empty", val) } } func TestRemoveAccessTokens(t *testing.T) { now := time.Now() storageManager := newForTest(nil) testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid", "secret") key := testAccessToken.Key() contract := &Contract{ AccessTokens: map[string]AccessToken{ key: testAccessToken, }, } storageManager.update(contract) storageManager.removeAccessTokens("hid", "env") if val, ok := storageManager.contract.AccessTokens[key]; ok { t.Fatalf("TestRemoveAccessTokens: got accessToken == %s, want accessToken == empty", val) } } func TestRemoveIDTokens(t *testing.T) { storageManager := newForTest(nil) testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret") key := testIDToken.Key() contract := &Contract{ IDTokens: map[string]IDToken{ key: testIDToken, }, } storageManager.update(contract) storageManager.removeIDTokens("hid", "env") if val, ok := storageManager.contract.IDTokens[key]; ok { t.Fatalf("TestRemoveIDTokens: got IDToken == %s, want IDToken == empty", val) } } func TestRemoveAccountObject(t *testing.T) { storageManager := newForTest(nil) testAccount := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") key := testAccount.Key() contract := &Contract{ Accounts: map[string]shared.Account{ key: testAccount, }, } storageManager.update(contract) storageManager.removeAccounts("hid", "env") if val, ok := storageManager.contract.Accounts[key]; ok { t.Fatalf("TestRemoveAccountObject: got Account == %s, want Account == empty", val) } } func TestRemoveAccount(t *testing.T) { now := time.Now() testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid profile", "secret") testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret") testAppMeta := NewAppMetaData("fid", "cid", "env") testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid") testAccount := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") contract := &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshToken.Key(): testRefreshToken, }, Accounts: map[string]shared.Account{ testAccount.Key(): testAccount, }, AppMetaData: map[string]AppMetaData{ testAppMeta.Key(): testAppMeta, }, IDTokens: map[string]IDToken{ testIDToken.Key(): testIDToken, }, AccessTokens: map[string]AccessToken{ testAccessToken.Key(): testAccessToken, }, } manager := newForTest(nil) manager.update(contract) manager.RemoveAccount(testAccount, "cid") if val, ok := manager.contract.RefreshTokens[testRefreshToken.Key()]; ok { t.Fatalf("TestRemoveAccount: got refreshToken == %s, want refreshToken == empty", val) } if val, ok := manager.contract.AccessTokens[testAccessToken.Key()]; ok { t.Fatalf("TestRemoveAccount: got accessToken == %s, want accessToken == empty", val) } if val, ok := manager.contract.IDTokens[testIDToken.Key()]; ok { t.Fatalf("TestRemoveAccount: got IDToken == %s, want IDToken == empty", val) } if val, ok := manager.contract.Accounts[testAccount.Key()]; ok { t.Fatalf("TestRemoveAccount: got Account == %s, want Account == empty", val) } } func TestRemoveEmptyAccount(t *testing.T) { now := time.Now() testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid profile", "secret") testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret") testAppMeta := NewAppMetaData("fid", "cid", "env") testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid") testAccount := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username") contract := &Contract{ RefreshTokens: map[string]accesstokens.RefreshToken{ testRefreshToken.Key(): testRefreshToken, }, Accounts: map[string]shared.Account{ testAccount.Key(): testAccount, }, AppMetaData: map[string]AppMetaData{ testAppMeta.Key(): testAppMeta, }, IDTokens: map[string]IDToken{ testIDToken.Key(): testIDToken, }, AccessTokens: map[string]AccessToken{ testAccessToken.Key(): testAccessToken, }, } manager := newForTest(nil) manager.update(contract) manager.RemoveAccount(shared.Account{}, "cid") if _, ok := manager.contract.RefreshTokens[testRefreshToken.Key()]; !ok { t.Fatalf("TestRemoveEmptyAccount: got refreshToken == empty, want refreshToken == %s", testRefreshToken) } if _, ok := manager.contract.AccessTokens[testAccessToken.Key()]; !ok { t.Fatalf("TestRemoveEmptyAccount: got accessToken == empty, want accessToken == %s", testAccessToken) } if _, ok := manager.contract.IDTokens[testIDToken.Key()]; !ok { t.Fatalf("TestRemoveEmptyAccount: got IDToken == empty, want IDToken == %s", testIDToken) } if _, ok := manager.contract.Accounts[testAccount.Key()]; !ok { t.Fatalf("TestRemoveEmptyAccount: got Account == empty, want Account == %s", testAccount) } } test_serialized_cache.json000066400000000000000000000034111442026362400354670ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage{ "Account": { "uid.utid-login.windows.net-contoso": { "username": "John Doe", "local_account_id": "object1234", "realm": "contoso", "environment": "login.windows.net", "home_account_id": "uid.utid", "authority_type": "MSSTS" } }, "RefreshToken": { "uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": { "target": "s2 s1 s3", "environment": "login.windows.net", "credential_type": "RefreshToken", "secret": "a refresh token", "client_id": "my_client_id", "home_account_id": "uid.utid" } }, "AccessToken": { "an-entry": { "foo": "bar" }, "uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": { "environment": "login.windows.net", "credential_type": "AccessToken", "secret": "an access token", "realm": "contoso", "target": "s2 s1 s3", "client_id": "my_client_id", "cached_at": "1000", "home_account_id": "uid.utid", "extended_expires_on": "4600", "expires_on": "4600" } }, "IdToken": { "uid.utid-login.windows.net-idtoken-my_client_id-contoso-": { "realm": "contoso", "environment": "login.windows.net", "credential_type": "IdToken", "secret": "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature", "client_id": "my_client_id", "home_account_id": "uid.utid" } }, "unknownEntity": {"field1":"1","field2":"whats"}, "AppMetadata": { "AppMetadata-login.windows.net-my_client_id": { "environment": "login.windows.net", "client_id": "my_client_id" } } }microsoft-authentication-library-for-go-1.0.0/apps/internal/exported/000077500000000000000000000000001442026362400257775ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/exported/exported.go000066400000000000000000000023151442026362400301610ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // package exported contains internal types that are re-exported from a public package package exported // AssertionRequestOptions has information required to generate a client assertion type AssertionRequestOptions struct { // ClientID identifies the application for which an assertion is requested. Used as the assertion's "iss" and "sub" claims. ClientID string // TokenEndpoint is the intended token endpoint. Used as the assertion's "aud" claim. TokenEndpoint string } // TokenProviderParameters is the authentication parameters passed to token providers type TokenProviderParameters struct { // Claims contains any additional claims requested for the token Claims string // CorrelationID of the authentication request CorrelationID string // Scopes requested for the token Scopes []string // TenantID identifies the tenant in which to authenticate TenantID string } // TokenProviderResult is the authentication result returned by custom token providers type TokenProviderResult struct { // AccessToken is the requested token AccessToken string // ExpiresInSeconds is the lifetime of the token in seconds ExpiresInSeconds int } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/000077500000000000000000000000001442026362400251165ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/json/design.md000066400000000000000000000127051442026362400267160ustar00rootroot00000000000000# JSON Package Design Author: John Doak(jdoak@microsoft.com) ## Why? This project needs a special type of marshal/unmarshal not directly supported by the encoding/json package. The need revolves around a few key wants/needs: - unmarshal and marshal structs representing JSON messages - fields in the messgage not in the struct must be maintained when unmarshalled - those same fields must be marshalled back when encoded again The initial version used map[string]interface{} to put in the keys that were known and then any other keys were put into a field called AdditionalFields. This has a few negatives: - Dual marshaling/unmarshalling is required - Adding a struct field requires manually adding a key by name to be encoded/decoded from the map (which is a loosely coupled construct), which can lead to bugs that aren't detected or have bad side effects - Tests can become quickly disconnected if those keys aren't put in tests as well. So you think you have support working, but you don't. Existing tests were found that didn't test the marshalling output. - There is no enforcement that if AdditionalFields is required on one struct, it should be on all containers that don't have custom marshal/unmarshal. This package aims to support our needs by providing custom Marshal()/Unmarshal() functions. This prevents all the negatives in the initial solution listed above. However, it does add its own negative: - Custom encoding/decoding via reflection is messy (as can be seen in encoding/json itself) Go proverb: Reflection is never clear Suggested reading: https://blog.golang.org/laws-of-reflection ## Important design decisions - We don't want to understand all JSON decoding rules - We don't want to deal with all the quoting, commas, etc on decode - Need support for json.Marshaler/Unmarshaler, so we can support types like time.Time - If struct does not implement json.Unmarshaler, it must have AdditionalFields defined - We only support root level objects that are \*struct or struct To faciliate these goals, we will utilize the json.Encoder and json.Decoder. They provide streaming processing (efficient) and return errors on bad JSON. Support for json.Marshaler/Unmarshaler allows for us to use non-basic types that must be specially encoded/decoded (like time.Time objects). We don't support types that can't customer unmarshal or have AdditionalFields in order to prevent future devs from forgetting that important field and generating bad return values. Support for root level objects of \*struct or struct simply acknowledges the fact that this is designed only for the purposes listed in the Introduction. Outside that (like encoding a lone number) should be done with the regular json package (as it will not have additional fields). We don't support a few things on json supported reference types and structs: - \*map: no need for pointers to maps - \*slice: no need for pointers to slices - any further pointers on struct after \*struct There should never be a need for this in Go. ## Design ## State Machines This uses state machine designs that based upon the Rob Pike talk on lexers and parsers: https://www.youtube.com/watch?v=HxaD_trXwRE This is the most common pattern for state machines in Go and the model to follow closesly when dealing with streaming processing of textual data. Our state machines are based on the type: ```go type stateFn func() (stateFn, error) ``` The state machine itself is simply a struct that has methods that satisfy stateFn. Our state machines have a few standard calls - run(): runs the state machine - start(): always the first stateFn to be called All state machines have the following logic: * run() is called * start() is called and returns the next stateFn or error * stateFn is called - If returned stateFn(next state) is non-nil, call it - If error is non-nil, run() returns the error - If stateFn == nil and err == nil, run() return err == nil ## Supporting types Marshalling/Unmarshalling must support(within top level struct): - struct - \*struct - []struct - []\*struct - []map[string]structContainer - [][]structContainer **Term note:** structContainer == type that has a struct or \*struct inside it We specifically do not support []interface or map[string]interface where the interface value would hold some value with a struct in it. Those will still marshal/unmarshal, but without support for AdditionalFields. ## Marshalling The marshalling design will be based around a statemachine design. The basic logic is as follows: * If struct has custom marshaller, call it and return * If struct has field "AdditionalFields", it must be a map[string]interface{} * If struct does not have "AdditionalFields", give an error * Get struct tag detailing json names to go names, create mapping * For each public field name - Write field name out - If field value is a struct, recursively call our state machine - Otherwise, use the json.Encoder to write out the value ## Unmarshalling The unmarshalling desin is also based around a statemachine design. The basic logic is as follows: * If struct has custom marhaller, call it * If struct has field "AdditionalFields", it must be a map[string]interface{} * Get struct tag detailing json names to go names, create mapping * For each key found - If key exists, - If value is basic type, extract value into struct field using Decoder - If value is struct type, recursively call statemachine - If key doesn't exist, add it to AdditionalFields if it exists using Decoder microsoft-authentication-library-for-go-1.0.0/apps/internal/json/json.go000066400000000000000000000117621442026362400264250ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // Package json provide functions for marshalling an unmarshalling types to JSON. These functions are meant to // be utilized inside of structs that implement json.Unmarshaler and json.Marshaler interfaces. // This package provides the additional functionality of writing fields that are not in the struct when marshalling // to a field called AdditionalFields if that field exists and is a map[string]interface{}. // When marshalling, if the struct has all the same prerequisites, it will uses the keys in AdditionalFields as // extra fields. This package uses encoding/json underneath. package json import ( "bytes" "encoding/json" "fmt" "reflect" "strings" ) const addField = "AdditionalFields" const ( marshalJSON = "MarshalJSON" unmarshalJSON = "UnmarshalJSON" ) var ( leftBrace = []byte("{")[0] rightBrace = []byte("}")[0] comma = []byte(",")[0] leftParen = []byte("[")[0] rightParen = []byte("]")[0] ) var mapStrInterType = reflect.TypeOf(map[string]interface{}{}) // stateFn defines a state machine function. This will be used in all state // machines in this package. type stateFn func() (stateFn, error) // Marshal is used to marshal a type into its JSON representation. It // wraps the stdlib calls in order to marshal a struct or *struct so // that a field called "AdditionalFields" of type map[string]interface{} // with "-" used inside struct tag `json:"-"` can be marshalled as if // they were fields within the struct. func Marshal(i interface{}) ([]byte, error) { buff := bytes.Buffer{} enc := json.NewEncoder(&buff) enc.SetEscapeHTML(false) enc.SetIndent("", "") v := reflect.ValueOf(i) if v.Kind() != reflect.Ptr && v.CanAddr() { v = v.Addr() } err := marshalStruct(v, &buff, enc) if err != nil { return nil, err } return buff.Bytes(), nil } // Unmarshal unmarshals a []byte representing JSON into i, which must be a *struct. In addition, if the struct has // a field called AdditionalFields of type map[string]interface{}, JSON data representing fields not in the struct // will be written as key/value pairs to AdditionalFields. func Unmarshal(b []byte, i interface{}) error { if len(b) == 0 { return nil } jdec := json.NewDecoder(bytes.NewBuffer(b)) jdec.UseNumber() return unmarshalStruct(jdec, i) } // MarshalRaw marshals i into a json.RawMessage. If I cannot be marshalled, // this will panic. This is exposed to help test AdditionalField values // which are stored as json.RawMessage. func MarshalRaw(i interface{}) json.RawMessage { b, err := json.Marshal(i) if err != nil { panic(err) } return json.RawMessage(b) } // isDelim simply tests to see if a json.Token is a delimeter. func isDelim(got json.Token) bool { switch got.(type) { case json.Delim: return true } return false } // delimIs tests got to see if it is want. func delimIs(got json.Token, want rune) bool { switch v := got.(type) { case json.Delim: if v == json.Delim(want) { return true } } return false } // hasMarshalJSON will determine if the value or a pointer to this value has // the MarshalJSON method. func hasMarshalJSON(v reflect.Value) bool { if method := v.MethodByName(marshalJSON); method.Kind() != reflect.Invalid { _, ok := v.Interface().(json.Marshaler) return ok } if v.Kind() == reflect.Ptr { v = v.Elem() } else { if !v.CanAddr() { return false } v = v.Addr() } if method := v.MethodByName(marshalJSON); method.Kind() != reflect.Invalid { _, ok := v.Interface().(json.Marshaler) return ok } return false } // callMarshalJSON will call MarshalJSON() method on the value or a pointer to this value. // This will panic if the method is not defined. func callMarshalJSON(v reflect.Value) ([]byte, error) { if method := v.MethodByName(marshalJSON); method.Kind() != reflect.Invalid { marsh := v.Interface().(json.Marshaler) return marsh.MarshalJSON() } if v.Kind() == reflect.Ptr { v = v.Elem() } else { if v.CanAddr() { v = v.Addr() } } if method := v.MethodByName(unmarshalJSON); method.Kind() != reflect.Invalid { marsh := v.Interface().(json.Marshaler) return marsh.MarshalJSON() } panic(fmt.Sprintf("callMarshalJSON called on type %T that does not have MarshalJSON defined", v.Interface())) } // hasUnmarshalJSON will determine if the value or a pointer to this value has // the UnmarshalJSON method. func hasUnmarshalJSON(v reflect.Value) bool { // You can't unmarshal on a non-pointer type. if v.Kind() != reflect.Ptr { if !v.CanAddr() { return false } v = v.Addr() } if method := v.MethodByName(unmarshalJSON); method.Kind() != reflect.Invalid { _, ok := v.Interface().(json.Unmarshaler) return ok } return false } // hasOmitEmpty indicates if the field has instructed us to not output // the field if omitempty is set on the tag. tag is the string // returned by reflect.StructField.Tag().Get(). func hasOmitEmpty(tag string) bool { sl := strings.Split(tag, ",") for _, str := range sl { if str == "omitempty" { return true } } return false } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/json_test.go000066400000000000000000000073171442026362400274650ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package json import ( "encoding/json" "fmt" "testing" "time" "github.com/kylelemons/godebug/pretty" ) type StructA struct { Name string ID int `json:"id"` Meta *StructB AdditionalFields map[string]interface{} } type StructB struct { Address string AdditionalFields map[string]interface{} } type StructC struct { Time time.Time Project StructD AdditionalFields map[string]interface{} } type StructD struct { Project string Info StructE AdditionalFields map[string]interface{} } type StructE struct { Employees int AdditionalFields map[string]interface{} } func TestUnmarshal(t *testing.T) { now := time.Now() nowJSON, err := now.MarshalJSON() if err != nil { panic(err) } tests := []struct { desc string b []byte got interface{} want interface{} err bool }{ { desc: "receiver not a pointer", got: StructA{}, b: []byte(`{"content": "value"}`), err: true, }, { desc: "receiver not a pointer to a struct", got: new(string), b: []byte(`{"content": "value"}`), err: true, }, { desc: "AdditionalFields not a map", b: []byte(`{"content": "value"}`), got: &struct { AdditionalFields string }{}, err: true, }, { desc: "Success, no json.Unmarshaler types", b: []byte( ` { "Name": "John", "id": 3, "Meta": { "Address": "291 Street", "unknown0": 3.2 }, "unknown0": 10, "unknown1": "hello" } `, ), got: &StructA{}, want: &StructA{ Name: "John", ID: 3, Meta: &StructB{ Address: "291 Street", AdditionalFields: map[string]interface{}{ "unknown0": MarshalRaw(3.2), }, }, AdditionalFields: map[string]interface{}{ "unknown0": MarshalRaw(10), "unknown1": MarshalRaw("hello"), }, }, }, { desc: "Success, a type has json.Unmarshaler", b: []byte(fmt.Sprintf(` { "Time":%s, "Project": { "Project":"myProject", "Info":{ "Employees":2 } } } `, string(nowJSON))), got: &StructC{}, want: &StructC{ Time: now, Project: StructD{ Project: "myProject", Info: StructE{ Employees: 2, }, }, }, }, } for _, test := range tests { err := Unmarshal(test.b, test.got) switch { case err == nil && test.err: t.Errorf("TestUnmarshal(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestUnmarshal(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, test.got); diff != "" { t.Errorf("TestUnmarshal(%s): -want/+got:\n%s", test.desc, diff) } } } func TestIsDelim(t *testing.T) { tests := []struct { desc string token json.Token want bool }{ {desc: "Is delim", token: json.Delim('{'), want: true}, {desc: "Not a delim", token: json.Token("{"), want: false}, } for _, test := range tests { got := isDelim(test.token) if got != test.want { t.Errorf("TestIsDelim(%s): got %v, want %v", test.desc, got, test.want) } } } func TestDelimIs(t *testing.T) { tests := []struct { desc string token json.Token delim rune want bool }{ {desc: "Token is a match", token: json.Delim('{'), delim: '{', want: true}, {desc: "Token is not a match", token: json.Delim('{'), delim: '}', want: false}, } for _, test := range tests { got := delimIs(test.token, test.delim) if got != test.want { t.Errorf("TestDelimIs(%s): got %v, want %v", test.desc, got, test.want) } } } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/mapslice.go000066400000000000000000000164271442026362400272540ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package json import ( "encoding/json" "fmt" "reflect" ) // unmarshalMap unmarshal's a map. func unmarshalMap(dec *json.Decoder, m reflect.Value) error { if m.Kind() != reflect.Ptr || m.Elem().Kind() != reflect.Map { panic("unmarshalMap called on non-*map value") } mapValueType := m.Elem().Type().Elem() walk := mapWalk{dec: dec, m: m, valueType: mapValueType} if err := walk.run(); err != nil { return err } return nil } type mapWalk struct { dec *json.Decoder key string m reflect.Value valueType reflect.Type } // run runs our decoder state machine. func (m *mapWalk) run() error { var state = m.start var err error for { state, err = state() if err != nil { return err } if state == nil { return nil } } } func (m *mapWalk) start() (stateFn, error) { // maps can have custom unmarshaler's. if hasUnmarshalJSON(m.m) { err := m.dec.Decode(m.m.Interface()) if err != nil { return nil, err } return nil, nil } // We only want to use this if the map value is: // *struct/struct/map/slice // otherwise use standard decode t, _ := m.valueBaseType() switch t.Kind() { case reflect.Struct, reflect.Map, reflect.Slice: delim, err := m.dec.Token() if err != nil { return nil, err } // This indicates the value was set to JSON null. if delim == nil { return nil, nil } if !delimIs(delim, '{') { return nil, fmt.Errorf("Unmarshal expected opening {, received %v", delim) } return m.next, nil case reflect.Ptr: return nil, fmt.Errorf("do not support maps with values of '**type' or '*reference") } // This is a basic map type, so just use Decode(). if err := m.dec.Decode(m.m.Interface()); err != nil { return nil, err } return nil, nil } func (m *mapWalk) next() (stateFn, error) { if m.dec.More() { key, err := m.dec.Token() if err != nil { return nil, err } m.key = key.(string) return m.storeValue, nil } // No more entries, so remove final }. _, err := m.dec.Token() if err != nil { return nil, err } return nil, nil } func (m *mapWalk) storeValue() (stateFn, error) { v := m.valueType for { switch v.Kind() { case reflect.Ptr: v = v.Elem() continue case reflect.Struct: return m.storeStruct, nil case reflect.Map: return m.storeMap, nil case reflect.Slice: return m.storeSlice, nil } return nil, fmt.Errorf("bug: mapWalk.storeValue() called on unsupported type: %v", v.Kind()) } } func (m *mapWalk) storeStruct() (stateFn, error) { v := newValue(m.valueType) if err := unmarshalStruct(m.dec, v.Interface()); err != nil { return nil, err } if m.valueType.Kind() == reflect.Ptr { m.m.Elem().SetMapIndex(reflect.ValueOf(m.key), v) return m.next, nil } m.m.Elem().SetMapIndex(reflect.ValueOf(m.key), v.Elem()) return m.next, nil } func (m *mapWalk) storeMap() (stateFn, error) { v := reflect.MakeMap(m.valueType) ptr := newValue(v.Type()) ptr.Elem().Set(v) if err := unmarshalMap(m.dec, ptr); err != nil { return nil, err } m.m.Elem().SetMapIndex(reflect.ValueOf(m.key), v) return m.next, nil } func (m *mapWalk) storeSlice() (stateFn, error) { v := newValue(m.valueType) if err := unmarshalSlice(m.dec, v); err != nil { return nil, err } m.m.Elem().SetMapIndex(reflect.ValueOf(m.key), v.Elem()) return m.next, nil } // valueType returns the underlying Type. So a *struct would yield // struct, etc... func (m *mapWalk) valueBaseType() (reflect.Type, bool) { ptr := false v := m.valueType if v.Kind() == reflect.Ptr { ptr = true v = v.Elem() } return v, ptr } // unmarshalSlice unmarshal's the next value, which must be a slice, into // ptrSlice, which must be a pointer to a slice. newValue() can be use to // create the slice. func unmarshalSlice(dec *json.Decoder, ptrSlice reflect.Value) error { if ptrSlice.Kind() != reflect.Ptr || ptrSlice.Elem().Kind() != reflect.Slice { panic("unmarshalSlice called on non-*[]slice value") } sliceValueType := ptrSlice.Elem().Type().Elem() walk := sliceWalk{ dec: dec, s: ptrSlice, valueType: sliceValueType, } if err := walk.run(); err != nil { return err } return nil } type sliceWalk struct { dec *json.Decoder s reflect.Value // *[]slice valueType reflect.Type } // run runs our decoder state machine. func (s *sliceWalk) run() error { var state = s.start var err error for { state, err = state() if err != nil { return err } if state == nil { return nil } } } func (s *sliceWalk) start() (stateFn, error) { // slices can have custom unmarshaler's. if hasUnmarshalJSON(s.s) { err := s.dec.Decode(s.s.Interface()) if err != nil { return nil, err } return nil, nil } // We only want to use this if the slice value is: // []*struct/[]struct/[]map/[]slice // otherwise use standard decode t := s.valueBaseType() switch t.Kind() { case reflect.Ptr: return nil, fmt.Errorf("cannot unmarshal into a ** or *") case reflect.Struct, reflect.Map, reflect.Slice: delim, err := s.dec.Token() if err != nil { return nil, err } // This indicates the value was set to nil. if delim == nil { return nil, nil } if !delimIs(delim, '[') { return nil, fmt.Errorf("Unmarshal expected opening [, received %v", delim) } return s.next, nil } if err := s.dec.Decode(s.s.Interface()); err != nil { return nil, err } return nil, nil } func (s *sliceWalk) next() (stateFn, error) { if s.dec.More() { return s.storeValue, nil } // Nothing left in the slice, remove closing ] _, err := s.dec.Token() return nil, err } func (s *sliceWalk) storeValue() (stateFn, error) { t := s.valueBaseType() switch t.Kind() { case reflect.Ptr: return nil, fmt.Errorf("do not support 'pointer to pointer' or 'pointer to reference' types") case reflect.Struct: return s.storeStruct, nil case reflect.Map: return s.storeMap, nil case reflect.Slice: return s.storeSlice, nil } return nil, fmt.Errorf("bug: sliceWalk.storeValue() called on unsupported type: %v", t.Kind()) } func (s *sliceWalk) storeStruct() (stateFn, error) { v := newValue(s.valueType) if err := unmarshalStruct(s.dec, v.Interface()); err != nil { return nil, err } if s.valueType.Kind() == reflect.Ptr { s.s.Elem().Set(reflect.Append(s.s.Elem(), v)) return s.next, nil } s.s.Elem().Set(reflect.Append(s.s.Elem(), v.Elem())) return s.next, nil } func (s *sliceWalk) storeMap() (stateFn, error) { v := reflect.MakeMap(s.valueType) ptr := newValue(v.Type()) ptr.Elem().Set(v) if err := unmarshalMap(s.dec, ptr); err != nil { return nil, err } s.s.Elem().Set(reflect.Append(s.s.Elem(), v)) return s.next, nil } func (s *sliceWalk) storeSlice() (stateFn, error) { v := newValue(s.valueType) if err := unmarshalSlice(s.dec, v); err != nil { return nil, err } s.s.Elem().Set(reflect.Append(s.s.Elem(), v.Elem())) return s.next, nil } // valueType returns the underlying Type. So a *struct would yield // struct, etc... func (s *sliceWalk) valueBaseType() reflect.Type { v := s.valueType if v.Kind() == reflect.Ptr { v = v.Elem() } return v } // newValue() returns a new *type that represents type passed. func newValue(valueType reflect.Type) reflect.Value { if valueType.Kind() == reflect.Ptr { return reflect.New(valueType.Elem()) } return reflect.New(valueType) } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/mapslice_test.go000066400000000000000000000147311442026362400303070ustar00rootroot00000000000000package json import ( "bytes" "encoding/json" "reflect" "testing" "github.com/kylelemons/godebug/pretty" ) type StructWithUnmarshal struct { Name string } type StructName struct { Name string AdditionalFields map[string]interface{} } func (s *StructWithUnmarshal) UnmarshalJSON(b []byte) error { // Note this looks sill, but you can't use json.Unmarshal // in an UnmarshalJSON, it causes a recursion loop. This is // just a simple workaround. type unmarshal struct { Name string } u := unmarshal{} err := json.Unmarshal(b, &u) if err != nil { panic(err) } s.Name = u.Name return nil } func TestUnmarshalMap(t *testing.T) { tests := []struct { desc string input string got interface{} want interface{} err bool }{ { desc: "error: struct has no AdditionalFields", input: ` { "key": { "Name": "John" } } `, got: &map[string]struct{ Name string }{}, err: true, }, { desc: "success: basic map[string]interface{}", input: ` { "key": { "Name": "John" } } `, got: &map[string]interface{}{}, want: map[string]interface{}{ "key": map[string]interface{}{ "Name": "John", }, }, }, { desc: "success: struct has UnmarshalJSON", input: ` { "key": { "Name": "John" } } `, got: &map[string]*StructWithUnmarshal{}, want: map[string]*StructWithUnmarshal{ "key": { Name: "John", }, }, }, { desc: "success: map[string]struct", input: ` { "key": { "Name": "John", "extra": "extra" } } `, got: &map[string]StructName{}, want: map[string]StructName{ "key": { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, { desc: "success: map[string]*struct", input: ` { "key": { "Name": "John", "extra": "extra" } } `, got: &map[string]*StructName{}, want: map[string]*StructName{ "key": { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, { desc: "success: map[string][]struct", input: ` { "key": [ { "Name": "John", "extra": "extra" } ] } `, got: &map[string][]StructName{}, want: map[string][]StructName{ "key": { { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, }, { desc: "success: map[string][]*struct", input: ` { "key": [ { "Name": "John", "extra": "extra" } ] } `, got: &map[string][]*StructName{}, want: map[string][]*StructName{ "key": { { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, }, } for _, test := range tests { dec := json.NewDecoder(bytes.NewBuffer([]byte(test.input))) err := unmarshalMap(dec, reflect.ValueOf(test.got)) switch { case err == nil && test.err: t.Errorf("TestUnmarshalMap(%s): err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestUnmarshalMap(%s): err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if diff := pretty.Compare(test.want, test.got); diff != "" { t.Errorf("TestUnmarshalMap(%s): -want/+got\n%s", test.desc, diff) } } } func TestUnmarshalSlice(t *testing.T) { tests := []struct { desc string input string got interface{} want interface{} err bool }{ { desc: "error: struct has no AdditionalFields", input: ` [ { "Name": "John" } ] `, got: new([]struct{ Name string }), err: true, }, { desc: "success: basic slice", input: ` [ "John", "Steve" ] `, got: new([]string), want: []string{"John", "Steve"}, }, { desc: "success: struct has UnmarshalJSON", input: ` [ { "Name": "John" } ] `, got: new([]*StructWithUnmarshal), want: []*StructWithUnmarshal{ { Name: "John", }, }, }, { desc: "success: []struct", input: ` [ { "Name": "John", "extra": "extra" } ] `, got: new([]StructName), want: []StructName{ { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, { desc: "success: []*struct", input: ` [ { "Name": "John", "extra": "extra" } ] `, got: new([]*StructName), want: []*StructName{ { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, { desc: "success: [][]struct", input: ` [ [ { "Name": "John", "extra": "extra" } ] ] `, got: new([][]StructName), want: [][]StructName{ { { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, }, { desc: "success: [][]*struct", input: ` [ [ { "Name": "John", "extra": "extra" } ] ] `, got: new([][]*StructName), want: [][]*StructName{ { { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, }, { desc: "success: []map[string]struct", input: ` [ { "key": { "Name": "John", "extra": "extra" } } ] `, got: new([]map[string]StructName), want: []map[string]StructName{ { "key": { Name: "John", AdditionalFields: map[string]interface{}{ "extra": MarshalRaw("extra"), }, }, }, }, }, } for _, test := range tests { dec := json.NewDecoder(bytes.NewBuffer([]byte(test.input))) err := unmarshalSlice(dec, reflect.ValueOf(test.got)) switch { case err == nil && test.err: t.Errorf("TestUnmarshalSlice(%s): err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestUnmarshalSlice(%s): err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if diff := pretty.Compare(test.want, test.got); diff != "" { t.Errorf("TestUnmarshalSlice(%s): -want/+got\n%s", test.desc, diff) } } } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/marshal.go000066400000000000000000000177531442026362400271110ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package json import ( "bytes" "encoding/json" "fmt" "reflect" "unicode" ) // marshalStruct takes in i, which must be a *struct or struct and marshals its content // as JSON into buff (sometimes with writes to buff directly, sometimes via enc). // This call is recursive for all fields of *struct or struct type. func marshalStruct(v reflect.Value, buff *bytes.Buffer, enc *json.Encoder) error { if v.Kind() == reflect.Ptr { v = v.Elem() } // We only care about custom Marshalling a struct. if v.Kind() != reflect.Struct { return fmt.Errorf("bug: marshal() received a non *struct or struct, received type %T", v.Interface()) } if hasMarshalJSON(v) { b, err := callMarshalJSON(v) if err != nil { return err } buff.Write(b) return nil } t := v.Type() // If it has an AdditionalFields field make sure its the right type. f := v.FieldByName(addField) if f.Kind() != reflect.Invalid { if f.Kind() != reflect.Map { return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", v.Interface()) } if !f.Type().AssignableTo(mapStrInterType) { return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", v.Interface()) } } translator, err := findFields(v) if err != nil { return err } buff.WriteByte(leftBrace) for x := 0; x < v.NumField(); x++ { field := v.Field(x) // We don't access private fields. if unicode.IsLower(rune(t.Field(x).Name[0])) { continue } if t.Field(x).Name == addField { if v.Field(x).Len() > 0 { if err := writeAddFields(field.Interface(), buff, enc); err != nil { return err } buff.WriteByte(comma) } continue } // If they have omitempty set, we don't write out the field if // it is the zero value. if hasOmitEmpty(t.Field(x).Tag.Get("json")) { if v.Field(x).IsZero() { continue } } // Write out the field name part. jsonName := translator.jsonName(t.Field(x).Name) buff.WriteString(fmt.Sprintf("%q:", jsonName)) if field.Kind() == reflect.Ptr { field = field.Elem() } if err := marshalStructField(field, buff, enc); err != nil { return err } } buff.Truncate(buff.Len() - 1) // Remove final comma buff.WriteByte(rightBrace) return nil } func marshalStructField(field reflect.Value, buff *bytes.Buffer, enc *json.Encoder) error { // Determine if we need a trailing comma. defer buff.WriteByte(comma) switch field.Kind() { // If it was a *struct or struct, we need to recursively all marshal(). case reflect.Struct: if field.CanAddr() { field = field.Addr() } return marshalStruct(field, buff, enc) case reflect.Map: return marshalMap(field, buff, enc) case reflect.Slice: return marshalSlice(field, buff, enc) } // It is just a basic type, so encode it. if err := enc.Encode(field.Interface()); err != nil { return err } buff.Truncate(buff.Len() - 1) // Remove Encode() added \n return nil } func marshalMap(v reflect.Value, buff *bytes.Buffer, enc *json.Encoder) error { if v.Kind() != reflect.Map { return fmt.Errorf("bug: marshalMap() called on %T", v.Interface()) } if v.Len() == 0 { buff.WriteByte(leftBrace) buff.WriteByte(rightBrace) return nil } encoder := mapEncode{m: v, buff: buff, enc: enc} return encoder.run() } type mapEncode struct { m reflect.Value buff *bytes.Buffer enc *json.Encoder valueBaseType reflect.Type } // run runs our encoder state machine. func (m *mapEncode) run() error { var state = m.start var err error for { state, err = state() if err != nil { return err } if state == nil { return nil } } } func (m *mapEncode) start() (stateFn, error) { if hasMarshalJSON(m.m) { b, err := callMarshalJSON(m.m) if err != nil { return nil, err } m.buff.Write(b) return nil, nil } valueBaseType := m.m.Type().Elem() if valueBaseType.Kind() == reflect.Ptr { valueBaseType = valueBaseType.Elem() } m.valueBaseType = valueBaseType switch valueBaseType.Kind() { case reflect.Ptr: return nil, fmt.Errorf("Marshal does not support ** or *") case reflect.Struct, reflect.Map, reflect.Slice: return m.encode, nil } // If the map value doesn't have a struct/map/slice, just Encode() it. if err := m.enc.Encode(m.m.Interface()); err != nil { return nil, err } m.buff.Truncate(m.buff.Len() - 1) // Remove Encode() added \n return nil, nil } func (m *mapEncode) encode() (stateFn, error) { m.buff.WriteByte(leftBrace) iter := m.m.MapRange() for iter.Next() { // Write the key. k := iter.Key() m.buff.WriteString(fmt.Sprintf("%q:", k.String())) v := iter.Value() switch m.valueBaseType.Kind() { case reflect.Struct: if v.CanAddr() { v = v.Addr() } if err := marshalStruct(v, m.buff, m.enc); err != nil { return nil, err } case reflect.Map: if err := marshalMap(v, m.buff, m.enc); err != nil { return nil, err } case reflect.Slice: if err := marshalSlice(v, m.buff, m.enc); err != nil { return nil, err } default: panic(fmt.Sprintf("critical bug: mapEncode.encode() called with value base type: %v", m.valueBaseType.Kind())) } m.buff.WriteByte(comma) } m.buff.Truncate(m.buff.Len() - 1) // Remove final comma m.buff.WriteByte(rightBrace) return nil, nil } func marshalSlice(v reflect.Value, buff *bytes.Buffer, enc *json.Encoder) error { if v.Kind() != reflect.Slice { return fmt.Errorf("bug: marshalSlice() called on %T", v.Interface()) } if v.Len() == 0 { buff.WriteByte(leftParen) buff.WriteByte(rightParen) return nil } encoder := sliceEncode{s: v, buff: buff, enc: enc} return encoder.run() } type sliceEncode struct { s reflect.Value buff *bytes.Buffer enc *json.Encoder valueBaseType reflect.Type } // run runs our encoder state machine. func (s *sliceEncode) run() error { var state = s.start var err error for { state, err = state() if err != nil { return err } if state == nil { return nil } } } func (s *sliceEncode) start() (stateFn, error) { if hasMarshalJSON(s.s) { b, err := callMarshalJSON(s.s) if err != nil { return nil, err } s.buff.Write(b) return nil, nil } valueBaseType := s.s.Type().Elem() if valueBaseType.Kind() == reflect.Ptr { valueBaseType = valueBaseType.Elem() } s.valueBaseType = valueBaseType switch valueBaseType.Kind() { case reflect.Ptr: return nil, fmt.Errorf("Marshal does not support ** or *") case reflect.Struct, reflect.Map, reflect.Slice: return s.encode, nil } // If the map value doesn't have a struct/map/slice, just Encode() it. if err := s.enc.Encode(s.s.Interface()); err != nil { return nil, err } s.buff.Truncate(s.buff.Len() - 1) // Remove Encode added \n return nil, nil } func (s *sliceEncode) encode() (stateFn, error) { s.buff.WriteByte(leftParen) for i := 0; i < s.s.Len(); i++ { v := s.s.Index(i) switch s.valueBaseType.Kind() { case reflect.Struct: if v.CanAddr() { v = v.Addr() } if err := marshalStruct(v, s.buff, s.enc); err != nil { return nil, err } case reflect.Map: if err := marshalMap(v, s.buff, s.enc); err != nil { return nil, err } case reflect.Slice: if err := marshalSlice(v, s.buff, s.enc); err != nil { return nil, err } default: panic(fmt.Sprintf("critical bug: mapEncode.encode() called with value base type: %v", s.valueBaseType.Kind())) } s.buff.WriteByte(comma) } s.buff.Truncate(s.buff.Len() - 1) // Remove final comma s.buff.WriteByte(rightParen) return nil, nil } // writeAddFields writes the AdditionalFields struct field out to JSON as field // values. i must be a map[string]interface{} or this will panic. func writeAddFields(i interface{}, buff *bytes.Buffer, enc *json.Encoder) error { m := i.(map[string]interface{}) x := 0 for k, v := range m { buff.WriteString(fmt.Sprintf("%q:", k)) if err := enc.Encode(v); err != nil { return err } buff.Truncate(buff.Len() - 1) // Remove Encode() added \n if x+1 != len(m) { buff.WriteByte(comma) } x++ } return nil } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/marshal_test.go000066400000000000000000000120201442026362400301260ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package json import ( "encoding/json" "testing" "github.com/kylelemons/godebug/pretty" ) func TestMarshalStruct(t *testing.T) { tests := []struct { desc string value interface{} want map[string]interface{} err bool }{ { desc: "struct with no additional fields", value: struct { Name string Int int }{ Name: "my name", Int: 5, }, want: map[string]interface{}{ "Name": "my name", "Int": 5, }, }, { desc: "*struct with AdditionalFields", value: &struct { Name string Int int AdditionalFields map[string]interface{} `json:"-"` }{ Name: "John Doak", Int: 45, AdditionalFields: map[string]interface{}{ "Hello": "World", "Float": 3.2, }, }, want: map[string]interface{}{ "Name": "John Doak", "Int": 45, "Float": 3.2, "Hello": "World", }, }, { desc: "AdditionalFields is not a map", value: struct { AdditionalFields string `json:"-"` }{ AdditionalFields: "hello", }, err: true, }, { desc: "AdditionalFields is not a map[string]interface{}", value: struct { AdditionalFields map[string]string `json:"-"` }{ AdditionalFields: map[string]string{ "Hello": "World", }, }, err: true, }, { desc: "Multiple Structs", value: &StructA{ Name: "John", ID: 3, Meta: &StructB{ Address: "291 Street", AdditionalFields: map[string]interface{}{ "unknown0": MarshalRaw(3.2), }, }, AdditionalFields: map[string]interface{}{ "unknown0": MarshalRaw(10), "unknown1": MarshalRaw("hello"), }, }, want: map[string]interface{}{ "Name": "John", "id": 3, "Meta": map[string]interface{}{ "Address": "291 Street", "unknown0": 3.2, }, "unknown0": 10, "unknown1": "hello", }, }, { desc: "Struct with map[string]interface{}", value: struct { Name string Map map[string]interface{} AdditionalFields map[string]interface{} }{ Name: "John", Map: map[string]interface{}{ "key": "value", }, }, want: map[string]interface{}{ "Name": "John", "Map": map[string]interface{}{ "key": "value", }, }, }, { desc: "Struct with map[string]struct{}", value: struct { Name string Map map[string]StructB AdditionalFields map[string]interface{} }{ Name: "John", Map: map[string]StructB{ "key": { Address: "addr", }, }, }, want: map[string]interface{}{ "Name": "John", "Map": map[string]interface{}{ "key": map[string]interface{}{ "Address": "addr", }, }, }, }, { desc: "Struct with map[string][]", value: struct { Name string Map map[string]interface{} AdditionalFields map[string]interface{} }{ Name: "John", Map: map[string]interface{}{ "key": []string{ "apples", }, }, }, want: map[string]interface{}{ "Name": "John", "Map": map[string]interface{}{ "key": []string{"apples"}, }, }, }, { desc: "Struct with map[string][]struct", value: struct { Name string Map map[string][]StructB AdditionalFields map[string]interface{} }{ Name: "John", Map: map[string][]StructB{ "key": { {Address: "addr"}, }, }, }, want: map[string]interface{}{ "Name": "John", "Map": map[string]interface{}{ "key": []interface{}{ map[string]interface{}{ "Address": "addr", }, }, }, }, }, } for _, test := range tests { b, err := Marshal(test.value) switch { case err == nil && test.err: t.Errorf("TestMarshal(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestMarshal(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } got := map[string]interface{}{} if err := json.Unmarshal(b, &got); err != nil { t.Errorf("TestMarshal(%s): Marshal produced invalid JSON:\n%s\n%s", test.desc, err, string(b)) continue } if diff := pretty.Compare(test.want, got); diff != "" { t.Errorf("TestMarshal(%s): -want/+got:\n%s", test.desc, diff) } } } func TestEmptyTypes(t *testing.T) { type structA struct { EmptyMap map[string]bool EmptySlice []string Slice []string EmptyInt int Int int AdditionalFields map[string]interface{} } val := structA{ EmptyMap: map[string]bool{}, Slice: []string{"hello"}, Int: 1, } b, err := Marshal(val) if err != nil { t.Fatalf("TestEmptyTypes: unexpected error on Marshal: %v", err) } got := structA{} if err := Unmarshal(b, &got); err != nil { t.Fatalf("TestEmptyTypes: unexpected error when Umarshalling: %v", err) } if diff := pretty.Compare(got, val); diff != "" { t.Fatalf("TestEmptyTypes: -want/+got:\n%s", diff) } } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/struct.go000066400000000000000000000164141442026362400267770ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package json import ( "encoding/json" "fmt" "reflect" "strings" ) func unmarshalStruct(jdec *json.Decoder, i interface{}) error { v := reflect.ValueOf(i) if v.Kind() != reflect.Ptr { return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i) } v = v.Elem() if v.Kind() != reflect.Struct { return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i) } if hasUnmarshalJSON(v) { // Indicates that this type has a custom Unmarshaler. return jdec.Decode(v.Addr().Interface()) } f := v.FieldByName(addField) if f.Kind() == reflect.Invalid { return fmt.Errorf("Unmarshal(%T) only supports structs that have the field AdditionalFields or implements json.Unmarshaler", i) } if f.Kind() != reflect.Map || !f.Type().AssignableTo(mapStrInterType) { return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", i) } dec := newDecoder(jdec, v) return dec.run() } type decoder struct { dec *json.Decoder value reflect.Value // This will be a reflect.Struct translator translateFields key string } func newDecoder(dec *json.Decoder, value reflect.Value) *decoder { return &decoder{value: value, dec: dec} } // run runs our decoder state machine. func (d *decoder) run() error { var state = d.start var err error for { state, err = state() if err != nil { return err } if state == nil { return nil } } } // start looks for our opening delimeter '{' and then transitions to looping through our fields. func (d *decoder) start() (stateFn, error) { var err error d.translator, err = findFields(d.value) if err != nil { return nil, err } delim, err := d.dec.Token() if err != nil { return nil, err } if !delimIs(delim, '{') { return nil, fmt.Errorf("Unmarshal expected opening {, received %v", delim) } return d.next, nil } // next gets the next struct field name from the raw json or stops the machine if we get our closing }. func (d *decoder) next() (stateFn, error) { if !d.dec.More() { // Remove the closing }. if _, err := d.dec.Token(); err != nil { return nil, err } return nil, nil } key, err := d.dec.Token() if err != nil { return nil, err } d.key = key.(string) return d.storeValue, nil } // storeValue takes the next value and stores it our struct. If the field can't be found // in the struct, it pushes the operation to storeAdditional(). func (d *decoder) storeValue() (stateFn, error) { goName := d.translator.goName(d.key) if goName == "" { goName = d.key } // We don't have the field in the struct, so it goes in AdditionalFields. f := d.value.FieldByName(goName) if f.Kind() == reflect.Invalid { return d.storeAdditional, nil } // Indicates that this type has a custom Unmarshaler. if hasUnmarshalJSON(f) { err := d.dec.Decode(f.Addr().Interface()) if err != nil { return nil, err } return d.next, nil } t, isPtr, err := fieldBaseType(d.value, goName) if err != nil { return nil, fmt.Errorf("type(%s) had field(%s) %w", d.value.Type().Name(), goName, err) } switch t.Kind() { // We need to recursively call ourselves on any *struct or struct. case reflect.Struct: if isPtr { if f.IsNil() { f.Set(reflect.New(t)) } } else { f = f.Addr() } if err := unmarshalStruct(d.dec, f.Interface()); err != nil { return nil, err } return d.next, nil case reflect.Map: v := reflect.MakeMap(f.Type()) ptr := newValue(f.Type()) ptr.Elem().Set(v) if err := unmarshalMap(d.dec, ptr); err != nil { return nil, err } f.Set(ptr.Elem()) return d.next, nil case reflect.Slice: v := reflect.MakeSlice(f.Type(), 0, 0) ptr := newValue(f.Type()) ptr.Elem().Set(v) if err := unmarshalSlice(d.dec, ptr); err != nil { return nil, err } f.Set(ptr.Elem()) return d.next, nil } if !isPtr { f = f.Addr() } // For values that are pointers, we need them to be non-nil in order // to decode into them. if f.IsNil() { f.Set(reflect.New(t)) } if err := d.dec.Decode(f.Interface()); err != nil { return nil, err } return d.next, nil } // storeAdditional pushes the key/value into our .AdditionalFields map. func (d *decoder) storeAdditional() (stateFn, error) { rw := json.RawMessage{} if err := d.dec.Decode(&rw); err != nil { return nil, err } field := d.value.FieldByName(addField) if field.IsNil() { field.Set(reflect.MakeMap(field.Type())) } field.SetMapIndex(reflect.ValueOf(d.key), reflect.ValueOf(rw)) return d.next, nil } func fieldBaseType(v reflect.Value, fieldName string) (t reflect.Type, isPtr bool, err error) { sf, ok := v.Type().FieldByName(fieldName) if !ok { return nil, false, fmt.Errorf("bug: fieldBaseType() lookup of field(%s) on type(%s): do not have field", fieldName, v.Type().Name()) } t = sf.Type if t.Kind() == reflect.Ptr { t = t.Elem() isPtr = true } if t.Kind() == reflect.Ptr { return nil, isPtr, fmt.Errorf("received pointer to pointer type, not supported") } return t, isPtr, nil } type translateField struct { jsonName string goName string } // translateFields is a list of translateFields with a handy lookup method. type translateFields []translateField // goName loops through a list of fields looking for one contaning the jsonName and // returning the goName. If not found, returns the empty string. // Note: not a map because at this size slices are faster even in tight loops. func (t translateFields) goName(jsonName string) string { for _, entry := range t { if entry.jsonName == jsonName { return entry.goName } } return "" } // jsonName loops through a list of fields looking for one contaning the goName and // returning the jsonName. If not found, returns the empty string. // Note: not a map because at this size slices are faster even in tight loops. func (t translateFields) jsonName(goName string) string { for _, entry := range t { if entry.goName == goName { return entry.jsonName } } return "" } var umarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() // findFields parses a struct and writes the field tags for lookup. It will return an error // if any field has a type of *struct or struct that does not implement json.Marshaler. func findFields(v reflect.Value) (translateFields, error) { if v.Kind() == reflect.Ptr { v = v.Elem() } if v.Kind() != reflect.Struct { return nil, fmt.Errorf("findFields received a %s type, expected *struct or struct", v.Type().Name()) } tfs := make([]translateField, 0, v.NumField()) for i := 0; i < v.NumField(); i++ { tf := translateField{ goName: v.Type().Field(i).Name, jsonName: parseTag(v.Type().Field(i).Tag.Get("json")), } switch tf.jsonName { case "", "-": tf.jsonName = tf.goName } tfs = append(tfs, tf) f := v.Field(i) if f.Kind() == reflect.Ptr { f = f.Elem() } if f.Kind() == reflect.Struct { if f.Type().Implements(umarshalerType) { return nil, fmt.Errorf("struct type %q which has field %q which "+ "doesn't implement json.Unmarshaler", v.Type().Name(), v.Type().Field(i).Name) } } } return tfs, nil } // parseTag just returns the first entry in the tag. tag is the string // returned by reflect.StructField.Tag().Get(). func parseTag(tag string) string { if idx := strings.Index(tag, ","); idx != -1 { return tag[:idx] } return tag } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/struct_test.go000066400000000000000000000146231442026362400300360ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package json import ( "bytes" "encoding/json" "reflect" "runtime" "testing" "github.com/kylelemons/godebug/pretty" ) func TestDecoderStart(t *testing.T) { tests := []struct { desc string b []byte i interface{} stateFn stateFn err bool }{ { desc: "No content to decode", i: &StructA{}, stateFn: nil, err: true, }, { desc: "No opening brace", b: []byte("3"), i: &StructA{}, stateFn: nil, err: true, }, { desc: "Success", b: []byte(`{"Name": "value"}`), i: &StructA{}, stateFn: (new(decoder).next), }, } for _, test := range tests { dec := newDecoder(json.NewDecoder(bytes.NewBuffer(test.b)), reflect.ValueOf(test.i)) stateFn, err := dec.start() switch { case err == nil && test.err: t.Errorf("TestDecoderStart(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestDecoderStart(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } gotStateFn := runtime.FuncForPC(reflect.ValueOf(stateFn).Pointer()).Name() wantStateFn := runtime.FuncForPC(reflect.ValueOf(test.stateFn).Pointer()).Name() if gotStateFn != wantStateFn { t.Errorf("TestDecoderStart(%s): got(stateFn) %s, want %s", test.desc, gotStateFn, wantStateFn) } } } func TestDecoderNext(t *testing.T) { tests := []struct { desc string b []byte // advToken advanced the decoder this may Token() calls, as the decoder only works // on well formed JSON. advToken int i interface{} key string stateFn stateFn err bool }{ { desc: "No content to decode", i: &StructA{}, stateFn: nil, err: true, }, { desc: "Bad ] found", b: []byte("{]"), advToken: 1, i: &StructA{}, stateFn: nil, err: true, }, { desc: "Closing brace", b: []byte("{}"), advToken: 1, i: &StructA{}, stateFn: nil, err: false, }, { desc: "Success", b: []byte(`{"Name": "value"}`), advToken: 1, i: &StructA{}, key: "Name", stateFn: (new(decoder).storeValue), }, } for _, test := range tests { dec := newDecoder(json.NewDecoder(bytes.NewBuffer(test.b)), reflect.ValueOf(test.i)) for i := 0; i < test.advToken; i++ { if _, err := dec.dec.Token(); err != nil { panic(err) } } stateFn, err := dec.next() switch { case err == nil && test.err: t.Errorf("TestDecoderNext(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestDecoderNext(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if dec.key != test.key { t.Errorf("TestDecoderNext(%s): got(.key) %s, want %s", test.desc, dec.key, test.key) } gotStateFn := runtime.FuncForPC(reflect.ValueOf(stateFn).Pointer()).Name() wantStateFn := runtime.FuncForPC(reflect.ValueOf(test.stateFn).Pointer()).Name() if gotStateFn != wantStateFn { t.Errorf("TestDecoderNext(%s): got(stateFn) %s, want %s", test.desc, gotStateFn, wantStateFn) } } } func TestDecoderStoreValue(t *testing.T) { tests := []struct { desc string b []byte want StructA stateFn stateFn }{ { desc: "Field found, no struct tag", b: []byte(`{"Name": "myName"}`), want: StructA{Name: "myName"}, stateFn: (new(decoder).next), }, { desc: "Field found, using struct tag", b: []byte(`{"id": 3}`), want: StructA{ID: 3}, stateFn: (new(decoder).next), }, { desc: "Field not found, go to storeAdditional()", b: []byte(`{"blah": 3}`), want: StructA{}, stateFn: (new(decoder).storeAdditional), }, } for _, test := range tests { got := StructA{} dec := newDecoder(json.NewDecoder(bytes.NewBuffer(test.b)), reflect.ValueOf(&got).Elem()) _, err := dec.start() // populates our translator field if err != nil { panic(err) } _, err = dec.next() if err != nil { panic(err) } stateFn, err := dec.storeValue() if err != nil { t.Errorf("TestDecoderStoreValue(%s): got err == %s, want err == nil", test.desc, err) continue } if diff := pretty.Compare(test.want, got); diff != "" { t.Errorf("TestDecoderStoreValue(%s): -want/+got:\n%s", test.desc, diff) continue } gotStateFn := runtime.FuncForPC(reflect.ValueOf(stateFn).Pointer()).Name() wantStateFn := runtime.FuncForPC(reflect.ValueOf(test.stateFn).Pointer()).Name() if gotStateFn != wantStateFn { t.Errorf("TestDecoderStoreValue(%s): got(stateFn) %s, want %s", test.desc, gotStateFn, wantStateFn) } } } func TestDecoderStoreAdditional(t *testing.T) { tests := []struct { desc string b []byte got StructA want StructA stateFn stateFn }{ { desc: "Map not initialized", b: []byte(`{"blah": "whatever"}`), got: StructA{}, want: StructA{ AdditionalFields: map[string]interface{}{ "blah": json.RawMessage(`"whatever"`), }, }, stateFn: (new(decoder).next), }, { desc: "Map exists", b: []byte(`{"blah": "whatever"}`), got: StructA{ AdditionalFields: map[string]interface{}{ "else": json.RawMessage(`"if"`), }, }, want: StructA{ AdditionalFields: map[string]interface{}{ "else": json.RawMessage(`"if"`), "blah": json.RawMessage(`"whatever"`), }, }, stateFn: (new(decoder).next), }, } for _, test := range tests { dec := newDecoder(json.NewDecoder(bytes.NewBuffer(test.b)), reflect.ValueOf(&test.got).Elem()) _, err := dec.start() // populates our translator field if err != nil { panic(err) } _, err = dec.next() if err != nil { panic(err) } stateFn, err := dec.storeAdditional() if err != nil { t.Errorf("TestDecoderStoreAdditional(%s): got err == %s, want err == nil", test.desc, err) continue } if diff := pretty.Compare(test.want, test.got); diff != "" { t.Errorf("TestDecoderStoreAdditional(%s): -want/+got:\n%s", test.desc, diff) continue } gotStateFn := runtime.FuncForPC(reflect.ValueOf(stateFn).Pointer()).Name() wantStateFn := runtime.FuncForPC(reflect.ValueOf(test.stateFn).Pointer()).Name() if gotStateFn != wantStateFn { t.Errorf("TestDecoderStoreAdditional(%s): got(stateFn) %s, want %s", test.desc, gotStateFn, wantStateFn) } } } microsoft-authentication-library-for-go-1.0.0/apps/internal/json/types/000077500000000000000000000000001442026362400262625ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/json/types/time/000077500000000000000000000000001442026362400272205ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/json/types/time/time.go000066400000000000000000000043441442026362400305120ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // Package time provides for custom types to translate time from JSON and other formats // into time.Time objects. package time import ( "fmt" "strconv" "strings" "time" ) // Unix provides a type that can marshal and unmarshal a string representation // of the unix epoch into a time.Time object. type Unix struct { T time.Time } // MarshalJSON implements encoding/json.MarshalJSON(). func (u Unix) MarshalJSON() ([]byte, error) { if u.T.IsZero() { return []byte(""), nil } return []byte(fmt.Sprintf("%q", strconv.FormatInt(u.T.Unix(), 10))), nil } // UnmarshalJSON implements encoding/json.UnmarshalJSON(). func (u *Unix) UnmarshalJSON(b []byte) error { i, err := strconv.Atoi(strings.Trim(string(b), `"`)) if err != nil { return fmt.Errorf("unix time(%s) could not be converted from string to int: %w", string(b), err) } u.T = time.Unix(int64(i), 0) return nil } // DurationTime provides a type that can marshal and unmarshal a string representation // of a duration from now into a time.Time object. // Note: I'm not sure this is the best way to do this. What happens is we get a field // called "expires_in" that represents the seconds from now that this expires. We // turn that into a time we call .ExpiresOn. But maybe we should be recording // when the token was received at .TokenRecieved and .ExpiresIn should remain as a duration. // Then we could have a method called ExpiresOn(). Honestly, the whole thing is // bad because the server doesn't return a concrete time. I think this is // cleaner, but its not great either. type DurationTime struct { T time.Time } // MarshalJSON implements encoding/json.MarshalJSON(). func (d DurationTime) MarshalJSON() ([]byte, error) { if d.T.IsZero() { return []byte(""), nil } dt := time.Until(d.T) return []byte(fmt.Sprintf("%d", int64(dt*time.Second))), nil } // UnmarshalJSON implements encoding/json.UnmarshalJSON(). func (d *DurationTime) UnmarshalJSON(b []byte) error { i, err := strconv.Atoi(strings.Trim(string(b), `"`)) if err != nil { return fmt.Errorf("unix time(%s) could not be converted from string to int: %w", string(b), err) } d.T = time.Now().Add(time.Duration(i) * time.Second) return nil } microsoft-authentication-library-for-go-1.0.0/apps/internal/local/000077500000000000000000000000001442026362400252375ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/local/server.go000066400000000000000000000100141442026362400270700ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // Package local contains a local HTTP server used with interactive authentication. package local import ( "context" "fmt" "net" "net/http" "strconv" "strings" "time" ) var okPage = []byte(` Authentication Complete

Authentication complete. You can return to the application. Feel free to close this browser tab.

`) const failPage = ` Authentication Failed

Authentication failed. You can return to the application. Feel free to close this browser tab.

Error details: error %s error_description: %s

` // Result is the result from the redirect. type Result struct { // Code is the code sent by the authority server. Code string // Err is set if there was an error. Err error } // Server is an HTTP server. type Server struct { // Addr is the address the server is listening on. Addr string resultCh chan Result s *http.Server reqState string } // New creates a local HTTP server and starts it. func New(reqState string, port int) (*Server, error) { var l net.Listener var err error var portStr string if port > 0 { // use port provided by caller l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) portStr = strconv.FormatInt(int64(port), 10) } else { // find a free port for i := 0; i < 10; i++ { l, err = net.Listen("tcp", "localhost:0") if err != nil { continue } addr := l.Addr().String() portStr = addr[strings.LastIndex(addr, ":")+1:] break } } if err != nil { return nil, err } serv := &Server{ Addr: fmt.Sprintf("http://localhost:%s", portStr), s: &http.Server{Addr: "localhost:0", ReadHeaderTimeout: time.Second}, reqState: reqState, resultCh: make(chan Result, 1), } serv.s.Handler = http.HandlerFunc(serv.handler) if err := serv.start(l); err != nil { return nil, err } return serv, nil } func (s *Server) start(l net.Listener) error { go func() { err := s.s.Serve(l) if err != nil { select { case s.resultCh <- Result{Err: err}: default: } } }() return nil } // Result gets the result of the redirect operation. Once a single result is returned, the server // is shutdown. ctx deadline will be honored. func (s *Server) Result(ctx context.Context) Result { select { case <-ctx.Done(): return Result{Err: ctx.Err()} case r := <-s.resultCh: return r } } // Shutdown shuts down the server. func (s *Server) Shutdown() { // Note: You might get clever and think you can do this in handler() as a defer, you can't. _ = s.s.Shutdown(context.Background()) } func (s *Server) putResult(r Result) { select { case s.resultCh <- r: default: } } func (s *Server) handler(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() headerErr := q.Get("error") if headerErr != "" { desc := q.Get("error_description") // Note: It is a little weird we handle some errors by not going to the failPage. If they all should, // change this to s.error() and make s.error() write the failPage instead of an error code. _, _ = w.Write([]byte(fmt.Sprintf(failPage, headerErr, desc))) s.putResult(Result{Err: fmt.Errorf(desc)}) return } respState := q.Get("state") switch respState { case s.reqState: case "": s.error(w, http.StatusInternalServerError, "server didn't send OAuth state") return default: s.error(w, http.StatusInternalServerError, "mismatched OAuth state, req(%s), resp(%s)", s.reqState, respState) return } code := q.Get("code") if code == "" { s.error(w, http.StatusInternalServerError, "authorization code missing in query string") return } _, _ = w.Write(okPage) s.putResult(Result{Code: code}) } func (s *Server) error(w http.ResponseWriter, code int, str string, i ...interface{}) { err := fmt.Errorf(str, i...) http.Error(w, err.Error(), code) s.putResult(Result{Err: err}) } microsoft-authentication-library-for-go-1.0.0/apps/internal/local/server_test.go000066400000000000000000000066021442026362400301370ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package local import ( "context" "io" "net/http" "net/url" "strings" "testing" "time" "github.com/kylelemons/godebug/pretty" ) func TestServer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() tests := []struct { desc string reqState string port int q url.Values failPage bool statusCode int }{ { desc: "Error: Query Values has 'error' key", reqState: "state", port: 0, q: url.Values{"state": []string{"state"}, "error": []string{"error"}}, statusCode: 200, failPage: true, }, { desc: "Error: Query Values missing 'state' key", reqState: "state", port: 0, q: url.Values{"code": []string{"code"}}, statusCode: http.StatusInternalServerError, }, { desc: "Error: Query Values missing had 'state' key value that was different that requested", reqState: "state", port: 0, q: url.Values{"state": []string{"etats"}, "code": []string{"code"}}, statusCode: http.StatusInternalServerError, }, { desc: "Error: Query Values missing 'code' key", reqState: "state", port: 0, q: url.Values{"state": []string{"state"}}, statusCode: http.StatusInternalServerError, }, { desc: "Success", reqState: "state", port: 0, q: url.Values{"state": []string{"state"}, "code": []string{"code"}}, statusCode: 200, }, } for _, test := range tests { serv, err := New(test.reqState, test.port) if err != nil { panic(err) } defer serv.Shutdown() if !strings.HasPrefix(serv.Addr, "http://localhost") { t.Fatalf("unexpected server address %s", serv.Addr) } u, err := url.Parse(serv.Addr) if err != nil { panic(err) } u.RawQuery = test.q.Encode() resp, err := http.DefaultClient.Do( &http.Request{ Method: "GET", URL: u, }, ) if err != nil { panic(err) } if resp.StatusCode != test.statusCode { if test.statusCode == 200 { t.Errorf("TestServer(%s): got StatusCode == %d, want StatusCode == 200", test.desc, resp.StatusCode) res := serv.Result(ctx) if res.Err == nil { t.Errorf("TestServer(%s): Result.Err == nil, want Result.Err != nil", test.desc) } continue } t.Errorf("TestServer(%s): got StatusCode == %d, want StatusCode == %d", test.desc, resp.StatusCode, test.statusCode) res := serv.Result(ctx) if res.Err == nil { t.Errorf("TestServer(%s): Result.Err == nil, want Result.Err != nil", test.desc) } continue } if resp.StatusCode != 200 { continue } content, err := io.ReadAll(resp.Body) if err != nil { panic(err) } if test.failPage { if !strings.Contains(string(content), "Authentication Failed") { t.Errorf("TestServer(%s): got okay page, want failed page", test.desc) } res := serv.Result(ctx) if res.Err == nil { t.Errorf("TestServer(%s): Result.Err == nil, want Result.Err != nil", test.desc) } continue } if !strings.Contains(string(content), "Authentication Complete") { t.Errorf("TestServer(%s): got failed page, okay page", test.desc) } res := serv.Result(ctx) if diff := pretty.Compare(Result{Code: "code"}, res); diff != "" { t.Errorf("TestServer(%s): -want/+got:\n%s", test.desc, diff) } } } microsoft-authentication-library-for-go-1.0.0/apps/internal/mock/000077500000000000000000000000001442026362400250765ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/mock/mock.go000066400000000000000000000112151442026362400263560ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package mock import ( "bytes" "encoding/base64" "fmt" "io" "net/http" "strings" "time" ) type response struct { body []byte callback func(*http.Request) code int headers http.Header } type responseOption interface { apply(*response) } type respOpt func(*response) func (fn respOpt) apply(r *response) { fn(r) } // WithBody sets the HTTP response's body to the specified value. func WithBody(b []byte) responseOption { return respOpt(func(r *response) { r.body = b }) } // WithCallback sets a callback to invoke before returning the response. func WithCallback(callback func(*http.Request)) responseOption { return respOpt(func(r *response) { r.callback = callback }) } // Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence. type Client struct { resp []response } func (c *Client) AppendResponse(opts ...responseOption) { r := response{code: http.StatusOK, headers: http.Header{}} for _, o := range opts { o.apply(&r) } c.resp = append(c.resp, r) } func (c *Client) Do(req *http.Request) (*http.Response, error) { if len(c.resp) == 0 { panic(fmt.Sprintf(`no response for "%s"`, req.URL.String())) } resp := c.resp[0] c.resp = c.resp[1:] if resp.callback != nil { resp.callback(req) } res := http.Response{Header: resp.headers, StatusCode: resp.code} res.Body = io.NopCloser(bytes.NewReader(resp.body)) return &res, nil } // CloseIdleConnections implements the comm.HTTPClient interface func (*Client) CloseIdleConnections() {} func GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo string, expiresIn int) []byte { body := fmt.Sprintf( `{"access_token": "%s","expires_in": %d,"expires_on": %d`, accessToken, expiresIn, time.Now().Add(time.Duration(expiresIn)*time.Second).Unix(), ) if clientInfo != "" { body += fmt.Sprintf(`, "client_info": "%s"`, clientInfo) } if idToken != "" { body += fmt.Sprintf(`, "id_token": "%s"`, idToken) } if refreshToken != "" { body += fmt.Sprintf(`, "refresh_token": "%s"`, refreshToken) } body += "}" return []byte(body) } func GetIDToken(tenant, issuer string) string { now := time.Now().Unix() payload := []byte(fmt.Sprintf(`{"aud": "%s","exp": %d,"iat": %d,"iss": "%s","tid": "%s"}`, tenant, now+3600, now, issuer, tenant)) return fmt.Sprintf("header.%s.signature", base64.RawStdEncoding.EncodeToString(payload)) } func GetInstanceDiscoveryBody(host, tenant string) []byte { authority := fmt.Sprintf("https://%s/%s", host, tenant) body := fmt.Sprintf(`{"tenant_discovery_endpoint": "%s/v2.0/.well-known/openid-configuration","api-version": "1.1","metadata": [{"preferred_network": "%s","preferred_cache": "%s","aliases": ["%s"]}]}`, authority, host, host, host, ) headers := http.Header{} headers.Add("Content-Type", "application/json; charset=utf-8") return []byte(body) } func GetTenantDiscoveryBody(host, tenant string) []byte { authority := fmt.Sprintf("https://%s/%s", host, tenant) content := strings.ReplaceAll(`{"token_endpoint": "{authority}/oauth2/v2.0/token", "token_endpoint_auth_methods_supported": [ "client_secret_post", "private_key_jwt", "client_secret_basic" ], "jwks_uri": "{authority}/discovery/v2.0/keys", "response_modes_supported": [ "query", "fragment", "form_post" ], "subject_types_supported": [ "pairwise" ], "id_token_signing_alg_values_supported": [ "RS256" ], "response_types_supported": [ "code", "id_token", "code id_token", "id_token token" ], "scopes_supported": [ "openid", "profile", "email", "offline_access" ], "issuer": "{authority}/v2.0", "request_uri_parameter_supported": false, "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo", "authorization_endpoint": "{authority}/oauth2/v2.0/authorize", "device_authorization_endpoint": "{authority}/oauth2/v2.0/devicecode", "http_logout_supported": true, "frontchannel_logout_supported": true, "end_session_endpoint": "{authority}/oauth2/v2.0/logout", "claims_supported": [ "sub", "iss", "cloud_instance_name", "cloud_instance_host_name", "cloud_graph_host_name", "msgraph_host", "aud", "exp", "iat", "auth_time", "acr", "nonce", "preferred_username", "name", "tid", "ver", "at_hash", "c_hash", "email" ], "kerberos_endpoint": "{authority}/kerberos", "tenant_region_scope": "NA", "cloud_instance_name": "microsoftonline.com", "cloud_graph_host_name": "graph.windows.net", "msgraph_host": "graph.microsoft.com", "rbac_url": "https://pas.windows.net" }`, "{authority}", authority) return []byte(content) } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/000077500000000000000000000000001442026362400252655ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/fake/000077500000000000000000000000001442026362400261735ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/fake/fake.go000066400000000000000000000145141442026362400274350ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package fake import ( "context" "errors" "fmt" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs" ) // ResolveEndpoints is a fake implementation of the oauth.resolveEndpointer interface. type ResolveEndpoints struct { // Set this to true to have all APIs return an error. Err bool // fake result to return Endpoints authority.Endpoints } func (f ResolveEndpoints) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) { if f.Err { return authority.Endpoints{}, errors.New("error") } return f.Endpoints, nil } // AccessTokens is a fake implementation of the oauth.accessTokens interface. type AccessTokens struct { // Set this to true to have all APIs return an error. Err bool // Result is for use with FromDeviceCodeResult. On each call it returns // the next item in this slice. They must be either an error or nil. Result []error Next int // fake result to return AccessToken accesstokens.TokenResponse // fake result to return DeviceCode accesstokens.DeviceCodeResult // FromRefreshTokenCallback is an optional callback invoked by FromRefreshToken FromRefreshTokenCallback func(appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) // ValidateAssertion is an optional callback for validating an assertion generated by confidential.Client ValidateAssertion func(string) } func (f *AccessTokens) FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error) { if f.Err { return accesstokens.TokenResponse{}, fmt.Errorf("error") } return f.AccessToken, nil } func (f *AccessTokens) FromAuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) { if f.Err { return accesstokens.TokenResponse{}, fmt.Errorf("error") } return f.AccessToken, nil } func (f *AccessTokens) FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error) { if f.FromRefreshTokenCallback != nil { f.FromRefreshTokenCallback(appType, authParams, cc, refreshToken) } if f.Err { return accesstokens.TokenResponse{}, fmt.Errorf("error") } return f.AccessToken, nil } func (f *AccessTokens) FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (accesstokens.TokenResponse, error) { if f.Err { return accesstokens.TokenResponse{}, fmt.Errorf("error") } return f.AccessToken, nil } func (f *AccessTokens) FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (accesstokens.TokenResponse, error) { if f.Err { return accesstokens.TokenResponse{}, fmt.Errorf("error") } if f.ValidateAssertion != nil { f.ValidateAssertion(assertion) } return f.AccessToken, nil } func (f *AccessTokens) FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion, clientSecret string) (accesstokens.TokenResponse, error) { if f.Err { return accesstokens.TokenResponse{}, fmt.Errorf("error") } return f.AccessToken, nil } func (f *AccessTokens) FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion, assertion string) (accesstokens.TokenResponse, error) { if f.Err { return accesstokens.TokenResponse{}, fmt.Errorf("error") } return f.AccessToken, nil } func (f *AccessTokens) DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (accesstokens.DeviceCodeResult, error) { if f.Err { return accesstokens.DeviceCodeResult{}, fmt.Errorf("error") } return f.DeviceCode, nil } func (f *AccessTokens) FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult accesstokens.DeviceCodeResult) (accesstokens.TokenResponse, error) { if f.Next < len(f.Result) { defer func() { f.Next++ }() v := f.Result[f.Next] if v == nil { return f.AccessToken, nil } return accesstokens.TokenResponse{}, v } panic("AccessTokens.FromDeviceCodeResult() asked for more return values than provided") } func (f *AccessTokens) FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (accesstokens.TokenResponse, error) { if f.Err { return accesstokens.TokenResponse{}, fmt.Errorf("error") } return f.AccessToken, nil } // Authority is a fake implementation of the oauth.fetchAuthority interface. type Authority struct { // Set this to true to have all APIs return an error. Err bool // The fake UserRealm to return from the UserRealm() API. Realm authority.UserRealm // fake result to return InstanceResp authority.InstanceDiscoveryResponse } func (f Authority) UserRealm(ctx context.Context, params authority.AuthParams) (authority.UserRealm, error) { if f.Err { return authority.UserRealm{}, errors.New("error") } return f.Realm, nil } func (f Authority) AADInstanceDiscovery(ctx context.Context, info authority.Info) (authority.InstanceDiscoveryResponse, error) { if f.Err { return authority.InstanceDiscoveryResponse{}, errors.New("error") } return f.InstanceResp, nil } // WSTrust is a fake implementation of the oauth.fetchWSTrust interface. type WSTrust struct { // Set these to true to have their respective APIs return an error. GetMexErr, GetSAMLTokenInfoErr bool // fake result to return MexDocument defs.MexDocument // fake result to return SamlTokenInfo wstrust.SamlTokenInfo } func (f WSTrust) Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error) { if f.GetMexErr { return defs.MexDocument{}, errors.New("error") } return f.MexDocument, nil } func (f WSTrust) SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (wstrust.SamlTokenInfo, error) { if f.GetSAMLTokenInfoErr { return wstrust.SamlTokenInfo{}, errors.New("error") } return f.SamlTokenInfo, nil } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/oauth.go000066400000000000000000000335341442026362400267440ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package oauth import ( "context" "encoding/json" "fmt" "io" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs" "github.com/google/uuid" ) // ResolveEndpointer contains the methods for resolving authority endpoints. type ResolveEndpointer interface { ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) } // AccessTokens contains the methods for fetching tokens from different sources. type AccessTokens interface { DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (accesstokens.DeviceCodeResult, error) FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error) FromAuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error) FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (accesstokens.TokenResponse, error) FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (accesstokens.TokenResponse, error) FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (accesstokens.TokenResponse, error) FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion string, assertion string) (accesstokens.TokenResponse, error) FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult accesstokens.DeviceCodeResult) (accesstokens.TokenResponse, error) FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (accesstokens.TokenResponse, error) } // FetchAuthority will be implemented by authority.Authority. type FetchAuthority interface { UserRealm(context.Context, authority.AuthParams) (authority.UserRealm, error) AADInstanceDiscovery(context.Context, authority.Info) (authority.InstanceDiscoveryResponse, error) } // FetchWSTrust contains the methods for interacting with WSTrust endpoints. type FetchWSTrust interface { Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error) SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (wstrust.SamlTokenInfo, error) } // Client provides tokens for various types of token requests. type Client struct { Resolver ResolveEndpointer AccessTokens AccessTokens Authority FetchAuthority WSTrust FetchWSTrust } // New is the constructor for Token. func New(httpClient ops.HTTPClient) *Client { r := ops.New(httpClient) return &Client{ Resolver: newAuthorityEndpoint(r), AccessTokens: r.AccessTokens(), Authority: r.Authority(), WSTrust: r.WSTrust(), } } // ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance. func (t *Client) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) { return t.Resolver.ResolveEndpoints(ctx, authorityInfo, userPrincipalName) } // AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint). // This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com). func (t *Client) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) { return t.Authority.AADInstanceDiscovery(ctx, authorityInfo) } // AuthCode returns a token based on an authorization code. func (t *Client) AuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) { if err := scopeError(req.AuthParams); err != nil { return accesstokens.TokenResponse{}, err } if err := t.resolveEndpoint(ctx, &req.AuthParams, ""); err != nil { return accesstokens.TokenResponse{}, err } tResp, err := t.AccessTokens.FromAuthCode(ctx, req) if err != nil { return accesstokens.TokenResponse{}, fmt.Errorf("could not retrieve token from auth code: %w", err) } return tResp, nil } // Credential acquires a token from the authority using a client credentials grant. func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) { if cred.TokenProvider != nil { now := time.Now() scopes := make([]string, len(authParams.Scopes)) copy(scopes, authParams.Scopes) params := exported.TokenProviderParameters{ Claims: authParams.Claims, CorrelationID: uuid.New().String(), Scopes: scopes, TenantID: authParams.AuthorityInfo.Tenant, } tr, err := cred.TokenProvider(ctx, params) if err != nil { if len(scopes) == 0 { err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err) return accesstokens.TokenResponse{}, err } return accesstokens.TokenResponse{}, err } return accesstokens.TokenResponse{ AccessToken: tr.AccessToken, ExpiresOn: internalTime.DurationTime{ T: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second), }, GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes}, }, nil } if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { return accesstokens.TokenResponse{}, err } if cred.Secret != "" { return t.AccessTokens.FromClientSecret(ctx, authParams, cred.Secret) } jwt, err := cred.JWT(ctx, authParams) if err != nil { return accesstokens.TokenResponse{}, err } return t.AccessTokens.FromAssertion(ctx, authParams, jwt) } // Credential acquires a token from the authority using a client credentials grant. func (t *Client) OnBehalfOf(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) { if err := scopeError(authParams); err != nil { return accesstokens.TokenResponse{}, err } if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { return accesstokens.TokenResponse{}, err } if cred.Secret != "" { return t.AccessTokens.FromUserAssertionClientSecret(ctx, authParams, authParams.UserAssertion, cred.Secret) } jwt, err := cred.JWT(ctx, authParams) if err != nil { return accesstokens.TokenResponse{}, err } tr, err := t.AccessTokens.FromUserAssertionClientCertificate(ctx, authParams, authParams.UserAssertion, jwt) if err != nil { return accesstokens.TokenResponse{}, err } return tr, nil } func (t *Client) Refresh(ctx context.Context, reqType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken accesstokens.RefreshToken) (accesstokens.TokenResponse, error) { if err := scopeError(authParams); err != nil { return accesstokens.TokenResponse{}, err } if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { return accesstokens.TokenResponse{}, err } tr, err := t.AccessTokens.FromRefreshToken(ctx, reqType, authParams, cc, refreshToken.Secret) if err != nil { return accesstokens.TokenResponse{}, err } return tr, nil } // UsernamePassword retrieves a token where a username and password is used. However, if this is // a user realm of "Federated", this uses SAML tokens. If "Managed", uses normal username/password. func (t *Client) UsernamePassword(ctx context.Context, authParams authority.AuthParams) (accesstokens.TokenResponse, error) { if err := scopeError(authParams); err != nil { return accesstokens.TokenResponse{}, err } if authParams.AuthorityInfo.AuthorityType == authority.ADFS { if err := t.resolveEndpoint(ctx, &authParams, authParams.Username); err != nil { return accesstokens.TokenResponse{}, err } return t.AccessTokens.FromUsernamePassword(ctx, authParams) } if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { return accesstokens.TokenResponse{}, err } userRealm, err := t.Authority.UserRealm(ctx, authParams) if err != nil { return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm from authority: %w", err) } switch userRealm.AccountType { case authority.Federated: mexDoc, err := t.WSTrust.Mex(ctx, userRealm.FederationMetadataURL) if err != nil { err = fmt.Errorf("problem getting mex doc from federated url(%s): %w", userRealm.FederationMetadataURL, err) return accesstokens.TokenResponse{}, err } saml, err := t.WSTrust.SAMLTokenInfo(ctx, authParams, userRealm.CloudAudienceURN, mexDoc.UsernamePasswordEndpoint) if err != nil { err = fmt.Errorf("problem getting SAML token info: %w", err) return accesstokens.TokenResponse{}, err } tr, err := t.AccessTokens.FromSamlGrant(ctx, authParams, saml) if err != nil { return accesstokens.TokenResponse{}, err } return tr, nil case authority.Managed: if len(authParams.Scopes) == 0 { err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err) return accesstokens.TokenResponse{}, err } return t.AccessTokens.FromUsernamePassword(ctx, authParams) } return accesstokens.TokenResponse{}, errors.New("unknown account type") } // DeviceCode is the result of a call to Token.DeviceCode(). type DeviceCode struct { // Result is the device code result from the first call in the device code flow. This allows // the caller to retrieve the displayed code that is used to authorize on the second device. Result accesstokens.DeviceCodeResult authParams authority.AuthParams accessTokens AccessTokens } // Token returns a token AFTER the user uses the user code on the second device. This will block // until either: (1) the code is input by the user and the service releases a token, (2) the token // expires, (3) the Context passed to .DeviceCode() is cancelled or expires, (4) some other service // error occurs. func (d DeviceCode) Token(ctx context.Context) (accesstokens.TokenResponse, error) { if d.accessTokens == nil { return accesstokens.TokenResponse{}, fmt.Errorf("DeviceCode was either created outside its package or the creating method had an error. DeviceCode is not valid") } var cancel context.CancelFunc if deadline, ok := ctx.Deadline(); !ok || d.Result.ExpiresOn.Before(deadline) { ctx, cancel = context.WithDeadline(ctx, d.Result.ExpiresOn) } else { ctx, cancel = context.WithCancel(ctx) } defer cancel() var interval = 50 * time.Millisecond timer := time.NewTimer(interval) defer timer.Stop() for { timer.Reset(interval) select { case <-ctx.Done(): return accesstokens.TokenResponse{}, ctx.Err() case <-timer.C: interval += interval * 2 if interval > 5*time.Second { interval = 5 * time.Second } } token, err := d.accessTokens.FromDeviceCodeResult(ctx, d.authParams, d.Result) if err != nil && isWaitDeviceCodeErr(err) { continue } return token, err // This handles if it was a non-wait error or success } } type deviceCodeError struct { Error string `json:"error"` } func isWaitDeviceCodeErr(err error) bool { var c errors.CallErr if !errors.As(err, &c) { return false } if c.Resp.StatusCode != 400 { return false } var dCErr deviceCodeError defer c.Resp.Body.Close() body, err := io.ReadAll(c.Resp.Body) if err != nil { return false } err = json.Unmarshal(body, &dCErr) if err != nil { return false } if dCErr.Error == "authorization_pending" || dCErr.Error == "slow_down" { return true } return false } // DeviceCode returns a DeviceCode object that can be used to get the code that must be entered on the second // device and optionally the token once the code has been entered on the second device. func (t *Client) DeviceCode(ctx context.Context, authParams authority.AuthParams) (DeviceCode, error) { if err := scopeError(authParams); err != nil { return DeviceCode{}, err } if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { return DeviceCode{}, err } dcr, err := t.AccessTokens.DeviceCodeResult(ctx, authParams) if err != nil { return DeviceCode{}, err } return DeviceCode{Result: dcr, authParams: authParams, accessTokens: t.AccessTokens}, nil } func (t *Client) resolveEndpoint(ctx context.Context, authParams *authority.AuthParams, userPrincipalName string) error { endpoints, err := t.Resolver.ResolveEndpoints(ctx, authParams.AuthorityInfo, userPrincipalName) if err != nil { return fmt.Errorf("unable to resolve an endpoint: %s", err) } authParams.Endpoints = endpoints return nil } // scopeError takes an authority.AuthParams and returns an error // if len(AuthParams.Scope) == 0. func scopeError(a authority.AuthParams) error { // TODO(someone): we could look deeper at the message to determine if // it's a scope error, but this is a good start. /* {error":"invalid_scope","error_description":"AADSTS1002012: The provided value for scope openid offline_access profile is not valid. Client credential flows must have a scope value with /.default suffixed to the resource identifier (application ID URI)...} */ if len(a.Scopes) == 0 { return fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which is invalid") } return nil } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/oauth_test.go000066400000000000000000000251601442026362400277770ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package oauth // NOTE: These tests cover that we handle errors from other lower level modules. // We don't actually care about a TokenResponse{}, that is gathered from a remote system // and they are tested via intergration tests (data retrieved from one system and passed from // to another). We care about execution behavior (service X says there is an error and we handle it, // we require .X is set and input doesn't have it, ...) import ( "bytes" "context" "crypto/x509" "io" "net/http" "testing" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" ) var testScopes = []string{"scope"} func TestAuthCode(t *testing.T) { tests := []struct { desc string re fake.ResolveEndpoints at *fake.AccessTokens err bool }{ { desc: "Error: Unable to resolve endpoints", re: fake.ResolveEndpoints{Err: true}, at: &fake.AccessTokens{}, err: true, }, { desc: "Error: REST access token error", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{Err: true}, err: true, }, { desc: "Success", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{}, }, } token := &Client{} for _, test := range tests { token.AccessTokens = test.at token.Resolver = test.re _, err := token.AuthCode(context.Background(), accesstokens.AuthCodeRequest{AuthParams: authority.AuthParams{Scopes: testScopes}}) switch { case err == nil && test.err: t.Errorf("TestAuthCode(%s): got err == nil, want err != nil", test.desc) case err != nil && !test.err: t.Errorf("TestAuthCode(%s): got err == %s, want err == nil", test.desc, err) } } } func TestCredential(t *testing.T) { callback := func(context.Context, exported.AssertionRequestOptions) (string, error) { return "assertion", nil } tests := []struct { desc string re fake.ResolveEndpoints at *fake.AccessTokens authParams authority.AuthParams cred *accesstokens.Credential err bool }{ { desc: "Error: Unable to resolve endpoints", re: fake.ResolveEndpoints{Err: true}, at: &fake.AccessTokens{}, cred: &accesstokens.Credential{ AssertionCallback: callback, }, err: true, }, { desc: "Error: REST access token error on secret", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{Err: true}, cred: &accesstokens.Credential{ AssertionCallback: callback, }, err: true, }, { desc: "Error: could not generate JWT from cred assertion", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{Err: true}, cred: &accesstokens.Credential{ AssertionCallback: callback, Cert: &x509.Certificate{}, // Key is nil and causes token.SignedString(c.Key) to fail in Credential.JWT() }, err: true, }, { desc: "Error: REST access token error on assertion", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{Err: true}, cred: &accesstokens.Credential{ AssertionCallback: callback, }, err: true, }, { desc: "Success: secret cred", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{}, cred: &accesstokens.Credential{ Secret: "secret", }, }, { desc: "Success: assertion cred", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{}, cred: &accesstokens.Credential{ AssertionCallback: callback, }, }, } token := &Client{} for _, test := range tests { token.AccessTokens = test.at token.Resolver = test.re _, err := token.Credential(context.Background(), test.authParams, test.cred) switch { case err == nil && test.err: t.Errorf("TestCredential(%s): got err == nil, want err != nil", test.desc) case err != nil && !test.err: t.Errorf("TestCredential(%s): got err == %s, want err == nil", test.desc, err) } } } func TestRefresh(t *testing.T) { tests := []struct { desc string re fake.ResolveEndpoints at *fake.AccessTokens err bool }{ { desc: "Error: Unable to resolve endpoints", re: fake.ResolveEndpoints{Err: true}, at: &fake.AccessTokens{}, err: true, }, { desc: "Error: REST access token error", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{Err: true}, err: true, }, { desc: "Success", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{}, }, } token := &Client{} for _, test := range tests { token.AccessTokens = test.at token.Resolver = test.re _, err := token.Refresh( context.Background(), accesstokens.ATPublic, authority.AuthParams{Scopes: testScopes}, &accesstokens.Credential{}, accesstokens.RefreshToken{}, ) switch { case err == nil && test.err: t.Errorf("TestRefresh(%s): got err == nil, want err != nil", test.desc) case err != nil && !test.err: t.Errorf("TestRefresh(%s): got err == %s, want err == nil", test.desc, err) } } } func TestUsernamePassword(t *testing.T) { tests := []struct { desc string re fake.ResolveEndpoints at *fake.AccessTokens au fake.Authority ws fake.WSTrust err bool }{ { desc: "Error: Unable to resolve endpoints", re: fake.ResolveEndpoints{Err: true}, at: &fake.AccessTokens{}, au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Managed}}, err: true, }, { desc: "Error: authority.Federated and Mex() error", re: fake.ResolveEndpoints{Err: false}, at: &fake.AccessTokens{}, au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Federated}}, ws: fake.WSTrust{GetMexErr: true}, err: true, }, { desc: "Error: authority.Federated and SAMLTokenInfo() error", re: fake.ResolveEndpoints{Err: false}, at: &fake.AccessTokens{}, au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Federated}}, ws: fake.WSTrust{GetSAMLTokenInfoErr: true}, err: true, }, { desc: "Error: authority.Federated and GetAccessTokenFromSamlGrant() error", re: fake.ResolveEndpoints{Err: false}, au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Federated}}, at: &fake.AccessTokens{Err: true}, err: true, }, { desc: "Error: authority.Managed and REST access token error", re: fake.ResolveEndpoints{Err: false}, at: &fake.AccessTokens{Err: true}, au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Managed}}, err: true, }, { desc: "Success: authority.Managed", re: fake.ResolveEndpoints{Err: false}, at: &fake.AccessTokens{}, au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Managed}}, }, { desc: "Success: authority.Federated", re: fake.ResolveEndpoints{Err: false}, at: &fake.AccessTokens{}, au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Federated}}, }, } token := &Client{} for _, test := range tests { token.AccessTokens = test.at token.Authority = test.au token.Resolver = test.re token.WSTrust = test.ws _, err := token.UsernamePassword(context.Background(), authority.AuthParams{Scopes: testScopes}) switch { case err == nil && test.err: t.Errorf("TestUsernamePassword(%s): got err == nil, want err != nil", test.desc) case err != nil && !test.err: t.Errorf("TestUsernamePassword(%s): got err == %s, want err == nil", test.desc, err) } } } func TestDeviceCode(t *testing.T) { tests := []struct { desc string dc DeviceCode err bool }{ { desc: "Error: .accessTokens == nil", dc: DeviceCode{}, err: true, }, { desc: "Error: FromDeviceCodeResult() returned a !isWaitDeviceCodeErr", dc: DeviceCode{ accessTokens: &fake.AccessTokens{ Result: []error{ errors.CallErr{ Resp: &http.Response{ StatusCode: 400, Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "authorization_pending"}`))), }, }, errors.CallErr{ Resp: &http.Response{ StatusCode: 400, Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "slow_down"}`))), }, }, errors.CallErr{ Resp: &http.Response{ StatusCode: 400, Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "bad_error"}`))), }, }, nil, }, }, }, err: true, }, { desc: "Success", dc: DeviceCode{ Result: accesstokens.DeviceCodeResult{ ExpiresOn: time.Now().Add(5 * time.Minute), }, accessTokens: &fake.AccessTokens{ Result: []error{ errors.CallErr{ Resp: &http.Response{ StatusCode: 400, Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "authorization_pending"}`))), }, }, errors.CallErr{ Resp: &http.Response{ StatusCode: 400, Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "slow_down"}`))), }, }, nil, }, }, }, }, } for _, test := range tests { _, err := test.dc.Token(context.Background()) switch { case err == nil && test.err: t.Errorf("TestDeviceCode(%s): got err == nil, want err != nil", test.desc) case err != nil && !test.err: t.Errorf("TestDeviceCode(%s): got err == %s, want err == nil", test.desc, err) } } } func TestDeviceCodeToken(t *testing.T) { tests := []struct { desc string re fake.ResolveEndpoints at *fake.AccessTokens err bool }{ { desc: "Error: Unable to resolve endpoints", re: fake.ResolveEndpoints{Err: true}, at: &fake.AccessTokens{}, err: true, }, { desc: "Error: REST access token error", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{Err: true}, err: true, }, { desc: "Success", re: fake.ResolveEndpoints{}, at: &fake.AccessTokens{}, }, } token := &Client{} for _, test := range tests { token.AccessTokens = test.at token.Resolver = test.re dc, err := token.DeviceCode(context.Background(), authority.AuthParams{Scopes: testScopes}) switch { case err == nil && test.err: t.Errorf("TestDeviceCodeToken(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestDeviceCodeToken(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if dc.accessTokens == nil { t.Errorf("TestDeviceCodeToken(%s): got DeviceCode{} back that did not have accessTokens set", test.desc) } } } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/000077500000000000000000000000001442026362400260665ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens/000077500000000000000000000000001442026362400305535ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens/accesstokens.go000066400000000000000000000356761442026362400336100ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. /* Package accesstokens exposes a REST client for querying backend systems to get various types of access tokens (oauth) for use in authentication. These calls are of type "application/x-www-form-urlencoded". This means we use url.Values to represent arguments and then encode them into the POST body message. We receive JSON in return for the requests. The request definition is defined in https://tools.ietf.org/html/rfc7521#section-4.2 . */ package accesstokens import ( "context" "crypto" /* #nosec */ "crypto/sha1" "crypto/x509" "encoding/base64" "encoding/json" "fmt" "net/url" "strconv" "strings" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/internal/grant" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" ) const ( grantType = "grant_type" deviceCode = "device_code" clientID = "client_id" clientInfo = "client_info" clientInfoVal = "1" username = "username" password = "password" ) //go:generate stringer -type=AppType // AppType is whether the authorization code flow is for a public or confidential client. type AppType int8 const ( // ATUnknown is the zero value when the type hasn't been set. ATUnknown AppType = iota // ATPublic indicates this if for the Public.Client. ATPublic // ATConfidential indicates this if for the Confidential.Client. ATConfidential ) type urlFormCaller interface { URLFormCall(ctx context.Context, endpoint string, qv url.Values, resp interface{}) error } // DeviceCodeResponse represents the HTTP response received from the device code endpoint type DeviceCodeResponse struct { authority.OAuthResponseBase UserCode string `json:"user_code"` DeviceCode string `json:"device_code"` VerificationURL string `json:"verification_url"` ExpiresIn int `json:"expires_in"` Interval int `json:"interval"` Message string `json:"message"` AdditionalFields map[string]interface{} } // Convert converts the DeviceCodeResponse to a DeviceCodeResult func (dcr DeviceCodeResponse) Convert(clientID string, scopes []string) DeviceCodeResult { expiresOn := time.Now().UTC().Add(time.Duration(dcr.ExpiresIn) * time.Second) return NewDeviceCodeResult(dcr.UserCode, dcr.DeviceCode, dcr.VerificationURL, expiresOn, dcr.Interval, dcr.Message, clientID, scopes) } // Credential represents the credential used in confidential client flows. This can be either // a Secret or Cert/Key. type Credential struct { // Secret contains the credential secret if we are doing auth by secret. Secret string // Cert is the public certificate, if we're authenticating by certificate. Cert *x509.Certificate // Key is the private key for signing, if we're authenticating by certificate. Key crypto.PrivateKey // X5c is the JWT assertion's x5c header value, required for SN/I authentication. X5c []string // AssertionCallback is a function provided by the application, if we're authenticating by assertion. AssertionCallback func(context.Context, exported.AssertionRequestOptions) (string, error) // TokenProvider is a function provided by the application that implements custom authentication // logic for a confidential client TokenProvider func(context.Context, exported.TokenProviderParameters) (exported.TokenProviderResult, error) } // JWT gets the jwt assertion when the credential is not using a secret. func (c *Credential) JWT(ctx context.Context, authParams authority.AuthParams) (string, error) { if c.AssertionCallback != nil { options := exported.AssertionRequestOptions{ ClientID: authParams.ClientID, TokenEndpoint: authParams.Endpoints.TokenEndpoint, } return c.AssertionCallback(ctx, options) } token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "aud": authParams.Endpoints.TokenEndpoint, "exp": json.Number(strconv.FormatInt(time.Now().Add(10*time.Minute).Unix(), 10)), "iss": authParams.ClientID, "jti": uuid.New().String(), "nbf": json.Number(strconv.FormatInt(time.Now().Unix(), 10)), "sub": authParams.ClientID, }) token.Header = map[string]interface{}{ "alg": "RS256", "typ": "JWT", "x5t": base64.StdEncoding.EncodeToString(thumbprint(c.Cert)), } if authParams.SendX5C { token.Header["x5c"] = c.X5c } assertion, err := token.SignedString(c.Key) if err != nil { return "", fmt.Errorf("unable to sign a JWT token using private key: %w", err) } return assertion, nil } // thumbprint runs the asn1.Der bytes through sha1 for use in the x5t parameter of JWT. // https://tools.ietf.org/html/rfc7517#section-4.8 func thumbprint(cert *x509.Certificate) []byte { /* #nosec */ a := sha1.Sum(cert.Raw) return a[:] } // Client represents the REST calls to get tokens from token generator backends. type Client struct { // Comm provides the HTTP transport client. Comm urlFormCaller testing bool } // FromUsernamePassword uses a username and password to get an access token. func (c Client) FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) { qv := url.Values{} if err := addClaims(qv, authParameters); err != nil { return TokenResponse{}, err } qv.Set(grantType, grant.Password) qv.Set(username, authParameters.Username) qv.Set(password, authParameters.Password) qv.Set(clientID, authParameters.ClientID) qv.Set(clientInfo, clientInfoVal) addScopeQueryParam(qv, authParameters) return c.doTokenResp(ctx, authParameters, qv) } // AuthCodeRequest stores the values required to request a token from the authority using an authorization code type AuthCodeRequest struct { AuthParams authority.AuthParams Code string CodeChallenge string Credential *Credential AppType AppType } // NewCodeChallengeRequest returns an AuthCodeRequest that uses a code challenge.. func NewCodeChallengeRequest(params authority.AuthParams, appType AppType, cc *Credential, code, challenge string) (AuthCodeRequest, error) { if appType == ATUnknown { return AuthCodeRequest{}, fmt.Errorf("bug: NewCodeChallengeRequest() called with AppType == ATUnknown") } return AuthCodeRequest{ AuthParams: params, AppType: appType, Code: code, CodeChallenge: challenge, Credential: cc, }, nil } // FromAuthCode uses an authorization code to retrieve an access token. func (c Client) FromAuthCode(ctx context.Context, req AuthCodeRequest) (TokenResponse, error) { var qv url.Values switch req.AppType { case ATUnknown: return TokenResponse{}, fmt.Errorf("bug: Token.AuthCode() received request with AppType == ATUnknown") case ATConfidential: var err error if req.Credential == nil { return TokenResponse{}, fmt.Errorf("AuthCodeRequest had nil Credential for Confidential app") } qv, err = prepURLVals(ctx, req.Credential, req.AuthParams) if err != nil { return TokenResponse{}, err } case ATPublic: qv = url.Values{} default: return TokenResponse{}, fmt.Errorf("bug: Token.AuthCode() received request with AppType == %v, which we do not recongnize", req.AppType) } qv.Set(grantType, grant.AuthCode) qv.Set("code", req.Code) qv.Set("code_verifier", req.CodeChallenge) qv.Set("redirect_uri", req.AuthParams.Redirecturi) qv.Set(clientID, req.AuthParams.ClientID) qv.Set(clientInfo, clientInfoVal) addScopeQueryParam(qv, req.AuthParams) if err := addClaims(qv, req.AuthParams); err != nil { return TokenResponse{}, err } return c.doTokenResp(ctx, req.AuthParams, qv) } // FromRefreshToken uses a refresh token (for refreshing credentials) to get a new access token. func (c Client) FromRefreshToken(ctx context.Context, appType AppType, authParams authority.AuthParams, cc *Credential, refreshToken string) (TokenResponse, error) { qv := url.Values{} if appType == ATConfidential { var err error qv, err = prepURLVals(ctx, cc, authParams) if err != nil { return TokenResponse{}, err } } if err := addClaims(qv, authParams); err != nil { return TokenResponse{}, err } qv.Set(grantType, grant.RefreshToken) qv.Set(clientID, authParams.ClientID) qv.Set(clientInfo, clientInfoVal) qv.Set("refresh_token", refreshToken) addScopeQueryParam(qv, authParams) return c.doTokenResp(ctx, authParams, qv) } // FromClientSecret uses a client's secret (aka password) to get a new token. func (c Client) FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (TokenResponse, error) { qv := url.Values{} if err := addClaims(qv, authParameters); err != nil { return TokenResponse{}, err } qv.Set(grantType, grant.ClientCredential) qv.Set("client_secret", clientSecret) qv.Set(clientID, authParameters.ClientID) addScopeQueryParam(qv, authParameters) token, err := c.doTokenResp(ctx, authParameters, qv) if err != nil { return token, fmt.Errorf("FromClientSecret(): %w", err) } return token, nil } func (c Client) FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (TokenResponse, error) { qv := url.Values{} if err := addClaims(qv, authParameters); err != nil { return TokenResponse{}, err } qv.Set(grantType, grant.ClientCredential) qv.Set("client_assertion_type", grant.ClientAssertion) qv.Set("client_assertion", assertion) qv.Set(clientID, authParameters.ClientID) qv.Set(clientInfo, clientInfoVal) addScopeQueryParam(qv, authParameters) token, err := c.doTokenResp(ctx, authParameters, qv) if err != nil { return token, fmt.Errorf("FromAssertion(): %w", err) } return token, nil } func (c Client) FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (TokenResponse, error) { qv := url.Values{} if err := addClaims(qv, authParameters); err != nil { return TokenResponse{}, err } qv.Set(grantType, grant.JWT) qv.Set(clientID, authParameters.ClientID) qv.Set("client_secret", clientSecret) qv.Set("assertion", userAssertion) qv.Set(clientInfo, clientInfoVal) qv.Set("requested_token_use", "on_behalf_of") addScopeQueryParam(qv, authParameters) return c.doTokenResp(ctx, authParameters, qv) } func (c Client) FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion string, assertion string) (TokenResponse, error) { qv := url.Values{} if err := addClaims(qv, authParameters); err != nil { return TokenResponse{}, err } qv.Set(grantType, grant.JWT) qv.Set("client_assertion_type", grant.ClientAssertion) qv.Set("client_assertion", assertion) qv.Set(clientID, authParameters.ClientID) qv.Set("assertion", userAssertion) qv.Set(clientInfo, clientInfoVal) qv.Set("requested_token_use", "on_behalf_of") addScopeQueryParam(qv, authParameters) return c.doTokenResp(ctx, authParameters, qv) } func (c Client) DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (DeviceCodeResult, error) { qv := url.Values{} if err := addClaims(qv, authParameters); err != nil { return DeviceCodeResult{}, err } qv.Set(clientID, authParameters.ClientID) addScopeQueryParam(qv, authParameters) endpoint := strings.Replace(authParameters.Endpoints.TokenEndpoint, "token", "devicecode", -1) resp := DeviceCodeResponse{} err := c.Comm.URLFormCall(ctx, endpoint, qv, &resp) if err != nil { return DeviceCodeResult{}, err } return resp.Convert(authParameters.ClientID, authParameters.Scopes), nil } func (c Client) FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult DeviceCodeResult) (TokenResponse, error) { qv := url.Values{} if err := addClaims(qv, authParameters); err != nil { return TokenResponse{}, err } qv.Set(grantType, grant.DeviceCode) qv.Set(deviceCode, deviceCodeResult.DeviceCode) qv.Set(clientID, authParameters.ClientID) qv.Set(clientInfo, clientInfoVal) addScopeQueryParam(qv, authParameters) return c.doTokenResp(ctx, authParameters, qv) } func (c Client) FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (TokenResponse, error) { qv := url.Values{} if err := addClaims(qv, authParameters); err != nil { return TokenResponse{}, err } qv.Set(username, authParameters.Username) qv.Set(password, authParameters.Password) qv.Set(clientID, authParameters.ClientID) qv.Set(clientInfo, clientInfoVal) qv.Set("assertion", base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString([]byte(samlGrant.Assertion))) addScopeQueryParam(qv, authParameters) switch samlGrant.AssertionType { case grant.SAMLV1: qv.Set(grantType, grant.SAMLV1) case grant.SAMLV2: qv.Set(grantType, grant.SAMLV2) default: return TokenResponse{}, fmt.Errorf("GetAccessTokenFromSamlGrant returned unknown SAML assertion type: %q", samlGrant.AssertionType) } return c.doTokenResp(ctx, authParameters, qv) } func (c Client) doTokenResp(ctx context.Context, authParams authority.AuthParams, qv url.Values) (TokenResponse, error) { resp := TokenResponse{} err := c.Comm.URLFormCall(ctx, authParams.Endpoints.TokenEndpoint, qv, &resp) if err != nil { return resp, err } resp.ComputeScope(authParams) if c.testing { return resp, nil } return resp, resp.Validate() } // prepURLVals returns an url.Values that sets various key/values if we are doing secrets // or JWT assertions. func prepURLVals(ctx context.Context, cc *Credential, authParams authority.AuthParams) (url.Values, error) { params := url.Values{} if cc.Secret != "" { params.Set("client_secret", cc.Secret) return params, nil } jwt, err := cc.JWT(ctx, authParams) if err != nil { return nil, err } params.Set("client_assertion", jwt) params.Set("client_assertion_type", grant.ClientAssertion) return params, nil } // openid required to get an id token // offline_access required to get a refresh token // profile required to get the client_info field back var detectDefaultScopes = map[string]bool{ "openid": true, "offline_access": true, "profile": true, } var defaultScopes = []string{"openid", "offline_access", "profile"} func AppendDefaultScopes(authParameters authority.AuthParams) []string { scopes := make([]string, 0, len(authParameters.Scopes)+len(defaultScopes)) for _, scope := range authParameters.Scopes { s := strings.TrimSpace(scope) if s == "" { continue } if detectDefaultScopes[scope] { continue } scopes = append(scopes, scope) } scopes = append(scopes, defaultScopes...) return scopes } // addClaims adds client capabilities and claims from AuthParams to the given url.Values func addClaims(v url.Values, ap authority.AuthParams) error { claims, err := ap.MergeCapabilitiesAndClaims() if err == nil && claims != "" { v.Set("claims", claims) } return err } func addScopeQueryParam(queryParams url.Values, authParameters authority.AuthParams) { scopes := AppendDefaultScopes(authParameters) queryParams.Set("scope", strings.Join(scopes, " ")) } accesstokens_test.go000066400000000000000000000631661442026362400345630ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package accesstokens import ( "context" "encoding/base64" "errors" "fmt" "net/url" "reflect" "strings" "testing" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/internal/grant" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" "github.com/kylelemons/godebug/pretty" ) var testAuthorityEndpoints = authority.NewEndpoints( "https://login.microsoftonline.com/v2.0/authorize", "https://login.microsoftonline.com/v2.0/token", "https://login.microsoftonline.com/v2.0", "login.microsoftonline.com", ) var jwtDecoderFake = func(s string) ([]byte, error) { if s == "error" { return nil, errors.New("error") } return []byte(s), nil } type fakeURLCaller struct { err bool gotEndpoint string gotQV url.Values gotResp interface{} } func (f *fakeURLCaller) URLFormCall(ctx context.Context, endpoint string, qv url.Values, resp interface{}) error { if f.err { return errors.New("error") } f.gotEndpoint = endpoint f.gotQV = qv f.gotResp = resp return nil } func (f *fakeURLCaller) compare(endpoint string, qv url.Values) error { if f.gotEndpoint != endpoint { return fmt.Errorf("got endpoint == %s, want endpoint == %s", f.gotEndpoint, endpoint) } if diff := pretty.Compare(qv, f.gotQV); diff != "" { return fmt.Errorf("qv -want/+got:\n%s", diff) } return nil } func TestAccessTokenFromUsernamePassword(t *testing.T) { authParams := authority.AuthParams{ Username: "username", Password: "password", Endpoints: testAuthorityEndpoints, ClientID: "clientID", } tests := []struct { desc string err bool commErr bool createErr bool qv url.Values }{ { desc: "Error: comm returns error", err: true, commErr: true, }, { desc: "Success", qv: url.Values{ grantType: []string{grant.Password}, username: []string{authParams.Username}, password: []string{authParams.Password}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, }, } for _, test := range tests { if test.qv != nil { addScopeQueryParam(test.qv, authParams) } fake := &fakeURLCaller{err: test.commErr} client := Client{Comm: fake, testing: true} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.FromUsernamePassword(context.Background(), authParams) switch { case err == nil && test.err: t.Errorf("TestAccessTokenFromUsernamePassword(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestAccessTokenFromUsernamePassword(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil { t.Errorf("TestAccessTokenFromUsernamePassword(%s): %s", test.desc, err) } } } func TestAccessTokenFromAuthCode(t *testing.T) { authParams := authority.AuthParams{ Endpoints: testAuthorityEndpoints, ClientID: "clientID", Redirecturi: "redirectURI", } tests := []struct { desc string err bool // commErr causes the comm call to return an error. commErr bool // createErr causes the TokenResponse creation to error. createErr bool authCodeRequest AuthCodeRequest authCode string codeVerifier string qv url.Values }{ { desc: "Error: comm returns error", err: true, commErr: true, authCodeRequest: AuthCodeRequest{ AuthParams: authParams, Code: "authCode", CodeChallenge: "codeVerifier", Credential: &Credential{Secret: "secret"}, AppType: ATConfidential, }, qv: url.Values{ "code": []string{"authCode"}, "code_verifier": []string{"codeVerifier"}, "redirect_uri": []string{"redirectURI"}, grantType: []string{grant.AuthCode}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, }, { desc: "Error: Credential is nil", authCodeRequest: AuthCodeRequest{ AuthParams: authParams, Code: "authCode", CodeChallenge: "codeVerifier", AppType: ATConfidential, }, qv: url.Values{ "code": []string{"authCode"}, "code_verifier": []string{"codeVerifier"}, "redirect_uri": []string{"redirectURI"}, grantType: []string{grant.AuthCode}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, err: true, }, { desc: "Success", authCodeRequest: AuthCodeRequest{ AuthParams: authParams, Code: "authCode", CodeChallenge: "codeVerifier", AppType: ATConfidential, Credential: &Credential{Secret: "secret"}, }, qv: url.Values{ "code": []string{"authCode"}, "code_verifier": []string{"codeVerifier"}, "redirect_uri": []string{"redirectURI"}, "client_secret": []string{"secret"}, grantType: []string{grant.AuthCode}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, }, } for _, test := range tests { if test.qv != nil { addScopeQueryParam(test.qv, authParams) } fake := &fakeURLCaller{err: test.err} client := Client{Comm: fake, testing: true} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.FromAuthCode(context.Background(), test.authCodeRequest) switch { case err == nil && test.err: t.Errorf("TestAccessTokenFromAuthCode(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestAccessTokenFromAuthCode(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil { t.Errorf("TestAccessTokenFromAuthCode(%s): %s", test.desc, err) } } } func TestAccessTokenFromRefreshToken(t *testing.T) { authParams := authority.AuthParams{ Endpoints: testAuthorityEndpoints, ClientID: "clientID", Redirecturi: "redirectURI", } tests := []struct { desc string err bool commErr bool createErr bool cred *Credential refreshToken string qv url.Values }{ { desc: "Error: comm returns error", err: true, commErr: true, refreshToken: "refreshToken", qv: url.Values{ "refresh_token": []string{"refreshToken"}, grantType: []string{grant.RefreshToken}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, }, { desc: "Success(public app)", refreshToken: "refreshToken", qv: url.Values{ "refresh_token": []string{"refreshToken"}, grantType: []string{grant.RefreshToken}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, }, { desc: "Success(confidential app)", refreshToken: "refreshToken", qv: url.Values{ "refresh_token": []string{"refreshToken"}, grantType: []string{grant.RefreshToken}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, }, } for _, test := range tests { if test.qv != nil { addScopeQueryParam(test.qv, authParams) } fake := &fakeURLCaller{err: test.commErr} client := Client{Comm: fake, testing: true} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.FromRefreshToken(context.Background(), ATPublic, authParams, test.cred, test.refreshToken) switch { case err == nil && test.err: t.Errorf("TestAccessTokenFromRefreshToken(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestAccessTokenFromRefreshToken(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil { t.Errorf("TestAccessTokenFromRefreshToken(%s): %s", test.desc, err) } } } func TestAccessTokenWithClientSecret(t *testing.T) { authParams := authority.AuthParams{ Endpoints: testAuthorityEndpoints, ClientID: "clientID", Redirecturi: "redirectURI", } tests := []struct { desc string err bool commErr bool createErr bool clientSecret string qv url.Values }{ { desc: "Error: comm returns error", err: true, commErr: true, clientSecret: "clientSecret", qv: url.Values{ "client_secret": []string{"clientSecret"}, grantType: []string{grant.ClientCredential}, clientID: []string{authParams.ClientID}, }, }, { desc: "Success", clientSecret: "clientSecret", qv: url.Values{ "client_secret": []string{"clientSecret"}, grantType: []string{grant.ClientCredential}, clientID: []string{authParams.ClientID}, }, }, } for _, test := range tests { if test.qv != nil { addScopeQueryParam(test.qv, authParams) } fake := &fakeURLCaller{err: test.err} client := Client{Comm: fake, testing: true} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.FromClientSecret(context.Background(), authParams, test.clientSecret) switch { case err == nil && test.err: t.Errorf("TestAccessTokenWithClientSecret(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestAccessTokenWithClientSecret(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil { t.Errorf("TestAccessTokenWithClientSecret(%s): %s", test.desc, err) } } } func TestAccessTokenWithAssertion(t *testing.T) { authParams := authority.AuthParams{ Endpoints: testAuthorityEndpoints, ClientID: "clientID", Redirecturi: "redirectURI", } tests := []struct { desc string err bool commErr bool createErr bool assertion string params url.Values qv url.Values }{ { desc: "Error: comm returns error", err: true, commErr: true, assertion: "assertion", qv: url.Values{ "client_assertion_type": []string{grant.ClientAssertion}, "client_assertion": []string{"assertion"}, grantType: []string{grant.ClientCredential}, clientInfo: []string{clientInfoVal}, clientID: []string{authParams.ClientID}, }, }, { desc: "Success", assertion: "assertion", qv: url.Values{ "client_assertion_type": []string{grant.ClientAssertion}, "client_assertion": []string{"assertion"}, grantType: []string{grant.ClientCredential}, clientInfo: []string{clientInfoVal}, clientID: []string{authParams.ClientID}, }, }, } for _, test := range tests { if test.qv != nil { addScopeQueryParam(test.qv, authParams) } fake := &fakeURLCaller{err: test.commErr} client := Client{Comm: fake, testing: true} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.FromAssertion(context.Background(), authParams, test.assertion) switch { case err == nil && test.err: t.Errorf("TestAccessTokenWithAssertion(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestAccessTokenWithAssertion(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil { t.Errorf("TestAccessTokenWithAssertion(%s): %s", test.desc, err) } } } func TestDeviceCodeResult(t *testing.T) { authParams := authority.AuthParams{ Endpoints: testAuthorityEndpoints, ClientID: "clientID", Redirecturi: "redirectURI", } tests := []struct { desc string err bool commErr bool createErr bool assertion string params url.Values qv url.Values }{ { desc: "Error: comm returns error", err: true, commErr: true, qv: url.Values{ clientID: []string{authParams.ClientID}, }, }, { desc: "Success", qv: url.Values{ clientID: []string{authParams.ClientID}, }, }, } for _, test := range tests { if test.qv != nil { addScopeQueryParam(test.qv, authParams) } fake := &fakeURLCaller{err: test.commErr} client := Client{Comm: fake, testing: true} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.DeviceCodeResult(context.Background(), authParams) switch { case err == nil && test.err: t.Errorf("TestDeviceCodeResult(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestDeviceCodeResult(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } wantEndpoint := strings.Replace(authParams.Endpoints.TokenEndpoint, "token", "devicecode", -1) if err := fake.compare(wantEndpoint, test.qv); err != nil { t.Errorf("TestDeviceCodeResult(%s): %s", test.desc, err) } } } func TestFromDeviceCodeResult(t *testing.T) { authParams := authority.AuthParams{ Endpoints: testAuthorityEndpoints, ClientID: "clientID", Redirecturi: "redirectURI", } tests := []struct { desc string err bool commErr bool createErr bool deviceCodeResult DeviceCodeResult qv url.Values }{ { desc: "Error: comm returns error", err: true, commErr: true, deviceCodeResult: NewDeviceCodeResult( "userCode", "deviceCode", "verificationURL", time.Now(), 1, "message", "clientID", nil, ), qv: url.Values{ deviceCode: []string{"deviceCode"}, grantType: []string{grant.DeviceCode}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, }, { desc: "Success", deviceCodeResult: NewDeviceCodeResult( "userCode", "deviceCode", "verificationURL", time.Now(), 1, "message", "clientID", nil, ), qv: url.Values{ deviceCode: []string{"deviceCode"}, grantType: []string{grant.DeviceCode}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, }, }, } for _, test := range tests { if test.qv != nil { addScopeQueryParam(test.qv, authParams) } fake := &fakeURLCaller{err: test.commErr} client := Client{Comm: fake, testing: true} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.FromDeviceCodeResult(context.Background(), authParams, test.deviceCodeResult) switch { case err == nil && test.err: t.Errorf("TestFromDeviceCodeResult(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestFromDeviceCodeResult(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil { t.Errorf("TestFromDeviceCodeResult(%s): %s", test.desc, err) } } } func TestAccessTokenFromSamlGrant(t *testing.T) { authParams := authority.AuthParams{ Username: "username", Password: "password", Endpoints: testAuthorityEndpoints, ClientID: "clientID", } base64Assertion := base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString([]byte("assertion")) tests := []struct { desc string err bool commErr bool createErr bool samlGrant wstrust.SamlTokenInfo qv url.Values }{ { desc: "Error: comm returns error", err: true, commErr: true, samlGrant: wstrust.SamlTokenInfo{ AssertionType: grant.SAMLV1, Assertion: "assertion", }, qv: url.Values{ username: []string{"username"}, password: []string{"password"}, grantType: []string{grant.SAMLV1}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, "assertion": []string{base64Assertion}, }, }, { desc: "Error: unknown grant type(empty space)", err: true, samlGrant: wstrust.SamlTokenInfo{ Assertion: "assertion", }, qv: url.Values{ username: []string{"username"}, password: []string{"password"}, grantType: []string{grant.SAMLV1}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, "assertion": []string{base64Assertion}, }, }, { desc: "Success: SAMLV1Grant", samlGrant: wstrust.SamlTokenInfo{ AssertionType: grant.SAMLV1, Assertion: "assertion", }, qv: url.Values{ username: []string{"username"}, password: []string{"password"}, grantType: []string{grant.SAMLV1}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, "assertion": []string{base64Assertion}, }, }, { desc: "Success: SAMLV2Grant", samlGrant: wstrust.SamlTokenInfo{ AssertionType: grant.SAMLV2, Assertion: "assertion", }, qv: url.Values{ username: []string{"username"}, password: []string{"password"}, grantType: []string{grant.SAMLV2}, clientID: []string{authParams.ClientID}, clientInfo: []string{clientInfoVal}, "assertion": []string{base64Assertion}, }, }, } for _, test := range tests { if test.qv != nil { addScopeQueryParam(test.qv, authParams) } fake := &fakeURLCaller{err: test.commErr} client := Client{Comm: fake, testing: true} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.FromSamlGrant(context.Background(), authParams, test.samlGrant) switch { case err == nil && test.err: t.Errorf("TestAccessTokenFromSamlGrant(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestAccessTokenFromSamlGrant(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil { t.Errorf("TestAccessTokenFromSamlGrant(%s): %s", test.desc, err) } } } func TestDecodeJWT(t *testing.T) { encodedStr := "A-z_4ME" expectedStr := []byte{3, 236, 255, 224, 193} actualString, err := decodeJWT(encodedStr) if err != nil { t.Errorf("Error should be nil but it is %v", err) } if !reflect.DeepEqual(expectedStr, actualString) { t.Errorf("Actual decoded string %s differs from expected decoded string %s", actualString, expectedStr) } } func TestLocalAccountID(t *testing.T) { id := &IDToken{ Subject: "sub", } actualLID := id.LocalAccountID() if !reflect.DeepEqual("sub", actualLID) { t.Errorf("Expected local account ID sub differs from actual local account ID %s", actualLID) } id.Oid = "oid" actualLID = id.LocalAccountID() if !reflect.DeepEqual("oid", actualLID) { t.Errorf("Expected local account ID oid differs from actual local account ID %s", actualLID) } } func TestTokenResponseUnmarshal(t *testing.T) { tests := []struct { desc string payload string want TokenResponse jwtDecoder func(data string) ([]byte, error) err bool }{ { desc: "Error: decodeJWT is going to error", payload: ` { "access_token": "secret", "expires_in": 86399, "ext_expires_in": 86399, "client_info": error, "scope": "openid profile" }`, err: true, jwtDecoder: jwtDecoderFake, }, { desc: "Success", payload: ` { "access_token": "secret", "expires_in": 86399, "ext_expires_in": 86399, "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" }`, want: TokenResponse{ AccessToken: "secret", ExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ UID: "uid", UTID: "utid", }, }, jwtDecoder: jwtDecoderFake, }, } for _, test := range tests { jwtDecoder = test.jwtDecoder got := TokenResponse{} err := json.Unmarshal([]byte(test.payload), &got) switch { case err == nil && test.err: t.Errorf("TestCreateTokenResponse(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestCreateTokenResponse(%s): got err == %v, want err == nil", test.desc, err) continue case err != nil: continue } // Note: IncludeUnexported prevents minor differences in time.Time due to internal fields. if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" { t.Errorf("TestCreateTokenResponse: -want/+got:\n%s", diff) } } } func TestTokenResponseValidate(t *testing.T) { tests := []struct { desc string input TokenResponse err bool }{ { desc: "Error: TokenResponse had .Error set", input: TokenResponse{ OAuthResponseBase: authority.OAuthResponseBase{ Error: "error", }, AccessToken: "token", scopesComputed: true, }, err: true, }, { desc: "Error: .AccessToken was empty", input: TokenResponse{ scopesComputed: true, }, err: true, }, { desc: "Error: .scopesComputed was false", input: TokenResponse{ AccessToken: "token", scopesComputed: false, }, err: true, }, { desc: "Success", input: TokenResponse{ AccessToken: "token", scopesComputed: true, }, }, } for _, test := range tests { err := test.input.Validate() switch { case err == nil && test.err: t.Errorf("TestTokenResponseValidate(%s): got err == nil, want err != nil", test.desc) case err != nil && !test.err: t.Errorf("TestTokenResponseValidate(%s): got err == %s, want err == nil", test.desc, err) } } } func TestComputeScopes(t *testing.T) { tests := []struct { desc string authParams authority.AuthParams input TokenResponse want TokenResponse }{ { desc: "authParam scopes copied in, no declined scopes", authParams: authority.AuthParams{ Scopes: []string{ "scope0", "scope1", }, }, input: TokenResponse{}, want: TokenResponse{ GrantedScopes: Scopes{ Slice: []string{"scope0", "scope1"}, }, scopesComputed: true, }, }, { desc: "a few declined scopes", authParams: authority.AuthParams{ Scopes: []string{ "scope0", "scope1", "scope2", }, }, input: TokenResponse{ GrantedScopes: Scopes{ Slice: []string{ "scope0", "scope1", }, }, }, want: TokenResponse{ GrantedScopes: Scopes{ Slice: []string{"scope0", "scope1"}, }, DeclinedScopes: []string{"scope2"}, scopesComputed: true, }, }, { desc: "no declined scopes case insensitive", authParams: authority.AuthParams{ Scopes: []string{ "scope0", "scope1", }, }, input: TokenResponse{ GrantedScopes: Scopes{ Slice: []string{ "Scope0", "Scope1", }, }, }, want: TokenResponse{ GrantedScopes: Scopes{ Slice: []string{"Scope0", "Scope1"}, }, DeclinedScopes: nil, scopesComputed: true, }, }, } for _, test := range tests { test.input.ComputeScope(test.authParams) if diff := pretty.Compare(test.want, test.input); diff != "" { t.Errorf("TestComputeScopes(%s): -want/+got:\n%s", test.desc, diff) } } } func TestHomeAccountID(t *testing.T) { tests := []struct { desc string ci ClientInfo want string }{ { desc: "UID and UTID is not set", }, { desc: "UID is not set", ci: ClientInfo{UTID: "utid"}, }, { desc: "UTID is not set", ci: ClientInfo{UID: "uid"}, want: "uid.uid", }, { desc: "UID and UTID are set", ci: ClientInfo{UID: "uid", UTID: "utid"}, want: "uid.utid", }, } for _, test := range tests { got := test.ci.HomeAccountID() if got != test.want { t.Errorf("TestHomeAccountID(%s): got %q, want %q", test.desc, got, test.want) } } } func TestFindDeclinedScopes(t *testing.T) { requestedScopes := []string{"user.read", "openid"} grantedScopes := []string{"user.read"} expectedDeclinedScopes := []string{"openid"} actualDeclinedScopes := findDeclinedScopes(requestedScopes, grantedScopes) if !reflect.DeepEqual(expectedDeclinedScopes, actualDeclinedScopes) { t.Errorf("Actual declined scopes %v differ from expected declined scopes %v", actualDeclinedScopes, expectedDeclinedScopes) } } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens/apptype_string.go000066400000000000000000000012271442026362400341540ustar00rootroot00000000000000// Code generated by "stringer -type=AppType"; DO NOT EDIT. package accesstokens import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[ATUnknown-0] _ = x[ATPublic-1] _ = x[ATConfidential-2] } const _AppType_name = "ATUnknownATPublicATConfidential" var _AppType_index = [...]uint8{0, 9, 17, 31} func (i AppType) String() string { if i < 0 || i >= AppType(len(_AppType_index)-1) { return "AppType(" + strconv.FormatInt(int64(i), 10) + ")" } return _AppType_name[_AppType_index[i]:_AppType_index[i+1]] } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens/tokens.go000066400000000000000000000251051442026362400324100ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package accesstokens import ( "bytes" "encoding/base64" "encoding/json" "errors" "fmt" "reflect" "strings" "time" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) // IDToken consists of all the information used to validate a user. // https://docs.microsoft.com/azure/active-directory/develop/id-tokens . type IDToken struct { PreferredUsername string `json:"preferred_username,omitempty"` GivenName string `json:"given_name,omitempty"` FamilyName string `json:"family_name,omitempty"` MiddleName string `json:"middle_name,omitempty"` Name string `json:"name,omitempty"` Oid string `json:"oid,omitempty"` TenantID string `json:"tid,omitempty"` Subject string `json:"sub,omitempty"` UPN string `json:"upn,omitempty"` Email string `json:"email,omitempty"` AlternativeID string `json:"alternative_id,omitempty"` Issuer string `json:"iss,omitempty"` Audience string `json:"aud,omitempty"` ExpirationTime int64 `json:"exp,omitempty"` IssuedAt int64 `json:"iat,omitempty"` NotBefore int64 `json:"nbf,omitempty"` RawToken string AdditionalFields map[string]interface{} } var null = []byte("null") // UnmarshalJSON implements json.Unmarshaler. func (i *IDToken) UnmarshalJSON(b []byte) error { if bytes.Equal(null, b) { return nil } // Because we have a custom unmarshaler, you // cannot directly call json.Unmarshal here. If you do, it will call this function // recursively until reach our recursion limit. We have to create a new type // that doesn't have this method in order to use json.Unmarshal. type idToken2 IDToken jwt := strings.Trim(string(b), `"`) jwtArr := strings.Split(jwt, ".") if len(jwtArr) < 2 { return errors.New("IDToken returned from server is invalid") } jwtPart := jwtArr[1] jwtDecoded, err := decodeJWT(jwtPart) if err != nil { return fmt.Errorf("unable to unmarshal IDToken, problem decoding JWT: %w", err) } token := idToken2{} err = json.Unmarshal(jwtDecoded, &token) if err != nil { return fmt.Errorf("unable to unmarshal IDToken: %w", err) } token.RawToken = jwt *i = IDToken(token) return nil } // IsZero indicates if the IDToken is the zero value. func (i IDToken) IsZero() bool { v := reflect.ValueOf(i) for i := 0; i < v.NumField(); i++ { field := v.Field(i) if !field.IsZero() { switch field.Kind() { case reflect.Map, reflect.Slice: if field.Len() == 0 { continue } } return false } } return true } // LocalAccountID extracts an account's local account ID from an ID token. func (i IDToken) LocalAccountID() string { if i.Oid != "" { return i.Oid } return i.Subject } // jwtDecoder is provided to allow tests to provide their own. var jwtDecoder = decodeJWT // ClientInfo is used to create a Home Account ID for an account. type ClientInfo struct { UID string `json:"uid"` UTID string `json:"utid"` AdditionalFields map[string]interface{} } // UnmarshalJSON implements json.Unmarshaler.s func (c *ClientInfo) UnmarshalJSON(b []byte) error { s := strings.Trim(string(b), `"`) // Client info may be empty in some flows, e.g. certificate exchange. if len(s) == 0 { return nil } // Because we have a custom unmarshaler, you // cannot directly call json.Unmarshal here. If you do, it will call this function // recursively until reach our recursion limit. We have to create a new type // that doesn't have this method in order to use json.Unmarshal. type clientInfo2 ClientInfo raw, err := jwtDecoder(s) if err != nil { return fmt.Errorf("TokenResponse client_info field had JWT decode error: %w", err) } var c2 clientInfo2 err = json.Unmarshal(raw, &c2) if err != nil { return fmt.Errorf("was unable to unmarshal decoded JWT in TokenRespone to ClientInfo: %w", err) } *c = ClientInfo(c2) return nil } // HomeAccountID creates the home account ID. func (c ClientInfo) HomeAccountID() string { if c.UID == "" { return "" } else if c.UTID == "" { return fmt.Sprintf("%s.%s", c.UID, c.UID) } else { return fmt.Sprintf("%s.%s", c.UID, c.UTID) } } // Scopes represents scopes in a TokenResponse. type Scopes struct { Slice []string } // UnmarshalJSON implements json.Unmarshal. func (s *Scopes) UnmarshalJSON(b []byte) error { str := strings.Trim(string(b), `"`) if len(str) == 0 { return nil } sl := strings.Split(str, " ") s.Slice = sl return nil } // TokenResponse is the information that is returned from a token endpoint during a token acquisition flow. type TokenResponse struct { authority.OAuthResponseBase AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` FamilyID string `json:"foci"` IDToken IDToken `json:"id_token"` ClientInfo ClientInfo `json:"client_info"` ExpiresOn internalTime.DurationTime `json:"expires_in"` ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"` GrantedScopes Scopes `json:"scope"` DeclinedScopes []string // This is derived AdditionalFields map[string]interface{} scopesComputed bool } // ComputeScope computes the final scopes based on what was granted by the server and // what our AuthParams were from the authority server. Per OAuth spec, if no scopes are returned, the response should be treated as if all scopes were granted // This behavior can be observed in client assertion flows, but can happen at any time, this check ensures we treat // those special responses properly Link to spec: https://tools.ietf.org/html/rfc6749#section-3.3 func (tr *TokenResponse) ComputeScope(authParams authority.AuthParams) { if len(tr.GrantedScopes.Slice) == 0 { tr.GrantedScopes = Scopes{Slice: authParams.Scopes} } else { tr.DeclinedScopes = findDeclinedScopes(authParams.Scopes, tr.GrantedScopes.Slice) } tr.scopesComputed = true } // Validate validates the TokenResponse has basic valid values. It must be called // after ComputeScopes() is called. func (tr *TokenResponse) Validate() error { if tr.Error != "" { return fmt.Errorf("%s: %s", tr.Error, tr.ErrorDescription) } if tr.AccessToken == "" { return errors.New("response is missing access_token") } if !tr.scopesComputed { return fmt.Errorf("TokenResponse hasn't had ScopesComputed() called") } return nil } func (tr *TokenResponse) CacheKey(authParams authority.AuthParams) string { if authParams.AuthorizationType == authority.ATOnBehalfOf { return authParams.AssertionHash() } if authParams.AuthorizationType == authority.ATClientCredentials { return authParams.AppKey() } if authParams.IsConfidentialClient || authParams.AuthorizationType == authority.ATRefreshToken { return tr.ClientInfo.HomeAccountID() } return "" } func findDeclinedScopes(requestedScopes []string, grantedScopes []string) []string { declined := []string{} grantedMap := map[string]bool{} for _, s := range grantedScopes { grantedMap[strings.ToLower(s)] = true } // Comparing the requested scopes with the granted scopes to see if there are any scopes that have been declined. for _, r := range requestedScopes { if !grantedMap[strings.ToLower(r)] { declined = append(declined, r) } } return declined } // decodeJWT decodes a JWT and converts it to a byte array representing a JSON object // JWT has headers and payload base64url encoded without padding // https://tools.ietf.org/html/rfc7519#section-3 and // https://tools.ietf.org/html/rfc7515#section-2 func decodeJWT(data string) ([]byte, error) { // https://tools.ietf.org/html/rfc7515#appendix-C return base64.RawURLEncoding.DecodeString(data) } // RefreshToken is the JSON representation of a MSAL refresh token for encoding to storage. type RefreshToken struct { HomeAccountID string `json:"home_account_id,omitempty"` Environment string `json:"environment,omitempty"` CredentialType string `json:"credential_type,omitempty"` ClientID string `json:"client_id,omitempty"` FamilyID string `json:"family_id,omitempty"` Secret string `json:"secret,omitempty"` Realm string `json:"realm,omitempty"` Target string `json:"target,omitempty"` UserAssertionHash string `json:"user_assertion_hash,omitempty"` AdditionalFields map[string]interface{} } // NewRefreshToken is the constructor for RefreshToken. func NewRefreshToken(homeID, env, clientID, refreshToken, familyID string) RefreshToken { return RefreshToken{ HomeAccountID: homeID, Environment: env, CredentialType: "RefreshToken", ClientID: clientID, FamilyID: familyID, Secret: refreshToken, } } // Key outputs the key that can be used to uniquely look up this entry in a map. func (rt RefreshToken) Key() string { var fourth = rt.FamilyID if fourth == "" { fourth = rt.ClientID } return strings.Join( []string{rt.HomeAccountID, rt.Environment, rt.CredentialType, fourth}, shared.CacheKeySeparator, ) } func (rt RefreshToken) GetSecret() string { return rt.Secret } // DeviceCodeResult stores the response from the STS device code endpoint. type DeviceCodeResult struct { // UserCode is the code the user needs to provide when authentication at the verification URI. UserCode string // DeviceCode is the code used in the access token request. DeviceCode string // VerificationURL is the the URL where user can authenticate. VerificationURL string // ExpiresOn is the expiration time of device code in seconds. ExpiresOn time.Time // Interval is the interval at which the STS should be polled at. Interval int // Message is the message which should be displayed to the user. Message string // ClientID is the UUID issued by the authorization server for your application. ClientID string // Scopes is the OpenID scopes used to request access a protected API. Scopes []string } // NewDeviceCodeResult creates a DeviceCodeResult instance. func NewDeviceCodeResult(userCode, deviceCode, verificationURL string, expiresOn time.Time, interval int, message, clientID string, scopes []string) DeviceCodeResult { return DeviceCodeResult{userCode, deviceCode, verificationURL, expiresOn, interval, message, clientID, scopes} } func (dcr DeviceCodeResult) String() string { return fmt.Sprintf("UserCode: (%v)\nDeviceCode: (%v)\nURL: (%v)\nMessage: (%v)\n", dcr.UserCode, dcr.DeviceCode, dcr.VerificationURL, dcr.Message) } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/authority/000077500000000000000000000000001442026362400301165ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/authority/authority.go000066400000000000000000000460611442026362400325040ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package authority import ( "context" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "path" "strings" "time" "github.com/google/uuid" ) const ( authorizationEndpoint = "https://%v/%v/oauth2/v2.0/authorize" instanceDiscoveryEndpoint = "https://%v/common/discovery/instance" tenantDiscoveryEndpointWithRegion = "https://%s.%s/%s/v2.0/.well-known/openid-configuration" regionName = "REGION_NAME" defaultAPIVersion = "2021-10-01" imdsEndpoint = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=" + defaultAPIVersion autoDetectRegion = "TryAutoDetect" ) // These are various hosts that host AAD Instance discovery endpoints. const ( defaultHost = "login.microsoftonline.com" loginMicrosoft = "login.microsoft.com" loginWindows = "login.windows.net" loginSTSWindows = "sts.windows.net" loginMicrosoftOnline = defaultHost ) // jsonCaller is an interface that allows us to mock the JSONCall method. type jsonCaller interface { JSONCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error } var aadTrustedHostList = map[string]bool{ "login.windows.net": true, // Microsoft Azure Worldwide - Used in validation scenarios where host is not this list "login.chinacloudapi.cn": true, // Microsoft Azure China "login.microsoftonline.de": true, // Microsoft Azure Blackforest "login-us.microsoftonline.com": true, // Microsoft Azure US Government - Legacy "login.microsoftonline.us": true, // Microsoft Azure US Government "login.microsoftonline.com": true, // Microsoft Azure Worldwide "login.cloudgovapi.us": true, // Microsoft Azure US Government } // TrustedHost checks if an AAD host is trusted/valid. func TrustedHost(host string) bool { if _, ok := aadTrustedHostList[host]; ok { return true } return false } // OAuthResponseBase is the base JSON return message for an OAuth call. // This is embedded in other calls to get the base fields from every response. type OAuthResponseBase struct { Error string `json:"error"` SubError string `json:"suberror"` ErrorDescription string `json:"error_description"` ErrorCodes []int `json:"error_codes"` CorrelationID string `json:"correlation_id"` Claims string `json:"claims"` } // TenantDiscoveryResponse is the tenant endpoints from the OpenID configuration endpoint. type TenantDiscoveryResponse struct { OAuthResponseBase AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` Issuer string `json:"issuer"` AdditionalFields map[string]interface{} } // Validate validates that the response had the correct values required. func (r *TenantDiscoveryResponse) Validate() error { switch "" { case r.AuthorizationEndpoint: return errors.New("TenantDiscoveryResponse: authorize endpoint was not found in the openid configuration") case r.TokenEndpoint: return errors.New("TenantDiscoveryResponse: token endpoint was not found in the openid configuration") case r.Issuer: return errors.New("TenantDiscoveryResponse: issuer was not found in the openid configuration") } return nil } type InstanceDiscoveryMetadata struct { PreferredNetwork string `json:"preferred_network"` PreferredCache string `json:"preferred_cache"` Aliases []string `json:"aliases"` AdditionalFields map[string]interface{} } type InstanceDiscoveryResponse struct { TenantDiscoveryEndpoint string `json:"tenant_discovery_endpoint"` Metadata []InstanceDiscoveryMetadata `json:"metadata"` AdditionalFields map[string]interface{} } //go:generate stringer -type=AuthorizeType // AuthorizeType represents the type of token flow. type AuthorizeType int // These are all the types of token flows. const ( ATUnknown AuthorizeType = iota ATUsernamePassword ATWindowsIntegrated ATAuthCode ATInteractive ATClientCredentials ATDeviceCode ATRefreshToken AccountByID ATOnBehalfOf ) // These are all authority types const ( AAD = "MSSTS" ADFS = "ADFS" ) // AuthParams represents the parameters used for authorization for token acquisition. type AuthParams struct { AuthorityInfo Info CorrelationID string Endpoints Endpoints ClientID string // Redirecturi is used for auth flows that specify a redirect URI (e.g. local server for interactive auth flow). Redirecturi string HomeAccountID string // Username is the user-name portion for username/password auth flow. Username string // Password is the password portion for username/password auth flow. Password string // Scopes is the list of scopes the user consents to. Scopes []string // AuthorizationType specifies the auth flow being used. AuthorizationType AuthorizeType // State is a random value used to prevent cross-site request forgery attacks. State string // CodeChallenge is derived from a code verifier and is sent in the auth request. CodeChallenge string // CodeChallengeMethod describes the method used to create the CodeChallenge. CodeChallengeMethod string // Prompt specifies the user prompt type during interactive auth. Prompt string // IsConfidentialClient specifies if it is a confidential client. IsConfidentialClient bool // SendX5C specifies if x5c claim(public key of the certificate) should be sent to STS. SendX5C bool // UserAssertion is the access token used to acquire token on behalf of user UserAssertion string // Capabilities the client will include with each token request, for example "CP1". // Call [NewClientCapabilities] to construct a value for this field. Capabilities ClientCapabilities // Claims required for an access token to satisfy a conditional access policy Claims string // KnownAuthorityHosts don't require metadata discovery because they're known to the user KnownAuthorityHosts []string // LoginHint is a username with which to pre-populate account selection during interactive auth LoginHint string // DomainHint is a directive that can be used to accelerate the user to their federated IdP sign-in page DomainHint string } // NewAuthParams creates an authorization parameters object. func NewAuthParams(clientID string, authorityInfo Info) AuthParams { return AuthParams{ ClientID: clientID, AuthorityInfo: authorityInfo, CorrelationID: uuid.New().String(), } } // WithTenant returns a copy of the AuthParams having the specified tenant ID. If the given // ID is empty, the copy is identical to the original. This function returns an error in // several cases: // - ID isn't specific (for example, it's "common") // - ID is non-empty and the authority doesn't support tenants (for example, it's an ADFS authority) // - the client is configured to authenticate only Microsoft accounts via the "consumers" endpoint // - the resulting authority URL is invalid func (p AuthParams) WithTenant(ID string) (AuthParams, error) { switch ID { case "", p.AuthorityInfo.Tenant: // keep the default tenant because the caller didn't override it return p, nil case "common", "consumers", "organizations": if p.AuthorityInfo.AuthorityType == AAD { return p, fmt.Errorf(`tenant ID must be a specific tenant, not "%s"`, ID) } // else we'll return a better error below } if p.AuthorityInfo.AuthorityType != AAD { return p, errors.New("the authority doesn't support tenants") } if p.AuthorityInfo.Tenant == "consumers" { return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`) } authority := "https://" + path.Join(p.AuthorityInfo.Host, ID) info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled) if err == nil { info.Region = p.AuthorityInfo.Region p.AuthorityInfo = info } return p, err } // MergeCapabilitiesAndClaims combines client capabilities and challenge claims into a value suitable for an authentication request's "claims" parameter. func (p AuthParams) MergeCapabilitiesAndClaims() (string, error) { claims := p.Claims if len(p.Capabilities.asMap) > 0 { if claims == "" { // without claims the result is simply the capabilities return p.Capabilities.asJSON, nil } // Otherwise, merge claims and capabilties into a single JSON object. // We handle the claims challenge as a map because we don't know its structure. var challenge map[string]any if err := json.Unmarshal([]byte(claims), &challenge); err != nil { return "", fmt.Errorf(`claims must be JSON. Are they base64 encoded? json.Unmarshal returned "%v"`, err) } if err := merge(p.Capabilities.asMap, challenge); err != nil { return "", err } b, err := json.Marshal(challenge) if err != nil { return "", err } claims = string(b) } return claims, nil } // merges a into b without overwriting b's values. Returns an error when a and b share a key for which either has a non-object value. func merge(a, b map[string]any) error { for k, av := range a { if bv, ok := b[k]; !ok { // b doesn't contain this key => simply set it to a's value b[k] = av } else { // b does contain this key => recursively merge a[k] into b[k], provided both are maps. If a[k] or b[k] isn't // a map, return an error because merging would overwrite some value in b. Errors shouldn't occur in practice // because the challenge will be from AAD, which knows the capabilities format. if A, ok := av.(map[string]any); ok { if B, ok := bv.(map[string]any); ok { return merge(A, B) } else { // b[k] isn't a map return errors.New("challenge claims conflict with client capabilities") } } else { // a[k] isn't a map return errors.New("challenge claims conflict with client capabilities") } } } return nil } // ClientCapabilities stores capabilities in the formats used by AuthParams.MergeCapabilitiesAndClaims. // [NewClientCapabilities] precomputes these representations because capabilities are static for the // lifetime of a client and are included with every authentication request i.e., these computations // always have the same result and would otherwise have to be repeated for every request. type ClientCapabilities struct { // asJSON is for the common case: adding the capabilities to an auth request with no challenge claims asJSON string // asMap is for merging the capabilities with challenge claims asMap map[string]any } func NewClientCapabilities(capabilities []string) (ClientCapabilities, error) { c := ClientCapabilities{} var err error if len(capabilities) > 0 { cpbs := make([]string, len(capabilities)) for i := 0; i < len(cpbs); i++ { cpbs[i] = fmt.Sprintf(`"%s"`, capabilities[i]) } c.asJSON = fmt.Sprintf(`{"access_token":{"xms_cc":{"values":[%s]}}}`, strings.Join(cpbs, ",")) // note our JSON is valid but we can't stop users breaking it with garbage like "}" err = json.Unmarshal([]byte(c.asJSON), &c.asMap) } return c, err } // Info consists of information about the authority. type Info struct { Host string CanonicalAuthorityURI string AuthorityType string UserRealmURIPrefix string ValidateAuthority bool Tenant string Region string InstanceDiscoveryDisabled bool } func firstPathSegment(u *url.URL) (string, error) { pathParts := strings.Split(u.EscapedPath(), "/") if len(pathParts) >= 2 { return pathParts[1], nil } return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) } // NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided. func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) { u, err := url.Parse(strings.ToLower(authority)) if err != nil || u.Scheme != "https" { return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) } tenant, err := firstPathSegment(u) if err != nil { return Info{}, err } authorityType := AAD if tenant == "adfs" { authorityType = ADFS } // u.Host includes the port, if any, which is required for private cloud deployments return Info{ Host: u.Host, CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant), AuthorityType: authorityType, UserRealmURIPrefix: fmt.Sprintf("https://%v/common/userrealm/", u.Hostname()), ValidateAuthority: validateAuthority, Tenant: tenant, InstanceDiscoveryDisabled: instanceDiscoveryDisabled, }, nil } // Endpoints consists of the endpoints from the tenant discovery response. type Endpoints struct { AuthorizationEndpoint string TokenEndpoint string selfSignedJwtAudience string authorityHost string } // NewEndpoints creates an Endpoints object. func NewEndpoints(authorizationEndpoint string, tokenEndpoint string, selfSignedJwtAudience string, authorityHost string) Endpoints { return Endpoints{authorizationEndpoint, tokenEndpoint, selfSignedJwtAudience, authorityHost} } // UserRealmAccountType refers to the type of user realm. type UserRealmAccountType string // These are the different types of user realms. const ( Unknown UserRealmAccountType = "" Federated UserRealmAccountType = "Federated" Managed UserRealmAccountType = "Managed" ) // UserRealm is used for the username password request to determine user type type UserRealm struct { AccountType UserRealmAccountType `json:"account_type"` DomainName string `json:"domain_name"` CloudInstanceName string `json:"cloud_instance_name"` CloudAudienceURN string `json:"cloud_audience_urn"` // required if accountType is Federated FederationProtocol string `json:"federation_protocol"` FederationMetadataURL string `json:"federation_metadata_url"` AdditionalFields map[string]interface{} } func (u UserRealm) validate() error { switch "" { case string(u.AccountType): return errors.New("the account type (Federated or Managed) is missing") case u.DomainName: return errors.New("domain name of user realm is missing") case u.CloudInstanceName: return errors.New("cloud instance name of user realm is missing") case u.CloudAudienceURN: return errors.New("cloud Instance URN is missing") } if u.AccountType == Federated { switch "" { case u.FederationProtocol: return errors.New("federation protocol of user realm is missing") case u.FederationMetadataURL: return errors.New("federation metadata URL of user realm is missing") } } return nil } // Client represents the REST calls to authority backends. type Client struct { // Comm provides the HTTP transport client. Comm jsonCaller // *comm.Client } func (c Client) UserRealm(ctx context.Context, authParams AuthParams) (UserRealm, error) { endpoint := fmt.Sprintf("https://%s/common/UserRealm/%s", authParams.Endpoints.authorityHost, url.PathEscape(authParams.Username)) qv := url.Values{ "api-version": []string{"1.0"}, } resp := UserRealm{} err := c.Comm.JSONCall( ctx, endpoint, http.Header{"client-request-id": []string{authParams.CorrelationID}}, qv, nil, &resp, ) if err != nil { return resp, err } return resp, resp.validate() } func (c Client) GetTenantDiscoveryResponse(ctx context.Context, openIDConfigurationEndpoint string) (TenantDiscoveryResponse, error) { resp := TenantDiscoveryResponse{} err := c.Comm.JSONCall( ctx, openIDConfigurationEndpoint, http.Header{}, nil, nil, &resp, ) return resp, err } // AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint). // This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com). func (c Client) AADInstanceDiscovery(ctx context.Context, authorityInfo Info) (InstanceDiscoveryResponse, error) { region := "" var err error resp := InstanceDiscoveryResponse{} if authorityInfo.Region != "" && authorityInfo.Region != autoDetectRegion { region = authorityInfo.Region } else if authorityInfo.Region == autoDetectRegion { region = detectRegion(ctx) } if region != "" { environment := authorityInfo.Host switch environment { case loginMicrosoft, loginWindows, loginSTSWindows, defaultHost: environment = loginMicrosoft } resp.TenantDiscoveryEndpoint = fmt.Sprintf(tenantDiscoveryEndpointWithRegion, region, environment, authorityInfo.Tenant) metadata := InstanceDiscoveryMetadata{ PreferredNetwork: fmt.Sprintf("%v.%v", region, authorityInfo.Host), PreferredCache: authorityInfo.Host, Aliases: []string{fmt.Sprintf("%v.%v", region, authorityInfo.Host), authorityInfo.Host}, } resp.Metadata = []InstanceDiscoveryMetadata{metadata} } else { qv := url.Values{} qv.Set("api-version", "1.1") qv.Set("authorization_endpoint", fmt.Sprintf(authorizationEndpoint, authorityInfo.Host, authorityInfo.Tenant)) discoveryHost := defaultHost if TrustedHost(authorityInfo.Host) { discoveryHost = authorityInfo.Host } endpoint := fmt.Sprintf(instanceDiscoveryEndpoint, discoveryHost) err = c.Comm.JSONCall(ctx, endpoint, http.Header{}, qv, nil, &resp) } return resp, err } func detectRegion(ctx context.Context) string { region := os.Getenv(regionName) if region != "" { region = strings.ReplaceAll(region, " ", "") return strings.ToLower(region) } // HTTP call to IMDS endpoint to get region // Refer : https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FPinAuthToRegion%2FAAD%20SDK%20Proposal%20to%20Pin%20Auth%20to%20region.md&_a=preview&version=GBdev // Set a 2 second timeout for this http client which only does calls to IMDS endpoint client := http.Client{ Timeout: time.Duration(2 * time.Second), } req, _ := http.NewRequest("GET", imdsEndpoint, nil) req.Header.Set("Metadata", "true") resp, err := client.Do(req) // If the request times out or there is an error, it is retried once if err != nil || resp.StatusCode != 200 { resp, err = client.Do(req) if err != nil || resp.StatusCode != 200 { return "" } } defer resp.Body.Close() response, err := io.ReadAll(resp.Body) if err != nil { return "" } return string(response) } func (a *AuthParams) CacheKey(isAppCache bool) string { if a.AuthorizationType == ATOnBehalfOf { return a.AssertionHash() } if a.AuthorizationType == ATClientCredentials || isAppCache { return a.AppKey() } if a.AuthorizationType == ATRefreshToken || a.AuthorizationType == AccountByID { return a.HomeAccountID } return "" } func (a *AuthParams) AssertionHash() string { hasher := sha256.New() // Per documentation this never returns an error : https://pkg.go.dev/hash#pkg-types _, _ = hasher.Write([]byte(a.UserAssertion)) sha := base64.URLEncoding.EncodeToString(hasher.Sum(nil)) return sha } func (a *AuthParams) AppKey() string { if a.AuthorityInfo.Tenant != "" { return fmt.Sprintf("%s_%s_AppTokenCache", a.ClientID, a.AuthorityInfo.Tenant) } return fmt.Sprintf("%s__AppTokenCache", a.ClientID) } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/authority/authority_test.go000066400000000000000000000341741442026362400335450ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package authority import ( "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "reflect" "strings" "testing" "github.com/kylelemons/godebug/pretty" ) type fakeJSONCaller struct { err bool resp []byte gotEndpoint string gotHeaders http.Header gotQV url.Values gotBody interface{} gotResp interface{} } func (f *fakeJSONCaller) JSONCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error { if f.err { return errors.New("error") } f.gotEndpoint = endpoint f.gotHeaders = headers f.gotQV = qv f.gotBody = body f.gotResp = resp if f.resp != nil { if err := json.Unmarshal(f.resp, resp); err != nil { return err } } return nil } func (f *fakeJSONCaller) compare(endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error { if f.gotEndpoint != endpoint { return fmt.Errorf("got endpoint == %s, want endpoint == %s", f.gotEndpoint, endpoint) } if diff := pretty.Compare(headers, f.gotHeaders); diff != "" { return fmt.Errorf("headers -want/+got:\n%s", diff) } if diff := pretty.Compare(qv, f.gotQV); diff != "" { return fmt.Errorf("qv -want/+got:\n%s", diff) } if diff := pretty.Compare(body, f.gotBody); diff != "" { return fmt.Errorf("body -want/+got:\n%s", diff) } gotValue := reflect.ValueOf(f.gotResp) if gotValue.Kind() != reflect.Ptr { return fmt.Errorf("resp cannot be a non-pointer type") } gotValue = gotValue.Elem() gotName := gotValue.Type().Name() wantName := reflect.ValueOf(resp).Elem().Type().Name() if gotName != wantName { return fmt.Errorf("resp type was %s, want %s", gotName, wantName) } return nil } var testAuthorityEndpoints = NewEndpoints( "https://login.microsoftonline.com/v2.0/authorize", "https://login.microsoftonline.com/v2.0/token", "https://login.microsoftonline.com/v2.0", "login.microsoftonline.com", ) func TestUserRealm(t *testing.T) { authParams := AuthParams{ Username: "username", Endpoints: testAuthorityEndpoints, CorrelationID: "id", } tests := []struct { desc string err bool endpoint string jsonResp *UserRealm headers http.Header qv url.Values resp interface{} }{ { desc: "Error: comm returns error", err: true, }, { desc: "Success", endpoint: fmt.Sprintf("https://login.microsoftonline.com/common/UserRealm/%s", url.PathEscape(authParams.Username)), headers: http.Header{ "client-request-id": []string{"id"}, }, qv: url.Values{ "api-version": []string{"1.0"}, }, jsonResp: &UserRealm{ AccountType: "Managed", DomainName: "microsoftonline.com", CloudInstanceName: "instance", CloudAudienceURN: "urn", }, resp: &UserRealm{}, }, } for _, test := range tests { fake := &fakeJSONCaller{err: test.err} client := Client{fake} if test.jsonResp != nil { b, err := json.Marshal(test.jsonResp) if err != nil { panic(err) } fake.resp = b } // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.UserRealm(context.Background(), authParams) switch { case err == nil && test.err: t.Errorf("TestUserRealm(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestUserRealm(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(test.endpoint, test.headers, test.qv, nil, test.resp); err != nil { t.Errorf("TestUserRealm(%s): %s", test.desc, err) } } } func TestTenantDiscoveryResponse(t *testing.T) { tests := []struct { desc string err bool endpoint string resp interface{} }{ { desc: "Error: comm returns error", err: true, }, { desc: "Success", endpoint: "endpoint", resp: &TenantDiscoveryResponse{}, }, } for _, test := range tests { fake := &fakeJSONCaller{err: test.err} client := Client{fake} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.GetTenantDiscoveryResponse(context.Background(), "endpoint") switch { case err == nil && test.err: t.Errorf("TestTenantDiscoveryResponse(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestTenantDiscoveryResponse(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(test.endpoint, http.Header{}, nil, nil, test.resp); err != nil { t.Errorf("TestTenantDiscoveryResponse(%s): %s", test.desc, err) } } } func TestAADInstanceDiscovery(t *testing.T) { tests := []struct { desc string err bool authInfo Info endpoint string qv url.Values resp interface{} }{ { desc: "Error: comm returns error", err: true, }, { desc: "Success with authorityInfo.Host not in trusted list", endpoint: fmt.Sprintf(instanceDiscoveryEndpoint, defaultHost), authInfo: Info{ Host: "host", Tenant: "tenant", }, qv: url.Values{ "api-version": []string{"1.1"}, "authorization_endpoint": []string{fmt.Sprintf(authorizationEndpoint, "host", "tenant")}, }, resp: &InstanceDiscoveryResponse{}, }, { desc: "Success with authorityInfo.Host in trusted list", endpoint: fmt.Sprintf(instanceDiscoveryEndpoint, "login.microsoftonline.de"), authInfo: Info{ Host: "login.microsoftonline.de", Tenant: "tenant", }, qv: url.Values{ "api-version": []string{"1.1"}, "authorization_endpoint": []string{fmt.Sprintf(authorizationEndpoint, "login.microsoftonline.de", "tenant")}, }, resp: &InstanceDiscoveryResponse{}, }, } for _, test := range tests { fake := &fakeJSONCaller{err: test.err} client := Client{fake} // We don't care about the result, that is just a translation from the JSON handled // automatically in the comm package. We care only that the comm package got what // it needed. _, err := client.AADInstanceDiscovery(context.Background(), test.authInfo) switch { case err == nil && test.err: t.Errorf("AADInstanceDiscovery(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("AADInstanceDiscovery(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compare(test.endpoint, http.Header{}, test.qv, nil, test.resp); err != nil { t.Errorf("AADInstanceDiscovery(%s): %s", test.desc, err) } } } func TestAADInstanceDiscoveryWithRegion(t *testing.T) { client := Client{&fakeJSONCaller{}} region := "region" discoveryPath := "tenant/v2.0/.well-known/openid-configuration" publicCloudEndpoint := fmt.Sprintf("https://%s.login.microsoft.com/%s", region, discoveryPath) for _, test := range []struct{ host, expectedEndpoint string }{ {"login.chinacloudapi.cn", fmt.Sprintf("https://%s.login.chinacloudapi.cn/%s", region, discoveryPath)}, {"login.microsoft.com", publicCloudEndpoint}, {"login.microsoftonline.com", publicCloudEndpoint}, {"login.windows.net", publicCloudEndpoint}, {"login.windows-ppe.net", fmt.Sprintf("https://%s.login.windows-ppe.net/%s", region, discoveryPath)}, {"sts.windows.net", publicCloudEndpoint}, } { t.Run(test.host, func(t *testing.T) { authInfo := Info{Host: test.host, Tenant: "tenant", Region: region} resp, err := client.AADInstanceDiscovery(context.Background(), authInfo) if err != nil { t.Errorf("AADInstanceDiscoveryWithRegion failing with %s", err) } expectedPreferredNetwork := fmt.Sprintf("%v.%v", region, test.host) expectedPreferredCache := test.host if resp.TenantDiscoveryEndpoint != test.expectedEndpoint { t.Errorf("AADInstanceDiscoveryWithRegion incorrect TenantDiscoveryEndpoint: got: %s, want: %s", resp.TenantDiscoveryEndpoint, test.expectedEndpoint) } if resp.Metadata[0].PreferredNetwork != expectedPreferredNetwork { t.Errorf("AADInstanceDiscoveryWithRegion incorrect Preferred Network got: %s, want: %s", resp.Metadata[0].PreferredNetwork, expectedPreferredNetwork) } if resp.Metadata[0].PreferredCache != expectedPreferredCache { t.Errorf("AADInstanceDiscoveryWithRegion incorrect Preferred Cache got: %s, want: %s", resp.Metadata[0].PreferredCache, expectedPreferredCache) } }) } } func TestCreateAuthorityInfoFromAuthorityUri(t *testing.T) { const authorityURI = "https://login.microsoftonline.com/common/" want := Info{ Host: "login.microsoftonline.com", CanonicalAuthorityURI: authorityURI, AuthorityType: "MSSTS", UserRealmURIPrefix: "https://login.microsoftonline.com/common/userrealm/", Tenant: "common", ValidateAuthority: true, } got, err := NewInfoFromAuthorityURI(authorityURI, true, false) if err != nil { t.Fatalf("TestCreateAuthorityInfoFromAuthorityUri: got err == %s, want err == nil", err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestCreateAuthorityInfoFromAuthorityUri: -want/+got:\n%s", diff) } } func TestAuthParamsWithTenant(t *testing.T) { uuid1 := "00000000-0000-0000-0000-000000000000" uuid2 := strings.ReplaceAll(uuid1, "0", "1") host := "https://localhost/" for _, test := range []struct { authority, expectedAuthority, tenant string expectError bool }{ {authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1}, {authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1}, {authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2}, {authority: host + uuid1, tenant: "common", expectError: true}, {authority: host + uuid1, tenant: "organizations", expectError: true}, {authority: host + "adfs", tenant: uuid1, expectError: true}, {authority: host + "consumers", tenant: uuid1, expectError: true}, } { t.Run("", func(t *testing.T) { info, err := NewInfoFromAuthorityURI(test.authority, false, false) if err != nil { t.Fatal(err) } params := NewAuthParams("client-id", info) p, err := params.WithTenant(test.tenant) if test.expectError { if err == nil { t.Fatal("expected an error") } return } if err != nil { t.Fatal(err) } if v := strings.TrimSuffix(p.AuthorityInfo.CanonicalAuthorityURI, "/"); v != test.expectedAuthority { t.Fatalf(`unexpected tenant "%s"`, v) } }) } // WithTenant shouldn't change AuthorityInfo fields unrelated to the tenant, such as Region t.Run("AuthorityInfo", func(t *testing.T) { a := "A" b := "B" before, err := NewInfoFromAuthorityURI("https://localhost/"+a, true, false) if err != nil { t.Fatal(err) } before.Region = "region" params := NewAuthParams("client-id", before) p, err := params.WithTenant(b) if err != nil { t.Fatal(err) } after := p.AuthorityInfo // these values should be different because they contain the tenant (this is tested above) after.CanonicalAuthorityURI = before.CanonicalAuthorityURI after.Tenant = before.Tenant // With those fields equal, we can compare the before and after Infos without enumerating // their fields i.e., we can implicitly compare all the other fields at once. With this // approach, when Info gets a new field, this test needs an update only if that field // contains the tenant, in which case this test will break so maintainers don't overlook it. if diff := pretty.Compare(before, after); diff != "" { t.Fatal(diff) } }) } func TestMergeCapabilitiesAndClaims(t *testing.T) { for _, test := range []struct { capabilities []string challenge, desc, expected string err bool }{ { desc: "no capabilities or challenge", expected: "", }, { desc: "encoded challenge", capabilities: []string{"cp1"}, challenge: "eyJpZF90b2tlbiI6eyJhdXRoX3RpbWUiOnsiZXNzZW50aWFsIjp0cnVlfX19", err: true, }, { desc: "only capabilities", capabilities: []string{"cp1"}, expected: `{"access_token":{"xms_cc":{"values":["cp1"]}}}`, }, { desc: "only challenge", challenge: `{"id_token":{"auth_time":{"essential":true}}}`, expected: `{"id_token":{"auth_time":{"essential":true}}}`, }, { desc: "overlapping claim", // i.e. capabilities and claims are siblings capabilities: []string{"cp1", "cp2"}, challenge: `{"access_token":{"nbf":{"essential":true, "value":"42"}}}`, expected: `{"access_token":{"nbf":{"essential":true, "value":"42"}, "xms_cc":{"values":["cp1","cp2"]}}}`, }, { desc: "non-overlapping claim", capabilities: []string{"cp1", "cp2"}, challenge: `{"id_token":{"auth_time":{"essential":true}}}`, expected: `{"id_token":{"auth_time":{"essential":true}}, "access_token":{"xms_cc":{"values":["cp1","cp2"]}}}`, }, { desc: "overlapping and non-overlapping claims", capabilities: []string{"cp1", "cp2"}, challenge: `{"id_token":{"auth_time":{"essential":true}},"access_token":{"nbf":{"essential":true, "value":"42"}}}`, expected: `{"id_token":{"auth_time":{"essential":true}},"access_token":{"nbf":{"essential":true, "value":"42"},"xms_cc":{"values":["cp1","cp2"]}}}`, }, } { cpb, err := NewClientCapabilities(test.capabilities) if err != nil { t.Fatal(err) } ap := AuthParams{Capabilities: cpb, Claims: test.challenge} t.Run(test.desc, func(t *testing.T) { var expected map[string]any if err := json.Unmarshal([]byte(test.expected), &expected); err != nil && test.expected != "" { t.Fatal("test bug: the expected result must be JSON or an empty string") } merged, err := ap.MergeCapabilitiesAndClaims() if err != nil { if test.err { return } t.Fatal(err) } if merged == test.expected { return } var actual map[string]any if err = json.Unmarshal([]byte(merged), &actual); err != nil { t.Fatal(err) } if diff := pretty.Compare(expected, actual); diff != "" { t.Fatal(diff) } }) } } authorizetype_string.go000066400000000000000000000017031442026362400346710ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/authority// Code generated by "stringer -type=AuthorizeType"; DO NOT EDIT. package authority import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[ATUnknown-0] _ = x[ATUsernamePassword-1] _ = x[ATWindowsIntegrated-2] _ = x[ATAuthCode-3] _ = x[ATInteractive-4] _ = x[ATClientCredentials-5] _ = x[ATDeviceCode-6] _ = x[ATRefreshToken-7] } const _AuthorizeType_name = "ATUnknownATUsernamePasswordATWindowsIntegratedATAuthCodeATInteractiveATClientCredentialsATDeviceCodeATRefreshToken" var _AuthorizeType_index = [...]uint8{0, 9, 27, 46, 56, 69, 88, 100, 114} func (i AuthorizeType) String() string { if i < 0 || i >= AuthorizeType(len(_AuthorizeType_index)-1) { return "AuthorizeType(" + strconv.FormatInt(int64(i), 10) + ")" } return _AuthorizeType_name[_AuthorizeType_index[i]:_AuthorizeType_index[i+1]] } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/000077500000000000000000000000001442026362400277025ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/comm/000077500000000000000000000000001442026362400306355ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/comm/comm.go000066400000000000000000000226511442026362400321250ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // Package comm provides helpers for communicating with HTTP backends. package comm import ( "bytes" "context" "encoding/json" "encoding/xml" "fmt" "io" "net/http" "net/url" "reflect" "runtime" "strings" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" customJSON "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/version" "github.com/google/uuid" ) // HTTPClient represents an HTTP client. // It's usually an *http.Client from the standard library. type HTTPClient interface { // Do sends an HTTP request and returns an HTTP response. Do(req *http.Request) (*http.Response, error) // CloseIdleConnections closes any idle connections in a "keep-alive" state. CloseIdleConnections() } // Client provides a wrapper to our *http.Client that handles compression and serialization needs. type Client struct { client HTTPClient } // New returns a new Client object. func New(httpClient HTTPClient) *Client { if httpClient == nil { panic("http.Client cannot == nil") } return &Client{client: httpClient} } // JSONCall connects to the REST endpoint passing the HTTP query values, headers and JSON conversion // of body in the HTTP body. It automatically handles compression and decompression with gzip. The response is JSON // unmarshalled into resp. resp must be a pointer to a struct. If the body struct contains a field called // "AdditionalFields" we use a custom marshal/unmarshal engine. func (c *Client) JSONCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error { if qv == nil { qv = url.Values{} } v := reflect.ValueOf(resp) if err := c.checkResp(v); err != nil { return err } // Choose a JSON marshal/unmarshal depending on if we have AdditionalFields attribute. var marshal = json.Marshal var unmarshal = json.Unmarshal if _, ok := v.Elem().Type().FieldByName("AdditionalFields"); ok { marshal = customJSON.Marshal unmarshal = customJSON.Unmarshal } u, err := url.Parse(endpoint) if err != nil { return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err) } u.RawQuery = qv.Encode() addStdHeaders(headers) req := &http.Request{Method: http.MethodGet, URL: u, Header: headers} if body != nil { // Note: In case your wondering why we are not gzip encoding.... // I'm not sure if these various services support gzip on send. headers.Add("Content-Type", "application/json; charset=utf-8") data, err := marshal(body) if err != nil { return fmt.Errorf("bug: conn.Call(): could not marshal the body object: %w", err) } req.Body = io.NopCloser(bytes.NewBuffer(data)) req.Method = http.MethodPost } data, err := c.do(ctx, req) if err != nil { return err } if resp != nil { if err := unmarshal(data, resp); err != nil { return fmt.Errorf("json decode error: %w\njson message bytes were: %s", err, string(data)) } } return nil } // XMLCall connects to an endpoint and decodes the XML response into resp. This is used when // sending application/xml . If sending XML via SOAP, use SOAPCall(). func (c *Client) XMLCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, resp interface{}) error { if err := c.checkResp(reflect.ValueOf(resp)); err != nil { return err } if qv == nil { qv = url.Values{} } u, err := url.Parse(endpoint) if err != nil { return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err) } u.RawQuery = qv.Encode() headers.Set("Content-Type", "application/xml; charset=utf-8") // This was not set in he original Mex(), but... addStdHeaders(headers) return c.xmlCall(ctx, u, headers, "", resp) } // SOAPCall returns the SOAP message given an endpoint, action, body of the request and the response object to marshal into. func (c *Client) SOAPCall(ctx context.Context, endpoint, action string, headers http.Header, qv url.Values, body string, resp interface{}) error { if body == "" { return fmt.Errorf("cannot make a SOAP call with body set to empty string") } if err := c.checkResp(reflect.ValueOf(resp)); err != nil { return err } if qv == nil { qv = url.Values{} } u, err := url.Parse(endpoint) if err != nil { return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err) } u.RawQuery = qv.Encode() headers.Set("Content-Type", "application/soap+xml; charset=utf-8") headers.Set("SOAPAction", action) addStdHeaders(headers) return c.xmlCall(ctx, u, headers, body, resp) } // xmlCall sends an XML in body and decodes into resp. This simply does the transport and relies on // an upper level call to set things such as SOAP parameters and Content-Type, if required. func (c *Client) xmlCall(ctx context.Context, u *url.URL, headers http.Header, body string, resp interface{}) error { req := &http.Request{Method: http.MethodGet, URL: u, Header: headers} if len(body) > 0 { req.Method = http.MethodPost req.Body = io.NopCloser(strings.NewReader(body)) } data, err := c.do(ctx, req) if err != nil { return err } return xml.Unmarshal(data, resp) } // URLFormCall is used to make a call where we need to send application/x-www-form-urlencoded data // to the backend and receive JSON back. qv will be encoded into the request body. func (c *Client) URLFormCall(ctx context.Context, endpoint string, qv url.Values, resp interface{}) error { if len(qv) == 0 { return fmt.Errorf("URLFormCall() requires qv to have non-zero length") } if err := c.checkResp(reflect.ValueOf(resp)); err != nil { return err } u, err := url.Parse(endpoint) if err != nil { return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err) } headers := http.Header{} headers.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") addStdHeaders(headers) enc := qv.Encode() req := &http.Request{ Method: http.MethodPost, URL: u, Header: headers, ContentLength: int64(len(enc)), Body: io.NopCloser(strings.NewReader(enc)), GetBody: func() (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(enc)), nil }, } data, err := c.do(ctx, req) if err != nil { return err } v := reflect.ValueOf(resp) if err := c.checkResp(v); err != nil { return err } var unmarshal = json.Unmarshal if _, ok := v.Elem().Type().FieldByName("AdditionalFields"); ok { unmarshal = customJSON.Unmarshal } if resp != nil { if err := unmarshal(data, resp); err != nil { return fmt.Errorf("json decode error: %w\nraw message was: %s", err, string(data)) } } return nil } // do makes the HTTP call to the server and returns the contents of the body. func (c *Client) do(ctx context.Context, req *http.Request) ([]byte, error) { if _, ok := ctx.Deadline(); !ok { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, 30*time.Second) defer cancel() } req = req.WithContext(ctx) reply, err := c.client.Do(req) if err != nil { return nil, fmt.Errorf("server response error:\n %w", err) } defer reply.Body.Close() data, err := c.readBody(reply) if err != nil { return nil, fmt.Errorf("could not read the body of an HTTP Response: %w", err) } reply.Body = io.NopCloser(bytes.NewBuffer(data)) // NOTE: This doesn't happen immediately after the call so that we can get an error message // from the server and include it in our error. switch reply.StatusCode { case 200, 201: default: sd := strings.TrimSpace(string(data)) if sd != "" { // We probably have the error in the body. return nil, errors.CallErr{ Req: req, Resp: reply, Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s", req.URL.String(), req.Method, reply.StatusCode, sd), } } return nil, errors.CallErr{ Req: req, Resp: reply, Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d", req.URL.String(), req.Method, reply.StatusCode), } } return data, nil } // checkResp checks a response object o make sure it is a pointer to a struct. func (c *Client) checkResp(v reflect.Value) error { if v.Kind() != reflect.Ptr { return fmt.Errorf("bug: resp argument must a *struct, was %T", v.Interface()) } v = v.Elem() if v.Kind() != reflect.Struct { return fmt.Errorf("bug: resp argument must be a *struct, was %T", v.Interface()) } return nil } // readBody reads the body out of an *http.Response. It supports gzip encoded responses. func (c *Client) readBody(resp *http.Response) ([]byte, error) { var reader io.Reader = resp.Body switch resp.Header.Get("Content-Encoding") { case "": // Do nothing case "gzip": reader = gzipDecompress(resp.Body) default: return nil, fmt.Errorf("bug: comm.Client.JSONCall(): content was send with unsupported content-encoding %s", resp.Header.Get("Content-Encoding")) } return io.ReadAll(reader) } var testID string // addStdHeaders adds the standard headers we use on all calls. func addStdHeaders(headers http.Header) http.Header { headers.Set("Accept-Encoding", "gzip") // So that I can have a static id for tests. if testID != "" { headers.Set("client-request-id", testID) headers.Set("Return-Client-Request-Id", "false") } else { headers.Set("client-request-id", uuid.New().String()) headers.Set("Return-Client-Request-Id", "false") } headers.Set("x-client-sku", "MSAL.Go") headers.Set("x-client-os", runtime.GOOS) headers.Set("x-client-cpu", runtime.GOARCH) headers.Set("x-client-ver", version.Version) return headers } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/comm/comm_test.go000066400000000000000000000325531442026362400331660ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package comm import ( "context" "encoding/json" "encoding/xml" "io" "net/http" "net/http/httptest" "net/url" "testing" customJSON "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json" "github.com/kylelemons/godebug/diff" "github.com/kylelemons/godebug/pretty" ) type recorder struct { xml bool statusCode int ret interface{} gotMethod string gotQV url.Values gotBody []byte gotHeaders http.Header } func (rec *recorder) reset() { rec.statusCode = 0 rec.ret = nil rec.gotMethod = "" rec.gotQV = nil rec.gotBody = nil rec.gotHeaders = nil } func (rec *recorder) ServeHTTP(w http.ResponseWriter, r *http.Request) { if rec.statusCode != http.StatusOK { http.Error(w, "error", http.StatusBadRequest) return } rec.gotMethod = r.Method rec.gotQV = r.URL.Query() b, err := io.ReadAll(r.Body) if err != nil { panic(err) } rec.gotBody = b // This gets added by the test server. delete(r.Header, "User-Agent") delete(r.Header, "Content-Length") rec.gotHeaders = r.Header if rec.xml { b, err = xml.Marshal(rec.ret) if err != nil { panic(err) } } else { b, err = customJSON.Marshal(rec.ret) if err != nil { panic(err) } } if _, err := w.Write(b); err != nil { panic(err) } } type SampleData struct { Ok string } func init() { testID = "testID" } func TestJSONCall(t *testing.T) { tests := []struct { desc string statusCode int headers http.Header qv url.Values body, resp interface{} expectMethod string expectHeaders http.Header expectBody interface{} want interface{} err bool }{ { desc: "Error: non-struct resp value", statusCode: http.StatusOK, resp: new(int), err: true, }, { desc: "Error: non-pointer resp value", statusCode: http.StatusOK, resp: SampleData{}, err: true, }, { desc: "Body == nil[http Get]", statusCode: http.StatusOK, headers: http.Header{"header": []string{"here"}}, qv: url.Values{"key": []string{"value"}}, resp: &SampleData{Ok: "true"}, expectMethod: http.MethodGet, expectHeaders: addStdHeaders(http.Header{"Header": []string{"here"}}), want: &SampleData{Ok: "true"}, }, { desc: "Body != nil[http Post]", statusCode: http.StatusOK, headers: http.Header{"header": []string{"here"}}, qv: url.Values{"key": []string{"value"}}, body: &SampleData{Ok: "false"}, resp: &SampleData{Ok: "true"}, expectMethod: http.MethodPost, expectHeaders: addStdHeaders( http.Header{ "Header": []string{"here"}, "Content-Type": []string{"application/json; charset=utf-8"}, }, ), want: &SampleData{Ok: "true"}, }, { desc: "Error: non-200 response", statusCode: http.StatusBadRequest, headers: http.Header{}, qv: url.Values{}, resp: &SampleData{Ok: "true"}, err: true, }, } rec := &recorder{} serv := httptest.NewServer(rec) defer serv.Close() for _, test := range tests { rec.reset() rec.statusCode = test.statusCode rec.ret = test.resp comm := New(serv.Client()) err := comm.JSONCall(context.Background(), serv.URL, test.headers, test.qv, test.body, test.resp) switch { case err == nil && test.err: t.Errorf("TestJSONCall(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestJSONCall(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if test.expectMethod != rec.gotMethod { t.Errorf("TestJSONCall(%s): got method == %s, want http method == %s", test.desc, test.expectMethod, rec.gotMethod) continue } if diff := pretty.Compare(test.qv, rec.gotQV); diff != "" { t.Errorf("TestJSONCall(%s): query values: -want/+got:\n%s", test.desc, diff) continue } if test.expectHeaders != nil { if diff := pretty.Compare(test.expectHeaders, rec.gotHeaders); diff != "" { t.Errorf("TestJSONCall(%s): headers: -want/+got:\n%s", test.desc, diff) continue } } if test.expectBody != nil { gotBody := SampleData{} if err := json.Unmarshal(rec.gotBody, &gotBody); err != nil { panic(err) } if diff := pretty.Compare(test.expectBody, gotBody); diff != "" { t.Errorf("TestJSONCall(%s): body: -want/+got:\n%s", test.desc, diff) continue } } if diff := pretty.Compare(test.want, test.resp); diff != "" { t.Errorf("TestJSONCall(%s): result: -want/+got:\n%s", test.desc, diff) } } } func TestXMLCall(t *testing.T) { tests := []struct { desc string statusCode int headers http.Header qv url.Values resp interface{} expectHeaders http.Header expectBody interface{} want interface{} err bool }{ { desc: "Error: non-struct resp value", statusCode: http.StatusOK, resp: new(int), err: true, }, { desc: "Error: non-pointer resp value", statusCode: http.StatusOK, resp: SampleData{}, err: true, }, { desc: "Success", statusCode: http.StatusOK, headers: http.Header{"header": []string{"here"}}, qv: url.Values{"key": []string{"value"}}, resp: &SampleData{Ok: "true"}, expectHeaders: addStdHeaders( http.Header{ "Header": []string{"here"}, "Content-Type": []string{"application/xml; charset=utf-8"}, }, ), want: &SampleData{Ok: "true"}, }, { desc: "Error: non-200 response", statusCode: http.StatusBadRequest, headers: http.Header{}, qv: url.Values{}, resp: &SampleData{Ok: "true"}, err: true, }, } rec := &recorder{xml: true} serv := httptest.NewServer(rec) defer serv.Close() for _, test := range tests { rec.reset() rec.statusCode = test.statusCode rec.ret = test.resp comm := New(serv.Client()) err := comm.XMLCall(context.Background(), serv.URL, test.headers, test.qv, test.resp) switch { case err == nil && test.err: t.Errorf("TestXMLCall(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestXMLCall(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if rec.gotMethod != http.MethodGet { t.Errorf("TestXMLCall(%s): got method == %s, want http method == GET", test.desc, rec.gotMethod) continue } if diff := pretty.Compare(test.qv, rec.gotQV); diff != "" { t.Errorf("TestXMLCall(%s): query values: -want/+got:\n%s", test.desc, diff) continue } if test.expectHeaders != nil { if diff := pretty.Compare(test.expectHeaders, rec.gotHeaders); diff != "" { t.Errorf("TestXMLCall(%s): headers: -want/+got:\n%s", test.desc, diff) continue } } if test.expectBody != nil { gotBody := SampleData{} if err := xml.Unmarshal(rec.gotBody, &gotBody); err != nil { panic(err) } if diff := pretty.Compare(test.expectBody, gotBody); diff != "" { t.Errorf("TestXMLCall(%s): body: -want/+got:\n%s", test.desc, diff) continue } } if diff := pretty.Compare(test.want, test.resp); diff != "" { t.Errorf("TestXMLCall(%s): result: -want/+got:\n%s", test.desc, diff) } } } func TestSoapCall(t *testing.T) { const soapActionDefault = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue" req := SampleData{Ok: "whatever"} body, err := xml.Marshal(req) if err != nil { panic(err) } tests := []struct { desc string statusCode int action string body string headers http.Header qv url.Values resp interface{} expectHeaders http.Header expectBody interface{} want interface{} err bool }{ { desc: "Error: non-struct resp value", statusCode: http.StatusOK, resp: new(int), err: true, }, { desc: "Error: non-pointer resp value", statusCode: http.StatusOK, resp: SampleData{}, err: true, }, { desc: "Error: body arg was empty string", statusCode: http.StatusOK, action: soapActionDefault, headers: http.Header{"header": []string{"here"}}, qv: url.Values{"key": []string{"value"}}, resp: &SampleData{Ok: "true"}, err: true, }, { desc: "Success", statusCode: http.StatusOK, headers: http.Header{"header": []string{"here"}}, qv: url.Values{"key": []string{"value"}}, action: soapActionDefault, body: string(body), resp: &SampleData{Ok: "true"}, expectHeaders: addStdHeaders( http.Header{ "Header": []string{"here"}, "Content-Type": []string{"application/soap+xml; charset=utf-8"}, "Soapaction": []string{soapActionDefault}, }, ), want: &SampleData{Ok: "true"}, }, { desc: "Error: non-200 response", statusCode: http.StatusBadRequest, headers: http.Header{}, qv: url.Values{}, resp: &SampleData{Ok: "true"}, err: true, }, } rec := &recorder{xml: true} serv := httptest.NewServer(rec) defer serv.Close() for _, test := range tests { rec.reset() rec.statusCode = test.statusCode rec.ret = test.resp comm := New(serv.Client()) err := comm.SOAPCall(context.Background(), serv.URL, test.action, test.headers, test.qv, test.body, test.resp) switch { case err == nil && test.err: t.Errorf("TestXMLCall(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestXMLCall(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if rec.gotMethod != http.MethodPost { t.Errorf("TestXMLCall(%s): got method == %s, want http method == POST", test.desc, rec.gotMethod) continue } if diff := pretty.Compare(test.qv, rec.gotQV); diff != "" { t.Errorf("TestXMLCall(%s): query values: -want/+got:\n%s", test.desc, diff) continue } if test.expectHeaders != nil { if diff := pretty.Compare(test.expectHeaders, rec.gotHeaders); diff != "" { t.Errorf("TestXMLCall(%s): headers: -want/+got:\n%s", test.desc, diff) continue } } if test.expectBody != nil { gotBody := SampleData{} if err := xml.Unmarshal(rec.gotBody, &gotBody); err != nil { panic(err) } if diff := pretty.Compare(test.expectBody, gotBody); diff != "" { t.Errorf("TestXMLCall(%s): body: -want/+got:\n%s", test.desc, diff) continue } } if diff := pretty.Compare(test.want, test.resp); diff != "" { t.Errorf("TestXMLCall(%s): result: -want/+got:\n%s", test.desc, diff) } } } func TestURLFormCall(t *testing.T) { tests := []struct { desc string statusCode int action string body string headers http.Header qv url.Values resp interface{} expectHeaders http.Header expectEndpoint string want interface{} err bool }{ { desc: "Error: non-struct resp value", statusCode: http.StatusOK, resp: new(int), err: true, }, { desc: "Error: non-pointer resp value", statusCode: http.StatusOK, resp: SampleData{}, err: true, }, { desc: "Error: empty query values", statusCode: http.StatusOK, headers: http.Header{"header": []string{"here"}}, resp: &SampleData{Ok: "true"}, expectHeaders: addStdHeaders( http.Header{ "Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"}, }, ), err: true, }, { desc: "Success", statusCode: http.StatusOK, headers: http.Header{"header": []string{"here"}}, qv: url.Values{"key": []string{"value"}}, resp: &SampleData{Ok: "true"}, expectHeaders: addStdHeaders( http.Header{ "Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"}, }, ), want: &SampleData{Ok: "true"}, }, { desc: "Error: non-200 response", statusCode: http.StatusBadRequest, headers: http.Header{"header": []string{"here"}}, qv: url.Values{"key": []string{"value"}}, resp: &SampleData{Ok: "true"}, err: true, }, } rec := &recorder{} serv := httptest.NewServer(rec) defer serv.Close() for _, test := range tests { rec.reset() rec.statusCode = test.statusCode rec.ret = test.resp comm := New(serv.Client()) err := comm.URLFormCall(context.Background(), serv.URL, test.qv, test.resp) switch { case err == nil && test.err: t.Errorf("TestURLFormCall(%s): got err == nil, want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestURLFormCall(%s): got err == %s, want err == nil", test.desc, err) continue case err != nil: continue } if rec.gotMethod != http.MethodPost { t.Errorf("TestURLFormCall(%s): got method == %s, want http method == POST", test.desc, rec.gotMethod) continue } if test.expectHeaders != nil { if diff := pretty.Compare(test.expectHeaders, rec.gotHeaders); diff != "" { t.Errorf("TestURLFormCall(%s): headers: -want/+got:\n%s", test.desc, diff) continue } } want := test.qv.Encode() got := string(rec.gotBody) if diff := diff.Diff(want, got); diff != "" { t.Errorf("TestXMLCall(%s): body: -want/+got:\n%s", test.desc, diff) continue } if diff := pretty.Compare(test.want, test.resp); diff != "" { t.Errorf("TestXMLCall(%s): result: -want/+got:\n%s", test.desc, diff) } } } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/comm/compress.go000066400000000000000000000013041442026362400330150ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package comm import ( "compress/gzip" "io" ) func gzipDecompress(r io.Reader) io.Reader { gzipReader, _ := gzip.NewReader(r) pipeOut, pipeIn := io.Pipe() go func() { // decompression bomb would have to come from Azure services. // If we want to limit, we should do that in comm.do(). _, err := io.Copy(pipeIn, gzipReader) //nolint if err != nil { // don't need the error. pipeIn.CloseWithError(err) //nolint gzipReader.Close() return } if err := gzipReader.Close(); err != nil { // don't need the error. pipeIn.CloseWithError(err) //nolint return } pipeIn.Close() }() return pipeOut } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/grant/000077500000000000000000000000001442026362400310155ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/grant/grant.go000066400000000000000000000011761442026362400324640ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // Package grant holds types of grants issued by authorization services. package grant const ( Password = "password" JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" SAMLV1 = "urn:ietf:params:oauth:grant-type:saml1_1-bearer" SAMLV2 = "urn:ietf:params:oauth:grant-type:saml2-bearer" DeviceCode = "device_code" AuthCode = "authorization_code" RefreshToken = "refresh_token" ClientCredential = "client_credentials" ClientAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ) microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/ops.go000066400000000000000000000035441442026362400272240ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. /* Package ops provides operations to various backend services using REST clients. The REST type provides several clients that can be used to communicate to backends. Usage is simple: rest := ops.New() // Creates an authority client and calls the UserRealm() method. userRealm, err := rest.Authority().UserRealm(ctx, authParameters) if err != nil { // Do something } */ package ops import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/internal/comm" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" ) // HTTPClient represents an HTTP client. // It's usually an *http.Client from the standard library. type HTTPClient = comm.HTTPClient // REST provides REST clients for communicating with various backends used by MSAL. type REST struct { client *comm.Client } // New is the constructor for REST. func New(httpClient HTTPClient) *REST { return &REST{client: comm.New(httpClient)} } // Authority returns a client for querying information about various authorities. func (r *REST) Authority() authority.Client { return authority.Client{Comm: r.client} } // AccessTokens returns a client that can be used to get various access tokens for // authorization purposes. func (r *REST) AccessTokens() accesstokens.Client { return accesstokens.Client{Comm: r.client} } // WSTrust provides access to various metadata in a WSTrust service. This data can // be used to gain tokens based on SAML data using the client provided by AccessTokens(). func (r *REST) WSTrust() wstrust.Client { return wstrust.Client{Comm: r.client} } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/000077500000000000000000000000001442026362400276215ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs/000077500000000000000000000000001442026362400305425ustar00rootroot00000000000000endpointtype_string.go000066400000000000000000000013351442026362400351240ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs// Code generated by "stringer -type=endpointType"; DO NOT EDIT. package defs import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[etUnknown-0] _ = x[etUsernamePassword-1] _ = x[etWindowsTransport-2] } const _endpointType_name = "etUnknownetUsernamePasswordetWindowsTransport" var _endpointType_index = [...]uint8{0, 9, 27, 45} func (i endpointType) String() string { if i < 0 || i >= endpointType(len(_endpointType_index)-1) { return "endpointType(" + strconv.FormatInt(int64(i), 10) + ")" } return _endpointType_name[_endpointType_index[i]:_endpointType_index[i+1]] } mex_document_definitions.go000066400000000000000000000277001442026362400361020ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package defs import "encoding/xml" type Definitions struct { XMLName xml.Name `xml:"definitions"` Text string `xml:",chardata"` Name string `xml:"name,attr"` TargetNamespace string `xml:"targetNamespace,attr"` WSDL string `xml:"wsdl,attr"` XSD string `xml:"xsd,attr"` T string `xml:"t,attr"` SOAPENC string `xml:"soapenc,attr"` SOAP string `xml:"soap,attr"` TNS string `xml:"tns,attr"` MSC string `xml:"msc,attr"` WSAM string `xml:"wsam,attr"` SOAP12 string `xml:"soap12,attr"` WSA10 string `xml:"wsa10,attr"` WSA string `xml:"wsa,attr"` WSAW string `xml:"wsaw,attr"` WSX string `xml:"wsx,attr"` WSAP string `xml:"wsap,attr"` WSU string `xml:"wsu,attr"` Trust string `xml:"trust,attr"` WSP string `xml:"wsp,attr"` Policy []Policy `xml:"Policy"` Types Types `xml:"types"` Message []Message `xml:"message"` PortType []PortType `xml:"portType"` Binding []Binding `xml:"binding"` Service Service `xml:"service"` } type Policy struct { Text string `xml:",chardata"` ID string `xml:"Id,attr"` ExactlyOne ExactlyOne `xml:"ExactlyOne"` } type ExactlyOne struct { Text string `xml:",chardata"` All All `xml:"All"` } type All struct { Text string `xml:",chardata"` NegotiateAuthentication NegotiateAuthentication `xml:"NegotiateAuthentication"` TransportBinding TransportBinding `xml:"TransportBinding"` UsingAddressing Text `xml:"UsingAddressing"` EndorsingSupportingTokens EndorsingSupportingTokens `xml:"EndorsingSupportingTokens"` WSS11 WSS11 `xml:"Wss11"` Trust10 Trust10 `xml:"Trust10"` SignedSupportingTokens SignedSupportingTokens `xml:"SignedSupportingTokens"` Trust13 WSTrust13 `xml:"Trust13"` SignedEncryptedSupportingTokens SignedEncryptedSupportingTokens `xml:"SignedEncryptedSupportingTokens"` } type NegotiateAuthentication struct { Text string `xml:",chardata"` HTTP string `xml:"http,attr"` XMLName xml.Name } type TransportBinding struct { Text string `xml:",chardata"` SP string `xml:"sp,attr"` Policy TransportBindingPolicy `xml:"Policy"` } type TransportBindingPolicy struct { Text string `xml:",chardata"` TransportToken TransportToken `xml:"TransportToken"` AlgorithmSuite AlgorithmSuite `xml:"AlgorithmSuite"` Layout Layout `xml:"Layout"` IncludeTimestamp Text `xml:"IncludeTimestamp"` } type TransportToken struct { Text string `xml:",chardata"` Policy TransportTokenPolicy `xml:"Policy"` } type TransportTokenPolicy struct { Text string `xml:",chardata"` HTTPSToken HTTPSToken `xml:"HttpsToken"` } type HTTPSToken struct { Text string `xml:",chardata"` RequireClientCertificate string `xml:"RequireClientCertificate,attr"` } type AlgorithmSuite struct { Text string `xml:",chardata"` Policy AlgorithmSuitePolicy `xml:"Policy"` } type AlgorithmSuitePolicy struct { Text string `xml:",chardata"` Basic256 Text `xml:"Basic256"` Basic128 Text `xml:"Basic128"` } type Layout struct { Text string `xml:",chardata"` Policy LayoutPolicy `xml:"Policy"` } type LayoutPolicy struct { Text string `xml:",chardata"` Strict Text `xml:"Strict"` } type EndorsingSupportingTokens struct { Text string `xml:",chardata"` SP string `xml:"sp,attr"` Policy EndorsingSupportingTokensPolicy `xml:"Policy"` } type EndorsingSupportingTokensPolicy struct { Text string `xml:",chardata"` X509Token X509Token `xml:"X509Token"` RSAToken RSAToken `xml:"RsaToken"` SignedParts SignedParts `xml:"SignedParts"` KerberosToken KerberosToken `xml:"KerberosToken"` IssuedToken IssuedToken `xml:"IssuedToken"` KeyValueToken KeyValueToken `xml:"KeyValueToken"` } type X509Token struct { Text string `xml:",chardata"` IncludeToken string `xml:"IncludeToken,attr"` Policy X509TokenPolicy `xml:"Policy"` } type X509TokenPolicy struct { Text string `xml:",chardata"` RequireThumbprintReference Text `xml:"RequireThumbprintReference"` WSSX509V3Token10 Text `xml:"WssX509V3Token10"` } type RSAToken struct { Text string `xml:",chardata"` IncludeToken string `xml:"IncludeToken,attr"` Optional string `xml:"Optional,attr"` MSSP string `xml:"mssp,attr"` } type SignedParts struct { Text string `xml:",chardata"` Header SignedPartsHeader `xml:"Header"` } type SignedPartsHeader struct { Text string `xml:",chardata"` Name string `xml:"Name,attr"` Namespace string `xml:"Namespace,attr"` } type KerberosToken struct { Text string `xml:",chardata"` IncludeToken string `xml:"IncludeToken,attr"` Policy KerberosTokenPolicy `xml:"Policy"` } type KerberosTokenPolicy struct { Text string `xml:",chardata"` WSSGSSKerberosV5ApReqToken11 Text `xml:"WssGssKerberosV5ApReqToken11"` } type IssuedToken struct { Text string `xml:",chardata"` IncludeToken string `xml:"IncludeToken,attr"` RequestSecurityTokenTemplate RequestSecurityTokenTemplate `xml:"RequestSecurityTokenTemplate"` Policy IssuedTokenPolicy `xml:"Policy"` } type RequestSecurityTokenTemplate struct { Text string `xml:",chardata"` KeyType Text `xml:"KeyType"` EncryptWith Text `xml:"EncryptWith"` SignatureAlgorithm Text `xml:"SignatureAlgorithm"` CanonicalizationAlgorithm Text `xml:"CanonicalizationAlgorithm"` EncryptionAlgorithm Text `xml:"EncryptionAlgorithm"` KeySize Text `xml:"KeySize"` KeyWrapAlgorithm Text `xml:"KeyWrapAlgorithm"` } type IssuedTokenPolicy struct { Text string `xml:",chardata"` RequireInternalReference Text `xml:"RequireInternalReference"` } type KeyValueToken struct { Text string `xml:",chardata"` IncludeToken string `xml:"IncludeToken,attr"` Optional string `xml:"Optional,attr"` } type WSS11 struct { Text string `xml:",chardata"` SP string `xml:"sp,attr"` Policy Wss11Policy `xml:"Policy"` } type Wss11Policy struct { Text string `xml:",chardata"` MustSupportRefThumbprint Text `xml:"MustSupportRefThumbprint"` } type Trust10 struct { Text string `xml:",chardata"` SP string `xml:"sp,attr"` Policy Trust10Policy `xml:"Policy"` } type Trust10Policy struct { Text string `xml:",chardata"` MustSupportIssuedTokens Text `xml:"MustSupportIssuedTokens"` RequireClientEntropy Text `xml:"RequireClientEntropy"` RequireServerEntropy Text `xml:"RequireServerEntropy"` } type SignedSupportingTokens struct { Text string `xml:",chardata"` SP string `xml:"sp,attr"` Policy SupportingTokensPolicy `xml:"Policy"` } type SupportingTokensPolicy struct { Text string `xml:",chardata"` UsernameToken UsernameToken `xml:"UsernameToken"` } type UsernameToken struct { Text string `xml:",chardata"` IncludeToken string `xml:"IncludeToken,attr"` Policy UsernameTokenPolicy `xml:"Policy"` } type UsernameTokenPolicy struct { Text string `xml:",chardata"` WSSUsernameToken10 WSSUsernameToken10 `xml:"WssUsernameToken10"` } type WSSUsernameToken10 struct { Text string `xml:",chardata"` XMLName xml.Name } type WSTrust13 struct { Text string `xml:",chardata"` SP string `xml:"sp,attr"` Policy WSTrust13Policy `xml:"Policy"` } type WSTrust13Policy struct { Text string `xml:",chardata"` MustSupportIssuedTokens Text `xml:"MustSupportIssuedTokens"` RequireClientEntropy Text `xml:"RequireClientEntropy"` RequireServerEntropy Text `xml:"RequireServerEntropy"` } type SignedEncryptedSupportingTokens struct { Text string `xml:",chardata"` SP string `xml:"sp,attr"` Policy SupportingTokensPolicy `xml:"Policy"` } type Types struct { Text string `xml:",chardata"` Schema Schema `xml:"schema"` } type Schema struct { Text string `xml:",chardata"` TargetNamespace string `xml:"targetNamespace,attr"` Import []Import `xml:"import"` } type Import struct { Text string `xml:",chardata"` SchemaLocation string `xml:"schemaLocation,attr"` Namespace string `xml:"namespace,attr"` } type Message struct { Text string `xml:",chardata"` Name string `xml:"name,attr"` Part Part `xml:"part"` } type Part struct { Text string `xml:",chardata"` Name string `xml:"name,attr"` Element string `xml:"element,attr"` } type PortType struct { Text string `xml:",chardata"` Name string `xml:"name,attr"` Operation Operation `xml:"operation"` } type Operation struct { Text string `xml:",chardata"` Name string `xml:"name,attr"` Input OperationIO `xml:"input"` Output OperationIO `xml:"output"` } type OperationIO struct { Text string `xml:",chardata"` Action string `xml:"Action,attr"` Message string `xml:"message,attr"` Body OperationIOBody `xml:"body"` } type OperationIOBody struct { Text string `xml:",chardata"` Use string `xml:"use,attr"` } type Binding struct { Text string `xml:",chardata"` Name string `xml:"name,attr"` Type string `xml:"type,attr"` PolicyReference PolicyReference `xml:"PolicyReference"` Binding DefinitionsBinding `xml:"binding"` Operation BindingOperation `xml:"operation"` } type PolicyReference struct { Text string `xml:",chardata"` URI string `xml:"URI,attr"` } type DefinitionsBinding struct { Text string `xml:",chardata"` Transport string `xml:"transport,attr"` } type BindingOperation struct { Text string `xml:",chardata"` Name string `xml:"name,attr"` Operation BindingOperationOperation `xml:"operation"` Input BindingOperationIO `xml:"input"` Output BindingOperationIO `xml:"output"` } type BindingOperationOperation struct { Text string `xml:",chardata"` SoapAction string `xml:"soapAction,attr"` Style string `xml:"style,attr"` } type BindingOperationIO struct { Text string `xml:",chardata"` Body OperationIOBody `xml:"body"` } type Service struct { Text string `xml:",chardata"` Name string `xml:"name,attr"` Port []Port `xml:"port"` } type Port struct { Text string `xml:",chardata"` Name string `xml:"name,attr"` Binding string `xml:"binding,attr"` Address Address `xml:"address"` EndpointReference PortEndpointReference `xml:"EndpointReference"` } type Address struct { Text string `xml:",chardata"` Location string `xml:"location,attr"` } type PortEndpointReference struct { Text string `xml:",chardata"` Address Text `xml:"Address"` Identity Identity `xml:"Identity"` } type Identity struct { Text string `xml:",chardata"` XMLNS string `xml:"xmlns,attr"` SPN Text `xml:"Spn"` } saml_assertion_definitions.go000066400000000000000000000171361442026362400364400ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package defs import "encoding/xml" // TODO(msal): Someone (and it ain't gonna be me) needs to document these attributes or // at the least put a link to RFC. type SAMLDefinitions struct { XMLName xml.Name `xml:"Envelope"` Text string `xml:",chardata"` S string `xml:"s,attr"` A string `xml:"a,attr"` U string `xml:"u,attr"` Header Header `xml:"Header"` Body Body `xml:"Body"` } type Header struct { Text string `xml:",chardata"` Action Action `xml:"Action"` Security Security `xml:"Security"` } type Action struct { Text string `xml:",chardata"` MustUnderstand string `xml:"mustUnderstand,attr"` } type Security struct { Text string `xml:",chardata"` MustUnderstand string `xml:"mustUnderstand,attr"` O string `xml:"o,attr"` Timestamp Timestamp `xml:"Timestamp"` } type Timestamp struct { Text string `xml:",chardata"` ID string `xml:"Id,attr"` Created Text `xml:"Created"` Expires Text `xml:"Expires"` } type Text struct { Text string `xml:",chardata"` } type Body struct { Text string `xml:",chardata"` RequestSecurityTokenResponseCollection RequestSecurityTokenResponseCollection `xml:"RequestSecurityTokenResponseCollection"` } type RequestSecurityTokenResponseCollection struct { Text string `xml:",chardata"` Trust string `xml:"trust,attr"` RequestSecurityTokenResponse []RequestSecurityTokenResponse `xml:"RequestSecurityTokenResponse"` } type RequestSecurityTokenResponse struct { Text string `xml:",chardata"` Lifetime Lifetime `xml:"Lifetime"` AppliesTo AppliesTo `xml:"AppliesTo"` RequestedSecurityToken RequestedSecurityToken `xml:"RequestedSecurityToken"` RequestedAttachedReference RequestedAttachedReference `xml:"RequestedAttachedReference"` RequestedUnattachedReference RequestedUnattachedReference `xml:"RequestedUnattachedReference"` TokenType Text `xml:"TokenType"` RequestType Text `xml:"RequestType"` KeyType Text `xml:"KeyType"` } type Lifetime struct { Text string `xml:",chardata"` Created WSUTimestamp `xml:"Created"` Expires WSUTimestamp `xml:"Expires"` } type WSUTimestamp struct { Text string `xml:",chardata"` Wsu string `xml:"wsu,attr"` } type AppliesTo struct { Text string `xml:",chardata"` Wsp string `xml:"wsp,attr"` EndpointReference EndpointReference `xml:"EndpointReference"` } type EndpointReference struct { Text string `xml:",chardata"` Wsa string `xml:"wsa,attr"` Address Text `xml:"Address"` } type RequestedSecurityToken struct { Text string `xml:",chardata"` AssertionRawXML string `xml:",innerxml"` Assertion Assertion `xml:"Assertion"` } type Assertion struct { XMLName xml.Name // Normally its `xml:"Assertion"`, but I think they want to capture the xmlns Text string `xml:",chardata"` MajorVersion string `xml:"MajorVersion,attr"` MinorVersion string `xml:"MinorVersion,attr"` AssertionID string `xml:"AssertionID,attr"` Issuer string `xml:"Issuer,attr"` IssueInstant string `xml:"IssueInstant,attr"` Saml string `xml:"saml,attr"` Conditions Conditions `xml:"Conditions"` AttributeStatement AttributeStatement `xml:"AttributeStatement"` AuthenticationStatement AuthenticationStatement `xml:"AuthenticationStatement"` Signature Signature `xml:"Signature"` } type Conditions struct { Text string `xml:",chardata"` NotBefore string `xml:"NotBefore,attr"` NotOnOrAfter string `xml:"NotOnOrAfter,attr"` AudienceRestrictionCondition AudienceRestrictionCondition `xml:"AudienceRestrictionCondition"` } type AudienceRestrictionCondition struct { Text string `xml:",chardata"` Audience Text `xml:"Audience"` } type AttributeStatement struct { Text string `xml:",chardata"` Subject Subject `xml:"Subject"` Attribute []Attribute `xml:"Attribute"` } type Subject struct { Text string `xml:",chardata"` NameIdentifier NameIdentifier `xml:"NameIdentifier"` SubjectConfirmation SubjectConfirmation `xml:"SubjectConfirmation"` } type NameIdentifier struct { Text string `xml:",chardata"` Format string `xml:"Format,attr"` } type SubjectConfirmation struct { Text string `xml:",chardata"` ConfirmationMethod Text `xml:"ConfirmationMethod"` } type Attribute struct { Text string `xml:",chardata"` AttributeName string `xml:"AttributeName,attr"` AttributeNamespace string `xml:"AttributeNamespace,attr"` AttributeValue Text `xml:"AttributeValue"` } type AuthenticationStatement struct { Text string `xml:",chardata"` AuthenticationMethod string `xml:"AuthenticationMethod,attr"` AuthenticationInstant string `xml:"AuthenticationInstant,attr"` Subject Subject `xml:"Subject"` } type Signature struct { Text string `xml:",chardata"` Ds string `xml:"ds,attr"` SignedInfo SignedInfo `xml:"SignedInfo"` SignatureValue Text `xml:"SignatureValue"` KeyInfo KeyInfo `xml:"KeyInfo"` } type SignedInfo struct { Text string `xml:",chardata"` CanonicalizationMethod Method `xml:"CanonicalizationMethod"` SignatureMethod Method `xml:"SignatureMethod"` Reference Reference `xml:"Reference"` } type Method struct { Text string `xml:",chardata"` Algorithm string `xml:"Algorithm,attr"` } type Reference struct { Text string `xml:",chardata"` URI string `xml:"URI,attr"` Transforms Transforms `xml:"Transforms"` DigestMethod Method `xml:"DigestMethod"` DigestValue Text `xml:"DigestValue"` } type Transforms struct { Text string `xml:",chardata"` Transform []Method `xml:"Transform"` } type KeyInfo struct { Text string `xml:",chardata"` Xmlns string `xml:"xmlns,attr"` X509Data X509Data `xml:"X509Data"` } type X509Data struct { Text string `xml:",chardata"` X509Certificate Text `xml:"X509Certificate"` } type RequestedAttachedReference struct { Text string `xml:",chardata"` SecurityTokenReference SecurityTokenReference `xml:"SecurityTokenReference"` } type SecurityTokenReference struct { Text string `xml:",chardata"` TokenType string `xml:"TokenType,attr"` O string `xml:"o,attr"` K string `xml:"k,attr"` KeyIdentifier KeyIdentifier `xml:"KeyIdentifier"` } type KeyIdentifier struct { Text string `xml:",chardata"` ValueType string `xml:"ValueType,attr"` } type RequestedUnattachedReference struct { Text string `xml:",chardata"` SecurityTokenReference SecurityTokenReference `xml:"SecurityTokenReference"` } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs/version_string.go000066400000000000000000000012121442026362400341400ustar00rootroot00000000000000// Code generated by "stringer -type=Version"; DO NOT EDIT. package defs import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[TrustUnknown-0] _ = x[Trust2005-1] _ = x[Trust13-2] } const _Version_name = "TrustUnknownTrust2005Trust13" var _Version_index = [...]uint8{0, 12, 21, 28} func (i Version) String() string { if i < 0 || i >= Version(len(_Version_index)-1) { return "Version(" + strconv.FormatInt(int64(i), 10) + ")" } return _Version_name[_Version_index[i]:_Version_index[i+1]] } wstrust_endpoint.go000066400000000000000000000150751442026362400344550ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package defs import ( "encoding/xml" "fmt" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" uuid "github.com/google/uuid" ) //go:generate stringer -type=Version type Version int const ( TrustUnknown Version = iota Trust2005 Trust13 ) // Endpoint represents a WSTrust endpoint. type Endpoint struct { // Version is the version of the endpoint. Version Version // URL is the URL of the endpoint. URL string } type wsTrustTokenRequestEnvelope struct { XMLName xml.Name `xml:"s:Envelope"` Text string `xml:",chardata"` S string `xml:"xmlns:s,attr"` Wsa string `xml:"xmlns:wsa,attr"` Wsu string `xml:"xmlns:wsu,attr"` Header struct { Text string `xml:",chardata"` Action struct { Text string `xml:",chardata"` MustUnderstand string `xml:"s:mustUnderstand,attr"` } `xml:"wsa:Action"` MessageID struct { Text string `xml:",chardata"` } `xml:"wsa:messageID"` ReplyTo struct { Text string `xml:",chardata"` Address struct { Text string `xml:",chardata"` } `xml:"wsa:Address"` } `xml:"wsa:ReplyTo"` To struct { Text string `xml:",chardata"` MustUnderstand string `xml:"s:mustUnderstand,attr"` } `xml:"wsa:To"` Security struct { Text string `xml:",chardata"` MustUnderstand string `xml:"s:mustUnderstand,attr"` Wsse string `xml:"xmlns:wsse,attr"` Timestamp struct { Text string `xml:",chardata"` ID string `xml:"wsu:Id,attr"` Created struct { Text string `xml:",chardata"` } `xml:"wsu:Created"` Expires struct { Text string `xml:",chardata"` } `xml:"wsu:Expires"` } `xml:"wsu:Timestamp"` UsernameToken struct { Text string `xml:",chardata"` ID string `xml:"wsu:Id,attr"` Username struct { Text string `xml:",chardata"` } `xml:"wsse:Username"` Password struct { Text string `xml:",chardata"` } `xml:"wsse:Password"` } `xml:"wsse:UsernameToken"` } `xml:"wsse:Security"` } `xml:"s:Header"` Body struct { Text string `xml:",chardata"` RequestSecurityToken struct { Text string `xml:",chardata"` Wst string `xml:"xmlns:wst,attr"` AppliesTo struct { Text string `xml:",chardata"` Wsp string `xml:"xmlns:wsp,attr"` EndpointReference struct { Text string `xml:",chardata"` Address struct { Text string `xml:",chardata"` } `xml:"wsa:Address"` } `xml:"wsa:EndpointReference"` } `xml:"wsp:AppliesTo"` KeyType struct { Text string `xml:",chardata"` } `xml:"wst:KeyType"` RequestType struct { Text string `xml:",chardata"` } `xml:"wst:RequestType"` } `xml:"wst:RequestSecurityToken"` } `xml:"s:Body"` } func buildTimeString(t time.Time) string { // Golang time formats are weird: https://stackoverflow.com/questions/20234104/how-to-format-current-time-using-a-yyyymmddhhmmss-format return t.Format("2006-01-02T15:04:05.000Z") } func (wte *Endpoint) buildTokenRequestMessage(authType authority.AuthorizeType, cloudAudienceURN string, username string, password string) (string, error) { var soapAction string var trustNamespace string var keyType string var requestType string createdTime := time.Now().UTC() expiresTime := createdTime.Add(10 * time.Minute) switch wte.Version { case Trust2005: soapAction = trust2005Spec trustNamespace = "http://schemas.xmlsoap.org/ws/2005/02/trust" keyType = "http://schemas.xmlsoap.org/ws/2005/05/identity/NoProofKey" requestType = "http://schemas.xmlsoap.org/ws/2005/02/trust/Issue" case Trust13: soapAction = trust13Spec trustNamespace = "http://docs.oasis-open.org/ws-sx/ws-trust/200512" keyType = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearer" requestType = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue" default: return "", fmt.Errorf("buildTokenRequestMessage had Version == %q, which is not recognized", wte.Version) } var envelope wsTrustTokenRequestEnvelope messageUUID := uuid.New() envelope.S = "http://www.w3.org/2003/05/soap-envelope" envelope.Wsa = "http://www.w3.org/2005/08/addressing" envelope.Wsu = "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd" envelope.Header.Action.MustUnderstand = "1" envelope.Header.Action.Text = soapAction envelope.Header.MessageID.Text = "urn:uuid:" + messageUUID.String() envelope.Header.ReplyTo.Address.Text = "http://www.w3.org/2005/08/addressing/anonymous" envelope.Header.To.MustUnderstand = "1" envelope.Header.To.Text = wte.URL switch authType { case authority.ATUnknown: return "", fmt.Errorf("buildTokenRequestMessage had no authority type(%v)", authType) case authority.ATUsernamePassword: endpointUUID := uuid.New() var trustID string if wte.Version == Trust2005 { trustID = "UnPwSecTok2005-" + endpointUUID.String() } else { trustID = "UnPwSecTok13-" + endpointUUID.String() } envelope.Header.Security.MustUnderstand = "1" envelope.Header.Security.Wsse = "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd" envelope.Header.Security.Timestamp.ID = "MSATimeStamp" envelope.Header.Security.Timestamp.Created.Text = buildTimeString(createdTime) envelope.Header.Security.Timestamp.Expires.Text = buildTimeString(expiresTime) envelope.Header.Security.UsernameToken.ID = trustID envelope.Header.Security.UsernameToken.Username.Text = username envelope.Header.Security.UsernameToken.Password.Text = password default: // This is just to note that we don't do anything for other cases. // We aren't missing anything I know of. } envelope.Body.RequestSecurityToken.Wst = trustNamespace envelope.Body.RequestSecurityToken.AppliesTo.Wsp = "http://schemas.xmlsoap.org/ws/2004/09/policy" envelope.Body.RequestSecurityToken.AppliesTo.EndpointReference.Address.Text = cloudAudienceURN envelope.Body.RequestSecurityToken.KeyType.Text = keyType envelope.Body.RequestSecurityToken.RequestType.Text = requestType output, err := xml.Marshal(envelope) if err != nil { return "", err } return string(output), nil } func (wte *Endpoint) BuildTokenRequestMessageWIA(cloudAudienceURN string) (string, error) { return wte.buildTokenRequestMessage(authority.ATWindowsIntegrated, cloudAudienceURN, "", "") } func (wte *Endpoint) BuildTokenRequestMessageUsernamePassword(cloudAudienceURN string, username string, password string) (string, error) { return wte.buildTokenRequestMessage(authority.ATUsernamePassword, cloudAudienceURN, username, password) } wstrust_mex_document.go000066400000000000000000000105311442026362400353140ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package defs import ( "errors" "fmt" "strings" ) //go:generate stringer -type=endpointType type endpointType int const ( etUnknown endpointType = iota etUsernamePassword etWindowsTransport ) type wsEndpointData struct { Version Version EndpointType endpointType } const trust13Spec string = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue" const trust2005Spec string = "http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue" type MexDocument struct { UsernamePasswordEndpoint Endpoint WindowsTransportEndpoint Endpoint policies map[string]endpointType bindings map[string]wsEndpointData } func updateEndpoint(cached *Endpoint, found Endpoint) { if cached == nil || cached.Version == TrustUnknown { *cached = found return } if (*cached).Version == Trust2005 && found.Version == Trust13 { *cached = found return } } // TODO(msal): Someone needs to write tests for everything below. // NewFromDef creates a new MexDocument. func NewFromDef(defs Definitions) (MexDocument, error) { policies, err := policies(defs) if err != nil { return MexDocument{}, err } bindings, err := bindings(defs, policies) if err != nil { return MexDocument{}, err } userPass, windows, err := endpoints(defs, bindings) if err != nil { return MexDocument{}, err } return MexDocument{ UsernamePasswordEndpoint: userPass, WindowsTransportEndpoint: windows, policies: policies, bindings: bindings, }, nil } func policies(defs Definitions) (map[string]endpointType, error) { policies := make(map[string]endpointType, len(defs.Policy)) for _, policy := range defs.Policy { if policy.ExactlyOne.All.NegotiateAuthentication.XMLName.Local != "" { if policy.ExactlyOne.All.TransportBinding.SP != "" && policy.ID != "" { policies["#"+policy.ID] = etWindowsTransport } } if policy.ExactlyOne.All.SignedEncryptedSupportingTokens.Policy.UsernameToken.Policy.WSSUsernameToken10.XMLName.Local != "" { if policy.ExactlyOne.All.TransportBinding.SP != "" && policy.ID != "" { policies["#"+policy.ID] = etUsernamePassword } } if policy.ExactlyOne.All.SignedSupportingTokens.Policy.UsernameToken.Policy.WSSUsernameToken10.XMLName.Local != "" { if policy.ExactlyOne.All.TransportBinding.SP != "" && policy.ID != "" { policies["#"+policy.ID] = etUsernamePassword } } } if len(policies) == 0 { return policies, errors.New("no policies for mex document") } return policies, nil } func bindings(defs Definitions, policies map[string]endpointType) (map[string]wsEndpointData, error) { bindings := make(map[string]wsEndpointData, len(defs.Binding)) for _, binding := range defs.Binding { policyName := binding.PolicyReference.URI transport := binding.Binding.Transport if transport == "http://schemas.xmlsoap.org/soap/http" { if policy, ok := policies[policyName]; ok { bindingName := binding.Name specVersion := binding.Operation.Operation.SoapAction if specVersion == trust13Spec { bindings[bindingName] = wsEndpointData{Trust13, policy} } else if specVersion == trust2005Spec { bindings[bindingName] = wsEndpointData{Trust2005, policy} } else { return nil, errors.New("found unknown spec version in mex document") } } } } return bindings, nil } func endpoints(defs Definitions, bindings map[string]wsEndpointData) (userPass, windows Endpoint, err error) { for _, port := range defs.Service.Port { bindingName := port.Binding index := strings.Index(bindingName, ":") if index != -1 { bindingName = bindingName[index+1:] } if binding, ok := bindings[bindingName]; ok { url := strings.TrimSpace(port.EndpointReference.Address.Text) if url == "" { return Endpoint{}, Endpoint{}, fmt.Errorf("MexDocument cannot have blank URL endpoint") } if binding.Version == TrustUnknown { return Endpoint{}, Endpoint{}, fmt.Errorf("endpoint version unknown") } endpoint := Endpoint{Version: binding.Version, URL: url} switch binding.EndpointType { case etUsernamePassword: updateEndpoint(&userPass, endpoint) case etWindowsTransport: updateEndpoint(&windows, endpoint) default: return Endpoint{}, Endpoint{}, errors.New("found unknown port type in MEX document") } } } return userPass, windows, nil } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/wstrust.go000066400000000000000000000116231442026362400317060ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. /* Package wstrust provides a client for communicating with a WSTrust (https://en.wikipedia.org/wiki/WS-Trust#:~:text=WS%2DTrust%20is%20a%20WS,in%20a%20secure%20message%20exchange.) for the purposes of extracting metadata from the service. This data can be used to acquire tokens using the accesstokens.Client.GetAccessTokenFromSamlGrant() call. */ package wstrust import ( "context" "errors" "fmt" "net/http" "net/url" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/internal/grant" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs" ) type xmlCaller interface { XMLCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, resp interface{}) error SOAPCall(ctx context.Context, endpoint, action string, headers http.Header, qv url.Values, body string, resp interface{}) error } type SamlTokenInfo struct { AssertionType string // Should be either constants SAMLV1Grant or SAMLV2Grant. Assertion string } // Client represents the REST calls to get tokens from token generator backends. type Client struct { // Comm provides the HTTP transport client. Comm xmlCaller } // TODO(msal): This allows me to call Mex without having a real Def file on line 45. // This would fail because policies() would not find a policy. This is easy enough to // fix in test data, but.... Definitions is defined with built in structs. That needs // to be pulled apart and until then I have this hack in. var newFromDef = defs.NewFromDef // Mex provides metadata about a wstrust service. func (c Client) Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error) { resp := defs.Definitions{} err := c.Comm.XMLCall( ctx, federationMetadataURL, http.Header{}, nil, &resp, ) if err != nil { return defs.MexDocument{}, err } return newFromDef(resp) } const ( SoapActionDefault = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue" // Note: Commented out because this action is not supported. It was in the original code // but only used in a switch where it errored. Since there was only one value, a default // worked better. However, buildTokenRequestMessage() had 2005 support. I'm not actually // sure what's going on here. It like we have half support. For now this is here just // for documentation purposes in case we are going to add support. // // SoapActionWSTrust2005 = "http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue" ) // SAMLTokenInfo provides SAML information that is used to generate a SAML token. func (c Client) SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (SamlTokenInfo, error) { var wsTrustRequestMessage string var err error switch authParameters.AuthorizationType { case authority.ATWindowsIntegrated: wsTrustRequestMessage, err = endpoint.BuildTokenRequestMessageWIA(cloudAudienceURN) if err != nil { return SamlTokenInfo{}, err } case authority.ATUsernamePassword: wsTrustRequestMessage, err = endpoint.BuildTokenRequestMessageUsernamePassword( cloudAudienceURN, authParameters.Username, authParameters.Password) if err != nil { return SamlTokenInfo{}, err } default: return SamlTokenInfo{}, fmt.Errorf("unknown auth type %v", authParameters.AuthorizationType) } var soapAction string switch endpoint.Version { case defs.Trust13: soapAction = SoapActionDefault case defs.Trust2005: return SamlTokenInfo{}, errors.New("WS Trust 2005 support is not implemented") default: return SamlTokenInfo{}, fmt.Errorf("the SOAP endpoint for a wstrust call had an invalid version: %v", endpoint.Version) } resp := defs.SAMLDefinitions{} err = c.Comm.SOAPCall(ctx, endpoint.URL, soapAction, http.Header{}, nil, wsTrustRequestMessage, &resp) if err != nil { return SamlTokenInfo{}, err } return c.samlAssertion(resp) } const ( samlv1Assertion = "urn:oasis:names:tc:SAML:1.0:assertion" samlv2Assertion = "urn:oasis:names:tc:SAML:2.0:assertion" ) func (c Client) samlAssertion(def defs.SAMLDefinitions) (SamlTokenInfo, error) { for _, tokenResponse := range def.Body.RequestSecurityTokenResponseCollection.RequestSecurityTokenResponse { token := tokenResponse.RequestedSecurityToken if token.Assertion.XMLName.Local != "" { assertion := token.AssertionRawXML samlVersion := token.Assertion.Saml switch samlVersion { case samlv1Assertion: return SamlTokenInfo{AssertionType: grant.SAMLV1, Assertion: assertion}, nil case samlv2Assertion: return SamlTokenInfo{AssertionType: grant.SAMLV2, Assertion: assertion}, nil } return SamlTokenInfo{}, fmt.Errorf("couldn't parse SAML assertion, version unknown: %q", samlVersion) } } return SamlTokenInfo{}, errors.New("unknown WS-Trust version") } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/wstrust_test.go000066400000000000000000000362401442026362400327470ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package wstrust import ( "context" "encoding/xml" "errors" "fmt" "net/http" "net/url" "reflect" "regexp" "strings" "testing" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs" "github.com/kylelemons/godebug/diff" "github.com/kylelemons/godebug/pretty" ) var testAuthorityEndpoints = authority.NewEndpoints( "https://login.microsoftonline.com/v2.0/authorize", "https://login.microsoftonline.com/v2.0/token", "https://login.microsoftonline.com/v2.0", "login.microsoftonline.com", ) type fakeXMLCaller struct { err bool giveResp interface{} gotAction string gotEndpoint string gotQV url.Values gotHeaders http.Header gotBody interface{} gotResp interface{} } func (f *fakeXMLCaller) XMLCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, resp interface{}) error { if f.err { return errors.New("error") } f.gotEndpoint = endpoint f.gotHeaders = headers f.gotQV = qv f.gotResp = resp return nil } func (f *fakeXMLCaller) SOAPCall(ctx context.Context, endpoint, action string, headers http.Header, qv url.Values, body string, resp interface{}) error { if f.err { return errors.New("error") } f.gotEndpoint = endpoint f.gotAction = action f.gotHeaders = headers f.gotQV = qv f.gotBody = body f.gotResp = resp if f.giveResp != nil { b, err := xml.MarshalIndent(f.giveResp, "", "\t") if err != nil { panic(err) } if err := xml.Unmarshal(b, resp); err != nil { panic(err) } } return nil } func (f *fakeXMLCaller) compareBase(endpoint string, headers http.Header, qv url.Values, resp interface{}) error { if f.gotEndpoint != endpoint { return fmt.Errorf("got endpoint == %s, want endpoint == %s", f.gotEndpoint, endpoint) } if diff := pretty.Compare(headers, f.gotHeaders); diff != "" { return fmt.Errorf("headers -want/+got:\n%s", diff) } if diff := pretty.Compare(qv, f.gotQV); diff != "" { return fmt.Errorf("qv -want/+got:\n%s", diff) } gotValue := reflect.ValueOf(f.gotResp) if gotValue.Kind() != reflect.Ptr { return fmt.Errorf("resp cannot be a non-pointer type") } gotValue = gotValue.Elem() gotName := gotValue.Type().Name() wantName := reflect.ValueOf(resp).Elem().Type().Name() if gotName != wantName { return fmt.Errorf("resp type was %s, want %s", gotName, wantName) } return nil } func (f *fakeXMLCaller) compareXML(endpoint string, resp interface{}) error { if err := f.compareBase(endpoint, http.Header{}, url.Values{}, resp); err != nil { return err } return nil } var replaceURNRE = regexp.MustCompile(`urn:uuid:.*`) func (f *fakeXMLCaller) compareSOAP(action, endpoint string, body, resp interface{}) error { if err := f.compareBase(endpoint, http.Header{}, nil, resp); err != nil { return err } if f.gotAction != action { return fmt.Errorf("got endpoint == %s, want endpoint == %s", f.gotEndpoint, endpoint) } // Removes a uuid that will change every time. // example: `urn:uuid:373ea2fa-d586-4cad-8bb8-10392ddbb5c6`` bodyStr := replaceURNRE.ReplaceAllString(body.(string), ``) gotBodyStr := replaceURNRE.ReplaceAllString(body.(string), ``) if diff := diff.Diff( strings.ReplaceAll(bodyStr, ">", ">\n"), // So we can do a line by line comparison strings.ReplaceAll(gotBodyStr, ">", ">\n"), // So we can do a line by line comparison ); diff != "" { return fmt.Errorf("body -want/+got:\n%s", diff) } return nil } func TestMex(t *testing.T) { tests := []struct { desc string err bool createErr bool newFromDef func(d defs.Definitions) (defs.MexDocument, error) federationMetadataURL string }{ { desc: "Error: comm returns error", err: true, }, { desc: "Definition was bad", federationMetadataURL: "", newFromDef: func(d defs.Definitions) (defs.MexDocument, error) { return defs.MexDocument{}, errors.New("error") }, err: true, }, { desc: "Success", federationMetadataURL: "", newFromDef: func(d defs.Definitions) (defs.MexDocument, error) { return defs.MexDocument{}, nil }, }, } defer func() { newFromDef = defs.NewFromDef }() for _, test := range tests { newFromDef = test.newFromDef fake := &fakeXMLCaller{err: test.err} client := Client{Comm: fake} // We don't care about the result, that is just a translation from the XML handled // in the comm package via wstrust.CreateWsTrustMexDocumentFromDef(). // We care only that the comm package got what the right inputs. _, err := client.Mex(context.Background(), "http://something") switch { case err == nil && test.err: t.Errorf("TestMex(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestMex(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compareXML("http://something", &defs.Definitions{}); err != nil { t.Errorf("TestMex(%s): %s", test.desc, err) } } } func TestSAMLTokenInfo(t *testing.T) { authParams := authority.AuthParams{ Username: "username", Password: "password", Endpoints: testAuthorityEndpoints, ClientID: "clientID", } // Note: We don't tests any error conditions built on buildTokenRequestMessage(), // as they can only fail if the xml marshaller fails. tests := []struct { desc string err bool commErr bool endpoint defs.Endpoint body string action string authorizationType authority.AuthorizeType giveResp defs.SAMLDefinitions }{ { desc: "Error: comm returns error", err: true, commErr: true, endpoint: defs.Endpoint{Version: defs.Trust13, URL: "upEndpoint"}, action: SoapActionDefault, authorizationType: authority.ATWindowsIntegrated, body: "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issueurn:uuid:fb8ec65b-f117-468f-b4e8-50c5e802affehttp://www.w3.org/2005/08/addressing/anonymousupEndpointurnhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearerhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue", giveResp: defs.SAMLDefinitions{ Body: defs.Body{ RequestSecurityTokenResponseCollection: defs.RequestSecurityTokenResponseCollection{ RequestSecurityTokenResponse: []defs.RequestSecurityTokenResponse{ { RequestedSecurityToken: defs.RequestedSecurityToken{ Assertion: defs.Assertion{ Text: "hello", XMLName: xml.Name{ Local: "Assertion", }, Saml: samlv1Assertion, }, }, }, }, }, }, }, }, { desc: "Error: Trust2005 endpoint, which isn't supported", err: true, endpoint: defs.Endpoint{Version: defs.Trust2005, URL: "upEndpoint"}, action: SoapActionDefault, authorizationType: authority.ATWindowsIntegrated, body: "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issueurn:uuid:fb8ec65b-f117-468f-b4e8-50c5e802affehttp://www.w3.org/2005/08/addressing/anonymousupEndpointurnhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearerhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue", giveResp: defs.SAMLDefinitions{ Body: defs.Body{ RequestSecurityTokenResponseCollection: defs.RequestSecurityTokenResponseCollection{ RequestSecurityTokenResponse: []defs.RequestSecurityTokenResponse{ { RequestedSecurityToken: defs.RequestedSecurityToken{ Assertion: defs.Assertion{ Text: "hello", XMLName: xml.Name{ Local: "Assertion", }, Saml: samlv1Assertion, }, }, }, }, }, }, }, }, { desc: "Success: SAMLV1 assertion with AuthorizationTypeWindowsIntegratedAuth", endpoint: defs.Endpoint{Version: defs.Trust13, URL: "upEndpoint"}, action: SoapActionDefault, authorizationType: authority.ATWindowsIntegrated, body: "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issueurn:uuid:fb8ec65b-f117-468f-b4e8-50c5e802affehttp://www.w3.org/2005/08/addressing/anonymousupEndpointurnhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearerhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue", giveResp: defs.SAMLDefinitions{ Body: defs.Body{ RequestSecurityTokenResponseCollection: defs.RequestSecurityTokenResponseCollection{ RequestSecurityTokenResponse: []defs.RequestSecurityTokenResponse{ { RequestedSecurityToken: defs.RequestedSecurityToken{ Assertion: defs.Assertion{ Text: "hello", XMLName: xml.Name{ Local: "Assertion", }, Saml: samlv1Assertion, }, }, }, }, }, }, }, }, { desc: "Success: SAMLV2 assertion with AuthorizationTypeUsernamePassword", endpoint: defs.Endpoint{Version: defs.Trust13, URL: "upEndpoint"}, action: SoapActionDefault, authorizationType: authority.ATUsernamePassword, body: "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issueurn:uuid:fb8ec65b-f117-468f-b4e8-50c5e802affehttp://www.w3.org/2005/08/addressing/anonymousupEndpointurnhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearerhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue", giveResp: defs.SAMLDefinitions{ Body: defs.Body{ RequestSecurityTokenResponseCollection: defs.RequestSecurityTokenResponseCollection{ RequestSecurityTokenResponse: []defs.RequestSecurityTokenResponse{ { RequestedSecurityToken: defs.RequestedSecurityToken{ Assertion: defs.Assertion{ Text: "hello", XMLName: xml.Name{ Local: "Assertion", }, Saml: samlv2Assertion, }, }, }, }, }, }, }, }, } for _, test := range tests { fake := &fakeXMLCaller{err: test.commErr, giveResp: test.giveResp} client := Client{Comm: fake} authParams.AuthorizationType = test.authorizationType // We don't care about the result, that is just a translation from the XML handled // in the comm package via wstrust.CreateWsTrustMexDocumentFromDef(). // We care only that the comm package got the right inputs. _, err := client.SAMLTokenInfo(context.Background(), authParams, "urn", test.endpoint) switch { case err == nil && test.err: t.Errorf("TestSAMLTokenInfo(%s): got err == nil , want err != nil", test.desc) continue case err != nil && !test.err: t.Errorf("TestSAMLTokenInfo(%s): got err == %s , want err == nil", test.desc, err) continue case err != nil: continue } if err := fake.compareSOAP(test.action, test.endpoint.URL, test.body, &defs.SAMLDefinitions{}); err != nil { t.Errorf("TestSAMLTokenInfo(%s): %s", test.desc, err) } } } microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/resolvers.go000066400000000000000000000114621442026362400276440ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // TODO(msal): Write some tests. The original code this came from didn't have tests and I'm too // tired at this point to do it. It, like many other *Manager code I found was broken because // they didn't have mutex protection. package oauth import ( "context" "errors" "fmt" "strings" "sync" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" ) // ADFS is an active directory federation service authority type. const ADFS = "ADFS" type cacheEntry struct { Endpoints authority.Endpoints ValidForDomainsInList map[string]bool } func createcacheEntry(endpoints authority.Endpoints) cacheEntry { return cacheEntry{endpoints, map[string]bool{}} } // AuthorityEndpoint retrieves endpoints from an authority for auth and token acquisition. type authorityEndpoint struct { rest *ops.REST mu sync.Mutex cache map[string]cacheEntry } // newAuthorityEndpoint is the constructor for AuthorityEndpoint. func newAuthorityEndpoint(rest *ops.REST) *authorityEndpoint { m := &authorityEndpoint{rest: rest, cache: map[string]cacheEntry{}} return m } // ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) { if endpoints, found := m.cachedEndpoints(authorityInfo, userPrincipalName); found { return endpoints, nil } endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName) if err != nil { return authority.Endpoints{}, err } resp, err := m.rest.Authority().GetTenantDiscoveryResponse(ctx, endpoint) if err != nil { return authority.Endpoints{}, err } if err := resp.Validate(); err != nil { return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err) } tenant := authorityInfo.Tenant endpoints := authority.NewEndpoints( strings.Replace(resp.AuthorizationEndpoint, "{tenant}", tenant, -1), strings.Replace(resp.TokenEndpoint, "{tenant}", tenant, -1), strings.Replace(resp.Issuer, "{tenant}", tenant, -1), authorityInfo.Host) m.addCachedEndpoints(authorityInfo, userPrincipalName, endpoints) return endpoints, nil } // cachedEndpoints returns a the cached endpoints if they exists. If not, we return false. func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, bool) { m.mu.Lock() defer m.mu.Unlock() if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok { if authorityInfo.AuthorityType == ADFS { domain, err := adfsDomainFromUpn(userPrincipalName) if err == nil { if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok { return cacheEntry.Endpoints, true } } } return cacheEntry.Endpoints, true } return authority.Endpoints{}, false } func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, userPrincipalName string, endpoints authority.Endpoints) { m.mu.Lock() defer m.mu.Unlock() updatedCacheEntry := createcacheEntry(endpoints) if authorityInfo.AuthorityType == ADFS { // Since we're here, we've made a call to the backend. We want to ensure we're caching // the latest values from the server. if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok { for k := range cacheEntry.ValidForDomainsInList { updatedCacheEntry.ValidForDomainsInList[k] = true } } domain, err := adfsDomainFromUpn(userPrincipalName) if err == nil { updatedCacheEntry.ValidForDomainsInList[domain] = true } } m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry } func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) { if authorityInfo.Tenant == "adfs" { return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil } else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) { resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo) if err != nil { return "", err } return resp.TenantDiscoveryEndpoint, nil } else if authorityInfo.Region != "" { resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo) if err != nil { return "", err } return resp.TenantDiscoveryEndpoint, nil } return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil } func adfsDomainFromUpn(userPrincipalName string) (string, error) { parts := strings.Split(userPrincipalName, "@") if len(parts) < 2 { return "", errors.New("no @ present in user principal name") } return parts[1], nil } microsoft-authentication-library-for-go-1.0.0/apps/internal/options/000077500000000000000000000000001442026362400256405ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/options/options.go000066400000000000000000000027571442026362400276750ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package options import ( "errors" "fmt" ) // CallOption implements an optional argument to a method call. See // https://blog.devgenius.io/go-call-option-that-can-be-used-with-multiple-methods-6c81734f3dbe // for an explanation of the usage pattern. type CallOption interface { Do(any) error callOption() } // ApplyOptions applies all the callOptions to options. options must be a pointer to a struct and // callOptions must be a list of objects that implement CallOption. func ApplyOptions[O, C any](options O, callOptions []C) error { for _, o := range callOptions { if t, ok := any(o).(CallOption); !ok { return fmt.Errorf("unexpected option type %T", o) } else if err := t.Do(options); err != nil { return err } } return nil } // NewCallOption returns a new CallOption whose Do() method calls function "f". func NewCallOption(f func(any) error) CallOption { if f == nil { // This isn't a practical concern because only an MSAL maintainer can get // us here, by implementing a do-nothing option. But if someone does that, // the below ensures the method invoked with the option returns an error. return callOption(func(any) error { return errors.New("invalid option: missing implementation") }) } return callOption(f) } // callOption is an adapter for a function to a CallOption type callOption func(any) error func (c callOption) Do(a any) error { return c(a) } func (callOption) callOption() {} microsoft-authentication-library-for-go-1.0.0/apps/internal/shared/000077500000000000000000000000001442026362400254135ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/shared/shared.go000066400000000000000000000037571442026362400272240ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package shared import ( "net/http" "reflect" "strings" ) const ( // CacheKeySeparator is used in creating the keys of the cache. CacheKeySeparator = "-" ) type Account struct { HomeAccountID string `json:"home_account_id,omitempty"` Environment string `json:"environment,omitempty"` Realm string `json:"realm,omitempty"` LocalAccountID string `json:"local_account_id,omitempty"` AuthorityType string `json:"authority_type,omitempty"` PreferredUsername string `json:"username,omitempty"` GivenName string `json:"given_name,omitempty"` FamilyName string `json:"family_name,omitempty"` MiddleName string `json:"middle_name,omitempty"` Name string `json:"name,omitempty"` AlternativeID string `json:"alternative_account_id,omitempty"` RawClientInfo string `json:"client_info,omitempty"` UserAssertionHash string `json:"user_assertion_hash,omitempty"` AdditionalFields map[string]interface{} } // NewAccount creates an account. func NewAccount(homeAccountID, env, realm, localAccountID, authorityType, username string) Account { return Account{ HomeAccountID: homeAccountID, Environment: env, Realm: realm, LocalAccountID: localAccountID, AuthorityType: authorityType, PreferredUsername: username, } } // Key creates the key for storing accounts in the cache. func (acc Account) Key() string { return strings.Join([]string{acc.HomeAccountID, acc.Environment, acc.Realm}, CacheKeySeparator) } // IsZero checks the zero value of account. func (acc Account) IsZero() bool { v := reflect.ValueOf(acc) for i := 0; i < v.NumField(); i++ { field := v.Field(i) if !field.IsZero() { switch field.Kind() { case reflect.Map, reflect.Slice: if field.Len() == 0 { continue } } return false } } return true } // DefaultClient is our default shared HTTP client. var DefaultClient = &http.Client{} microsoft-authentication-library-for-go-1.0.0/apps/internal/shared/shared_test.go000066400000000000000000000042171442026362400302530ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package shared import ( stdJSON "encoding/json" "testing" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json" "github.com/kylelemons/godebug/pretty" ) var ( accHID = "hid" accEnv = "env" accRealm = "realm" authType = "MSSTS" accLid = "lid" accUser = "user" ) func TestAccountUnmarshal(t *testing.T) { jsonMap := map[string]interface{}{ "home_account_id": "hid", "environment": "env", "extra": "this_is_extra", "authority_type": authType, } b, err := stdJSON.Marshal(jsonMap) if err != nil { panic(err) } want := Account{ HomeAccountID: accHID, Environment: accEnv, AuthorityType: authType, AdditionalFields: map[string]interface{}{ "extra": json.MarshalRaw("this_is_extra"), }, } got := Account{} err = json.Unmarshal(b, &got) if err != nil { panic(err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestAccountUnmarshal: -want/+got:\n%s", diff) } } func TestAccountKey(t *testing.T) { acc := &Account{ HomeAccountID: accHID, Environment: accEnv, Realm: accRealm, } expectedKey := "hid-env-realm" actualKey := acc.Key() if expectedKey != actualKey { t.Errorf("Actual key %s differs from expected key %s", actualKey, expectedKey) } } func TestAccountMarshal(t *testing.T) { acc := Account{ HomeAccountID: accHID, Environment: accEnv, Realm: accRealm, LocalAccountID: accLid, AuthorityType: authType, PreferredUsername: accUser, AdditionalFields: map[string]interface{}{"extra": "extra"}, } want := map[string]interface{}{ "home_account_id": "hid", "environment": "env", "realm": "realm", "local_account_id": "lid", "authority_type": authType, "username": "user", "extra": "extra", } b, err := json.Marshal(acc) if err != nil { panic(err) } got := map[string]interface{}{} if err := stdJSON.Unmarshal(b, &got); err != nil { panic(err) } if diff := pretty.Compare(want, got); diff != "" { t.Errorf("TestAccountMarshal: -want/+got:\n%s", diff) } } microsoft-authentication-library-for-go-1.0.0/apps/internal/version/000077500000000000000000000000001442026362400256325ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/internal/version/version.go000066400000000000000000000004151442026362400276460ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // Package version keeps the version number of the client package. package version // Version is the version of this client package that is communicated to the server. const Version = "1.0.0" microsoft-authentication-library-for-go-1.0.0/apps/public/000077500000000000000000000000001442026362400236075ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/public/public.go000066400000000000000000000522601442026362400254210ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. /* Package public provides a client for authentication of "public" applications. A "public" application is defined as an app that runs on client devices (android, ios, windows, linux, ...). These devices are "untrusted" and access resources via web APIs that must authenticate. */ package public /* Design note: public.Client uses client.Base as an embedded type. client.Base statically assigns its attributes during creation. As it doesn't have any pointers in it, anything borrowed from it, such as Base.AuthParams is a copy that is free to be manipulated here. */ // TODO(msal): This should have example code for each method on client using Go's example doc framework. // base usage details should be includee in the package documentation. import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "fmt" "net/url" "strconv" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/local" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/options" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" "github.com/google/uuid" "github.com/pkg/browser" ) // AuthResult contains the results of one token acquisition operation. // For details see https://aka.ms/msal-net-authenticationresult type AuthResult = base.AuthResult type Account = shared.Account // clientOptions configures the Client's behavior. type clientOptions struct { accessor cache.ExportReplace authority string capabilities []string disableInstanceDiscovery bool httpClient ops.HTTPClient } func (p *clientOptions) validate() error { u, err := url.Parse(p.authority) if err != nil { return fmt.Errorf("Authority options cannot be URL parsed: %w", err) } if u.Scheme != "https" { return fmt.Errorf("Authority(%s) did not start with https://", u.String()) } return nil } // Option is an optional argument to the New constructor. type Option func(o *clientOptions) // WithAuthority allows for a custom authority to be set. This must be a valid https url. func WithAuthority(authority string) Option { return func(o *clientOptions) { o.authority = authority } } // WithCache provides an accessor that will read and write authentication data to an externally managed cache. func WithCache(accessor cache.ExportReplace) Option { return func(o *clientOptions) { o.accessor = accessor } } // WithClientCapabilities allows configuring one or more client capabilities such as "CP1" func WithClientCapabilities(capabilities []string) Option { return func(o *clientOptions) { // there's no danger of sharing the slice's underlying memory with the application because // this slice is simply passed to base.WithClientCapabilities, which copies its data o.capabilities = capabilities } } // WithHTTPClient allows for a custom HTTP client to be set. func WithHTTPClient(httpClient ops.HTTPClient) Option { return func(o *clientOptions) { o.httpClient = httpClient } } // WithInstanceDiscovery set to false to disable authority validation (to support private cloud scenarios) func WithInstanceDiscovery(enabled bool) Option { return func(o *clientOptions) { o.disableInstanceDiscovery = !enabled } } // Client is a representation of authentication client for public applications as defined in the // package doc. For more information, visit https://docs.microsoft.com/azure/active-directory/develop/msal-client-applications. type Client struct { base base.Client } // New is the constructor for Client. func New(clientID string, options ...Option) (Client, error) { opts := clientOptions{ authority: base.AuthorityPublicCloud, httpClient: shared.DefaultClient, } for _, o := range options { o(&opts) } if err := opts.validate(); err != nil { return Client{}, err } base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery)) if err != nil { return Client{}, err } return Client{base}, nil } // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { claims, loginHint, tenantID, domainHint string } // AuthCodeURLOption is implemented by options for AuthCodeURL type AuthCodeURLOption interface { authCodeURLOption() } // AuthCodeURL creates a URL used to acquire an authorization code. // // Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return "", err } ap, err := pca.base.AuthParams.WithTenant(o.tenantID) if err != nil { return "", err } ap.Claims = o.claims ap.LoginHint = o.loginHint ap.DomainHint = o.domainHint return pca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap) } // WithClaims sets additional claims to request for the token, such as those required by conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. // This option is valid for any token acquisition method. func WithClaims(claims string) interface { AcquireByAuthCodeOption AcquireByDeviceCodeOption AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption AuthCodeURLOption options.CallOption } { return struct { AcquireByAuthCodeOption AcquireByDeviceCodeOption AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *acquireTokenByAuthCodeOptions: t.claims = claims case *acquireTokenByDeviceCodeOptions: t.claims = claims case *acquireTokenByUsernamePasswordOptions: t.claims = claims case *acquireTokenSilentOptions: t.claims = claims case *authCodeURLOptions: t.claims = claims case *interactiveAuthOptions: t.claims = claims default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // WithTenantID specifies a tenant for a single authentication. It may be different than the tenant set in [New] by [WithAuthority]. // This option is valid for any token acquisition method. func WithTenantID(tenantID string) interface { AcquireByAuthCodeOption AcquireByDeviceCodeOption AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption AuthCodeURLOption options.CallOption } { return struct { AcquireByAuthCodeOption AcquireByDeviceCodeOption AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *acquireTokenByAuthCodeOptions: t.tenantID = tenantID case *acquireTokenByDeviceCodeOptions: t.tenantID = tenantID case *acquireTokenByUsernamePasswordOptions: t.tenantID = tenantID case *acquireTokenSilentOptions: t.tenantID = tenantID case *authCodeURLOptions: t.tenantID = tenantID case *interactiveAuthOptions: t.tenantID = tenantID default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // acquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. // These are set by using various AcquireTokenSilentOption functions. type acquireTokenSilentOptions struct { account Account claims, tenantID string } // AcquireSilentOption is implemented by options for AcquireTokenSilent type AcquireSilentOption interface { acquireSilentOption() } // WithSilentAccount uses the passed account during an AcquireTokenSilent() call. func WithSilentAccount(account Account) interface { AcquireSilentOption options.CallOption } { return struct { AcquireSilentOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *acquireTokenSilentOptions: t.account = account default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // AcquireTokenSilent acquires a token from either the cache or using a refresh token. // // Options: [WithClaims], [WithSilentAccount], [WithTenantID] func (pca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts ...AcquireSilentOption) (AuthResult, error) { o := acquireTokenSilentOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } silentParameters := base.AcquireTokenSilentParameters{ Scopes: scopes, Account: o.account, Claims: o.claims, RequestType: accesstokens.ATPublic, IsAppCache: false, TenantID: o.tenantID, } return pca.base.AcquireTokenSilent(ctx, silentParameters) } // acquireTokenByUsernamePasswordOptions contains optional configuration for AcquireTokenByUsernamePassword type acquireTokenByUsernamePasswordOptions struct { claims, tenantID string } // AcquireByUsernamePasswordOption is implemented by options for AcquireTokenByUsernamePassword type AcquireByUsernamePasswordOption interface { acquireByUsernamePasswordOption() } // AcquireTokenByUsernamePassword acquires a security token from the authority, via Username/Password Authentication. // NOTE: this flow is NOT recommended. // // Options: [WithClaims], [WithTenantID] func (pca Client) AcquireTokenByUsernamePassword(ctx context.Context, scopes []string, username, password string, opts ...AcquireByUsernamePasswordOption) (AuthResult, error) { o := acquireTokenByUsernamePasswordOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } authParams, err := pca.base.AuthParams.WithTenant(o.tenantID) if err != nil { return AuthResult{}, err } authParams.Scopes = scopes authParams.AuthorizationType = authority.ATUsernamePassword authParams.Claims = o.claims authParams.Username = username authParams.Password = password token, err := pca.base.Token.UsernamePassword(ctx, authParams) if err != nil { return AuthResult{}, err } return pca.base.AuthResultFromToken(ctx, authParams, token, true) } type DeviceCodeResult = accesstokens.DeviceCodeResult // DeviceCode provides the results of the device code flows first stage (containing the code) // that must be entered on the second device and provides a method to retrieve the AuthenticationResult // once that code has been entered and verified. type DeviceCode struct { // Result holds the information about the device code (such as the code). Result DeviceCodeResult authParams authority.AuthParams client Client dc oauth.DeviceCode } // AuthenticationResult retreives the AuthenticationResult once the user enters the code // on the second device. Until then it blocks until the .AcquireTokenByDeviceCode() context // is cancelled or the token expires. func (d DeviceCode) AuthenticationResult(ctx context.Context) (AuthResult, error) { token, err := d.dc.Token(ctx) if err != nil { return AuthResult{}, err } return d.client.base.AuthResultFromToken(ctx, d.authParams, token, true) } // acquireTokenByDeviceCodeOptions contains optional configuration for AcquireTokenByDeviceCode type acquireTokenByDeviceCodeOptions struct { claims, tenantID string } // AcquireByDeviceCodeOption is implemented by options for AcquireTokenByDeviceCode type AcquireByDeviceCodeOption interface { acquireByDeviceCodeOptions() } // AcquireTokenByDeviceCode acquires a security token from the authority, by acquiring a device code and using that to acquire the token. // Users need to create an AcquireTokenDeviceCodeParameters instance and pass it in. // // Options: [WithClaims], [WithTenantID] func (pca Client) AcquireTokenByDeviceCode(ctx context.Context, scopes []string, opts ...AcquireByDeviceCodeOption) (DeviceCode, error) { o := acquireTokenByDeviceCodeOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return DeviceCode{}, err } authParams, err := pca.base.AuthParams.WithTenant(o.tenantID) if err != nil { return DeviceCode{}, err } authParams.Scopes = scopes authParams.AuthorizationType = authority.ATDeviceCode authParams.Claims = o.claims dc, err := pca.base.Token.DeviceCode(ctx, authParams) if err != nil { return DeviceCode{}, err } return DeviceCode{Result: dc.Result, authParams: authParams, client: pca, dc: dc}, nil } // acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. type acquireTokenByAuthCodeOptions struct { challenge, claims, tenantID string } // AcquireByAuthCodeOption is implemented by options for AcquireTokenByAuthCode type AcquireByAuthCodeOption interface { acquireByAuthCodeOption() } // WithChallenge allows you to provide a code for the .AcquireTokenByAuthCode() call. func WithChallenge(challenge string) interface { AcquireByAuthCodeOption options.CallOption } { return struct { AcquireByAuthCodeOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *acquireTokenByAuthCodeOptions: t.challenge = challenge default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // AcquireTokenByAuthCode is a request to acquire a security token from the authority, using an authorization code. // The specified redirect URI must be the same URI that was used when the authorization code was requested. // // Options: [WithChallenge], [WithClaims], [WithTenantID] func (pca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, opts ...AcquireByAuthCodeOption) (AuthResult, error) { o := acquireTokenByAuthCodeOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } params := base.AcquireTokenAuthCodeParameters{ Scopes: scopes, Code: code, Challenge: o.challenge, Claims: o.claims, AppType: accesstokens.ATPublic, RedirectURI: redirectURI, TenantID: o.tenantID, } return pca.base.AcquireTokenByAuthCode(ctx, params) } // Accounts gets all the accounts in the token cache. // If there are no accounts in the cache the returned slice is empty. func (pca Client) Accounts(ctx context.Context) ([]Account, error) { return pca.base.AllAccounts(ctx) } // RemoveAccount signs the account out and forgets account from token cache. func (pca Client) RemoveAccount(ctx context.Context, account Account) error { return pca.base.RemoveAccount(ctx, account) } // interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. type interactiveAuthOptions struct { claims, domainHint, loginHint, redirectURI, tenantID string } // AcquireInteractiveOption is implemented by options for AcquireTokenInteractive type AcquireInteractiveOption interface { acquireInteractiveOption() } // WithLoginHint pre-populates the login prompt with a username. func WithLoginHint(username string) interface { AcquireInteractiveOption AuthCodeURLOption options.CallOption } { return struct { AcquireInteractiveOption AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *authCodeURLOptions: t.loginHint = username case *interactiveAuthOptions: t.loginHint = username default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // WithDomainHint adds the IdP domain as domain_hint query parameter in the auth url. func WithDomainHint(domain string) interface { AcquireInteractiveOption AuthCodeURLOption options.CallOption } { return struct { AcquireInteractiveOption AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *authCodeURLOptions: t.domainHint = domain case *interactiveAuthOptions: t.domainHint = domain default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // WithRedirectURI sets a port for the local server used in interactive authentication, for // example http://localhost:port. All URI components other than the port are ignored. func WithRedirectURI(redirectURI string) interface { AcquireInteractiveOption options.CallOption } { return struct { AcquireInteractiveOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { case *interactiveAuthOptions: t.redirectURI = redirectURI default: return fmt.Errorf("unexpected options type %T", a) } return nil }, ), } } // AcquireTokenInteractive acquires a security token from the authority using the default web browser to select the account. // https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-authentication-flows#interactive-and-non-interactive-authentication // // Options: [WithDomainHint], [WithLoginHint], [WithRedirectURI], [WithTenantID] func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, opts ...AcquireInteractiveOption) (AuthResult, error) { o := interactiveAuthOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } // the code verifier is a random 32-byte sequence that's been base-64 encoded without padding. // it's used to prevent MitM attacks during auth code flow, see https://tools.ietf.org/html/rfc7636 cv, challenge, err := codeVerifier() if err != nil { return AuthResult{}, err } var redirectURL *url.URL if o.redirectURI != "" { redirectURL, err = url.Parse(o.redirectURI) if err != nil { return AuthResult{}, err } } authParams, err := pca.base.AuthParams.WithTenant(o.tenantID) if err != nil { return AuthResult{}, err } authParams.Scopes = scopes authParams.AuthorizationType = authority.ATInteractive authParams.Claims = o.claims authParams.CodeChallenge = challenge authParams.CodeChallengeMethod = "S256" authParams.LoginHint = o.loginHint authParams.DomainHint = o.domainHint authParams.State = uuid.New().String() authParams.Prompt = "select_account" res, err := pca.browserLogin(ctx, redirectURL, authParams) if err != nil { return AuthResult{}, err } authParams.Redirecturi = res.redirectURI req, err := accesstokens.NewCodeChallengeRequest(authParams, accesstokens.ATPublic, nil, res.authCode, cv) if err != nil { return AuthResult{}, err } token, err := pca.base.Token.AuthCode(ctx, req) if err != nil { return AuthResult{}, err } return pca.base.AuthResultFromToken(ctx, authParams, token, true) } type interactiveAuthResult struct { authCode string redirectURI string } // provides a test hook to simulate opening a browser var browserOpenURL = func(authURL string) error { return browser.OpenURL(authURL) } // parses the port number from the provided URL. // returns 0 if nil or no port is specified. func parsePort(u *url.URL) (int, error) { if u == nil { return 0, nil } p := u.Port() if p == "" { return 0, nil } return strconv.Atoi(p) } // browserLogin launches the system browser for interactive login func (pca Client) browserLogin(ctx context.Context, redirectURI *url.URL, params authority.AuthParams) (interactiveAuthResult, error) { // start local redirect server so login can call us back port, err := parsePort(redirectURI) if err != nil { return interactiveAuthResult{}, err } srv, err := local.New(params.State, port) if err != nil { return interactiveAuthResult{}, err } defer srv.Shutdown() params.Scopes = accesstokens.AppendDefaultScopes(params) authURL, err := pca.base.AuthCodeURL(ctx, params.ClientID, srv.Addr, params.Scopes, params) if err != nil { return interactiveAuthResult{}, err } // open browser window so user can select credentials if err := browserOpenURL(authURL); err != nil { return interactiveAuthResult{}, err } // now wait until the logic calls us back res := srv.Result(ctx) if res.Err != nil { return interactiveAuthResult{}, res.Err } return interactiveAuthResult{ authCode: res.Code, redirectURI: srv.Addr, }, nil } // creates a code verifier string along with its SHA256 hash which // is used as the challenge when requesting an auth code. // used in interactive auth flow for PKCE. func codeVerifier() (codeVerifier string, challenge string, err error) { cvBytes := make([]byte, 32) if _, err = rand.Read(cvBytes); err != nil { return } codeVerifier = base64.RawURLEncoding.EncodeToString(cvBytes) // for PKCE, create a hash of the code verifier cvh := sha256.Sum256([]byte(codeVerifier)) challenge = base64.RawURLEncoding.EncodeToString(cvh[:]) return } microsoft-authentication-library-for-go-1.0.0/apps/public/public_test.go000066400000000000000000000736431442026362400264700ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package public import ( "context" "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "net/url" "strings" "testing" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" "github.com/kylelemons/godebug/pretty" ) const authorityFmt = "https://%s/%s" var tokenScope = []string{"the_scope"} func fakeBrowserOpenURL(authURL string) error { // we will get called with the URL for requesting an auth code u, err := url.Parse(authURL) if err != nil { return err } // validate the URL content q := u.Query() if q.Get("code_challenge") == "" { return errors.New("missing query param 'code_challenge") } if m := q.Get("code_challenge_method"); m != "S256" { return fmt.Errorf("unexpected code_challenge_method '%s'", m) } if q.Get("prompt") == "" { return errors.New("missing query param 'prompt") } state := q.Get("state") if state == "" { return errors.New("missing query param 'state'") } redirect := q.Get("redirect_uri") if redirect == "" { return errors.New("missing query param 'redirect_uri'") } // now send the info to our local redirect server resp, err := http.DefaultClient.Get(redirect + fmt.Sprintf("/?state=%s&code=fake_auth_code", state)) if err != nil { return err } if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status code %d", resp.StatusCode) } return nil } func TestAcquireTokenInteractive(t *testing.T) { realBrowserOpenURL := browserOpenURL defer func() { browserOpenURL = realBrowserOpenURL }() browserOpenURL = fakeBrowserOpenURL client, err := New("some_client_id") if err != nil { t.Fatal(err) } client.base.Token.AccessTokens = &fake.AccessTokens{} client.base.Token.Authority = &fake.Authority{} client.base.Token.Resolver = &fake.ResolveEndpoints{} client.base.Token.WSTrust = &fake.WSTrust{} _, err = client.AcquireTokenInteractive(context.Background(), []string{"the_scope"}) if err != nil { t.Fatal(err) } } func TestAcquireTokenSilentHomeTenantAliases(t *testing.T) { accessToken := "*" homeTenant := "home-tenant" clientInfo := base64.RawStdEncoding.EncodeToString([]byte( fmt.Sprintf(`{"uid":"uid","utid":"%s"}`, homeTenant), )) lmo := "login.microsoftonline.com" for _, alias := range []string{"common", "organizations"} { mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, alias))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 3600))) mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, homeTenant))) client, err := New("client-id", WithAuthority(fmt.Sprintf(authorityFmt, lmo, alias)), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } // the auth flow isn't important, we just need to populate the cache ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken) } account := ar.Account ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken) } } } func TestAcquireTokenSilentWithTenantID(t *testing.T) { tenantA, tenantB := "a", "b" lmo := "login.microsoftonline.com" mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA))) client, err := New("client-id", WithAuthority(fmt.Sprintf(authorityFmt, lmo, tenantA)), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) ctx := context.Background() // cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it. for _, tenant := range []string{tenantA, tenantB} { if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil { t.Fatal("silent auth should fail because the cache is empty") } mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`))) mockClient.AppendResponse(mock.WithBody( mock.GetAccessTokenBody(tenant, mock.GetIDToken(tenant, fmt.Sprintf(authorityFmt, lmo, tenant)), "rt-"+tenant, clientInfo, 3600)), ) ar, err := client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithTenantID(tenant)) if err != nil { t.Fatal(err) } if ar.AccessToken != tenant { t.Fatalf(`unexpected token "%s"`, ar.AccessToken) } } // cache should return the correct access token for each tenant var account Account accounts, err := client.Accounts(ctx) if err != nil { t.Fatal(err) } // expecting one account for each tenant we authenticated in above if len(accounts) == 2 { account = accounts[0] } else { t.Fatalf("expected 2 accounts but got %d", len(accounts)) } for _, test := range []struct { desc, expected string opts []AcquireSilentOption }{ // when no tenant is specified the client should return the cached token for its configured authority {"no tenant specified", tenantA, []AcquireSilentOption{WithSilentAccount(account)}}, // when a tenant is specified the client should return the cached token for that tenant {"redundant tenant specified", tenantA, []AcquireSilentOption{WithSilentAccount(account), WithTenantID(tenantA)}}, {"different tenant specified", tenantB, []AcquireSilentOption{WithSilentAccount(account), WithTenantID(tenantB)}}, } { t.Run(test.desc, func(t *testing.T) { ar, err := client.AcquireTokenSilent(ctx, tokenScope, test.opts...) if err != nil { t.Fatal(err) } if ar.AccessToken != test.expected { t.Fatalf(`expected "%s", got "%s"`, test.expected, ar.AccessToken) } }) } } func TestAcquireTokenWithTenantID(t *testing.T) { // replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive realBrowserOpenURL := browserOpenURL defer func() { browserOpenURL = realBrowserOpenURL }() browserOpenURL = fakeBrowserOpenURL accessToken := "*" clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) uuid1 := "00000000-0000-0000-0000-000000000000" uuid2 := strings.ReplaceAll(uuid1, "0", "1") lmo := "login.microsoftonline.com" host := fmt.Sprintf("https://%s/", lmo) for _, test := range []struct { authority, expectedAuthority, tenant string expectError bool }{ {authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1}, {authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1}, {authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2}, {authority: host + uuid1, tenant: "common", expectError: true}, {authority: host + uuid1, tenant: "organizations", expectError: true}, {authority: host + "consumers", tenant: uuid1, expectError: true}, } { for _, method := range []string{"authcode", "authcodeURL", "devicecode", "interactive", "password"} { t.Run(method, func(t *testing.T) { URL := "" mockClient := mock.Client{} if method == "obo" { // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant))) } mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant))) if method == "devicecode" { mockClient.AppendResponse(mock.WithBody([]byte(`{"device_code":"...","expires_in":600}`))) } else if method == "password" { // user realm metadata mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`))) } mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(test.tenant, test.authority), "rt", clientInfo, 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) client, err := New("client-id", WithAuthority(test.authority), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } ctx := context.Background() if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(test.tenant)); err == nil { t.Fatal("silent auth should fail because the cache is empty") } var ar AuthResult var dc DeviceCode switch method { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithTenantID(test.tenant)) case "authcodeURL": URL, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithTenantID(test.tenant)) case "devicecode": dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope, WithTenantID(test.tenant)) case "interactive": ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithTenantID(test.tenant)) case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithTenantID(test.tenant)) default: t.Fatalf("test bug: no test for " + method) } if err != nil { if test.expectError { return } t.Fatal(err) } else if test.expectError { t.Fatal("expected an error") } if method == "devicecode" { if ar, err = dc.AuthenticationResult(ctx); err != nil { t.Fatal(err) } } if !strings.HasPrefix(URL, test.expectedAuthority) { t.Fatalf(`expected "%s", got "%s"`, test.expectedAuthority, URL) } if method == "authcodeURL" { // didn't acquire a token, no need to test silent auth return } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } // silent authentication should succeed for the given tenant because the client has a cached // access token, and for a different tenant because the client has a cached refresh token if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account), WithTenantID(test.tenant)); err != nil { t.Fatal(err) } else if ar.AccessToken != accessToken { t.Fatal("cached access token should match the one returned by AcquireToken...") } otherTenant := "not-" + test.tenant mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, otherTenant))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(otherTenant, test.authority), "rt", clientInfo, 3600))) if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account), WithTenantID("not-"+test.tenant)); err != nil { t.Fatal(err) } }) } } } func TestWithInstanceDiscovery(t *testing.T) { // replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive realBrowserOpenURL := browserOpenURL defer func() { browserOpenURL = realBrowserOpenURL }() browserOpenURL = fakeBrowserOpenURL accessToken := "*" clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) host := "stack.local" stackurl := fmt.Sprintf("https://%s/", host) for _, tenant := range []string{ "adfs", "98b8267d-e97f-426e-8b3f-7956511fd63f", } { for _, method := range []string{"authcode", "devicecode", "interactive", "password"} { t.Run(method, func(t *testing.T) { authority := stackurl + tenant mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant))) if method == "devicecode" { mockClient.AppendResponse(mock.WithBody([]byte(`{"device_code":"...","expires_in":600}`))) } else if method == "password" && tenant != "adfs" { // user realm metadata, which is not requested when AuthorityType is ADFS mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`))) } mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenant, authority), "rt", clientInfo, 3600)), ) client, err := New("client-id", WithAuthority(authority), WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) if err != nil { t.Fatal(err) } ctx := context.Background() if _, err = client.AcquireTokenSilent(ctx, tokenScope); err == nil { t.Fatal("silent auth should fail because the cache is empty") } var ar AuthResult var dc DeviceCode switch method { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope) case "devicecode": dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope) case "interactive": ar, err = client.AcquireTokenInteractive(ctx, tokenScope) case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password") default: t.Fatal("test bug: no test for " + method) } if err != nil { t.Fatal(err) } if method == "devicecode" { if ar, err = dc.AuthenticationResult(ctx); err != nil { t.Fatal(err) } } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account)); err != nil { t.Fatal(err) } else if ar.AccessToken != accessToken { t.Fatal("cached access token should match the one returned by AcquireToken...") } }) } } } // testCache is a simple in-memory cache.ExportReplace implementation type testCache map[string][]byte func (c testCache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { v, err := m.Marshal() if err == nil { c[h.PartitionKey] = v } return err } func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { if v, has := c[h.PartitionKey]; has { return u.Unmarshal(v) } return nil } func TestWithCache(t *testing.T) { cache := make(testCache) accessToken, refreshToken := "*", "rt" clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) lmo := "login.microsoftonline.com" tenantA, tenantB := "a", "b" authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB) mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), refreshToken, clientInfo, 3600))) client, err := New("client-id", WithAuthority(authorityA), WithCache(&cache), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } // The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } account := ar.Account if actual := account.Realm; actual != tenantA { t.Fatalf(`unexpected realm "%s"`, actual) } // a client configured for a different tenant should be able to authenticate silently with the shared cache's data client, err = New("client-id", WithAuthority(authorityB), WithCache(&cache), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } accounts, err := client.Accounts(context.Background()) if err != nil { t.Fatal(err) } if actual := len(accounts); actual != 1 { t.Fatalf("expected 1 account but cache contains %d", actual) } if diff := pretty.Compare(account, accounts[0]); diff != "" { t.Fatal(diff) } // this should work because the cache contains an access token from tenantA mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA))) ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA)) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } // this should work because the cache contains a refresh token for the user accessToken2 := accessToken + "2" mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantB))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken2, mock.GetIDToken(tenantB, authorityB), refreshToken, clientInfo, 3600))) ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken2 { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } } func TestWithClaims(t *testing.T) { // replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive realBrowserOpenURL := browserOpenURL defer func() { browserOpenURL = realBrowserOpenURL }() browserOpenURL = fakeBrowserOpenURL clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) lmo, tenant := "login.microsoftonline.com", "tenant" authority := fmt.Sprintf(authorityFmt, lmo, tenant) accessToken, idToken, refreshToken := "at", mock.GetIDToken(tenant, lmo), "rt" for _, test := range []struct { capabilities []string claims, expected string }{ {}, { capabilities: []string{"cp1"}, expected: `{"access_token":{"xms_cc":{"values":["cp1"]}}}`, }, { claims: `{"id_token":{"auth_time":{"essential":true}}}`, expected: `{"id_token":{"auth_time":{"essential":true}}}`, }, { capabilities: []string{"cp1", "cp2"}, claims: `{"access_token":{"nbf":{"essential":true, "value":"42"}}}`, expected: `{"access_token":{"nbf":{"essential":true, "value":"42"}, "xms_cc":{"values":["cp1","cp2"]}}}`, }, } { var expected map[string]any if err := json.Unmarshal([]byte(test.expected), &expected); err != nil && test.expected != "" { t.Fatal("test bug: the expected result must be JSON or an empty string") } // validate determines whether a request's query or form values contain the expected claims validate := func(t *testing.T, v url.Values) { if test.expected == "" { if v.Has("claims") { t.Fatal("claims shouldn't be set") } return } claims, ok := v["claims"] if !ok { t.Fatal("claims should be set") } if len(claims) != 1 { t.Fatalf("expected exactly 1 claims value, got %d", len(claims)) } var actual map[string]any if err := json.Unmarshal([]byte(claims[0]), &actual); err != nil { t.Fatal(err) } if diff := pretty.Compare(expected, actual); diff != "" { t.Fatal(diff) } } for _, method := range []string{"authcode", "authcodeURL", "devicecode", "interactive", "password", "passwordFederated"} { t.Run(method, func(t *testing.T) { mockClient := mock.Client{} if method == "obo" { // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) } mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) switch method { case "devicecode": mockClient.AppendResponse(mock.WithBody([]byte(`{"device_code":".","expires_in":600}`))) case "password": mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":".","cloud_instance_name":".","domain_name":"."}`))) case "passwordFederated": mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Federated","cloud_audience_urn":".","cloud_instance_name":".","domain_name":".","federation_protocol":".","federation_metadata_url":"."}`))) } mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600)), mock.WithCallback(func(r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatal(err) } validate(t, r.Form) }), ) client, err := New("client-id", WithAuthority(authority), WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } if _, err = client.AcquireTokenSilent(context.Background(), tokenScope); err == nil { t.Fatal("silent authentication should fail because the cache is empty") } ctx := context.Background() var ar AuthResult var dc DeviceCode switch method { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithClaims(test.claims)) case "authcodeURL": u := "" if u, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithClaims(test.claims)); err == nil { var parsed *url.URL if parsed, err = url.Parse(u); err == nil { validate(t, parsed.Query()) return // didn't acquire a token, no need for further validation } } case "devicecode": dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope, WithClaims(test.claims)) case "interactive": ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithClaims(test.claims)) case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims)) case "passwordFederated": client.base.Token.WSTrust = fake.WSTrust{SamlTokenInfo: wstrust.SamlTokenInfo{AssertionType: "urn:ietf:params:oauth:grant-type:saml1_1-bearer"}} ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims)) default: t.Fatalf("test bug: no test for " + method) } if method == "devicecode" && err == nil { // complete the device code flow ar, err = dc.AuthenticationResult(ctx) } if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } // silent auth should now succeed because the client has an access token cached ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account)) if err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } if test.claims != "" { // when given claims, AcquireTokenSilent should request a new access token instead of returning the cached one newToken := "new-access-token" mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600)), mock.WithCallback(func(r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatal(err) } // all token requests should include any specified claims validate(t, r.Form) if actual := r.Form.Get("refresh_token"); actual != refreshToken { t.Fatalf(`unexpected refresh token "%s"`, actual) } }), ) ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithClaims(test.claims), WithSilentAccount(ar.Account)) if err != nil { t.Fatal(err) } if actual := ar.AccessToken; actual != newToken { t.Fatalf("Expected %s, got %s. Client should have redeemed its cached refresh token for a new access token.", newToken, actual) } } }) } } } func TestWithPortAuthority(t *testing.T) { accessToken := "*" sl := "stack.local" port := ":3001" host := sl + port tenant := "00000000-0000-0000-0000-000000000000" authority := fmt.Sprintf("https://%s%s/%s", sl, port, tenant) idToken, refreshToken, URL := "", "", "" mockClient := mock.Client{} //2 calls to instance discovery are made because Host is not trusted mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant))) mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) client, err := New("client-id", WithAuthority(authority), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } ctx := context.Background() if _, err = client.AcquireTokenSilent(ctx, tokenScope); err == nil { t.Fatal("silent auth should fail because the cache is empty") } var ar AuthResult ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope) if err != nil { t.Fatal(err) } if !strings.HasPrefix(URL, authority) { t.Fatalf(`expected "%s", got "%s"`, authority, URL) } if ar.AccessToken != accessToken { t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) } if ar, err = client.AcquireTokenSilent(ctx, tokenScope); err != nil { t.Fatal(err) } if ar.AccessToken != accessToken { t.Fatal("cached access token should match the one returned by AcquireToken...") } } func TestWithLoginHint(t *testing.T) { realBrowserOpenURL := browserOpenURL defer func() { browserOpenURL = realBrowserOpenURL }() upn := "user@localhost" client, err := New("client-id") if err != nil { t.Fatal(err) } client.base.Token.AccessTokens = &fake.AccessTokens{} client.base.Token.Authority = &fake.Authority{} client.base.Token.Resolver = &fake.ResolveEndpoints{} for _, expectHint := range []bool{true, false} { t.Run(fmt.Sprint(expectHint), func(t *testing.T) { // replace the browser launching function with a fake that validates login_hint is set as expected called := false validate := func(v url.Values) error { if !v.Has("login_hint") { if !expectHint { return nil } return errors.New("expected a login hint") } else if !expectHint { return errors.New("expected no login hint") } if actual := v["login_hint"]; len(actual) != 1 || actual[0] != upn { err = fmt.Errorf(`unexpected login_hint "%v"`, actual) } return err } browserOpenURL = func(authURL string) error { called = true parsed, err := url.Parse(authURL) if err != nil { return err } query, err := url.ParseQuery(parsed.RawQuery) if err != nil { return err } if err = validate(query); err != nil { t.Fatal(err) return err } // this helper validates the other params and completes the redirect return fakeBrowserOpenURL(authURL) } acquireOpts := []AcquireInteractiveOption{} urlOpts := []AuthCodeURLOption{} if expectHint { acquireOpts = append(acquireOpts, WithLoginHint(upn)) urlOpts = append(urlOpts, WithLoginHint(upn)) } _, err = client.AcquireTokenInteractive(context.Background(), tokenScope, acquireOpts...) if err != nil { t.Fatal(err) } if !called { t.Fatal("browserOpenURL wasn't called") } u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) if err == nil { var parsed *url.URL parsed, err = url.Parse(u) if err == nil { err = validate(parsed.Query()) } } if err != nil { t.Fatal(err) } }) } } func TestWithDomainHint(t *testing.T) { realBrowserOpenURL := browserOpenURL defer func() { browserOpenURL = realBrowserOpenURL }() domain := "contoso.com" client, err := New("client-id") if err != nil { t.Fatal(err) } client.base.Token.AccessTokens = &fake.AccessTokens{} client.base.Token.Authority = &fake.Authority{} client.base.Token.Resolver = &fake.ResolveEndpoints{} for _, expectHint := range []bool{true, false} { t.Run(fmt.Sprint(expectHint), func(t *testing.T) { // replace the browser launching function with a fake that validates domain_hint is set as expected called := false validate := func(v url.Values) error { if !v.Has("domain_hint") { if !expectHint { return nil } return errors.New("expected a domain hint") } else if !expectHint { return errors.New("expected no domain hint") } if actual := v["domain_hint"]; len(actual) != 1 || actual[0] != domain { err = fmt.Errorf(`unexpected domain_hint "%v"`, actual) } return err } browserOpenURL = func(authURL string) error { called = true parsed, err := url.Parse(authURL) if err != nil { return err } query, err := url.ParseQuery(parsed.RawQuery) if err != nil { return err } if err = validate(query); err != nil { t.Fatal(err) return err } // this helper validates the other params and completes the redirect return fakeBrowserOpenURL(authURL) } var acquireOpts []AcquireInteractiveOption var urlOpts []AuthCodeURLOption if expectHint { acquireOpts = append(acquireOpts, WithDomainHint(domain)) urlOpts = append(urlOpts, WithDomainHint(domain)) } _, err = client.AcquireTokenInteractive(context.Background(), tokenScope, acquireOpts...) if err != nil { t.Fatal(err) } if !called { t.Fatal("browserOpenURL wasn't called") } u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) if err == nil { var parsed *url.URL parsed, err = url.Parse(u) if err == nil { err = validate(parsed.Query()) } } if err != nil { t.Fatal(err) } }) } } microsoft-authentication-library-for-go-1.0.0/apps/testdata/000077500000000000000000000000001442026362400241425ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/testdata/test-cert-chain-reverse.pem000066400000000000000000000115011442026362400313060ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIFGTCCAwGgAwIBAgIUBpOlpNN/cgasvozVw6mfa04+ZC0wDQYJKoZIhvcNAQEL BQAwOzELMAkGA1UEBhMCVVMxDDAKBgNVBAoMA3h6eTEMMAoGA1UECwwDYWJjMRAw DgYDVQQDDAdST09ULUNOMCAXDTIwMDgyMTE3MTAyNVoYDzMzODkwODA0MTcxMDI1 WjA+MQswCQYDVQQGEwJVUzEMMAoGA1UECgwDeHl6MQwwCgYDVQQLDANhYmMxEzAR BgNVBAMMCklOVEVSSU0tQ04wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoIC AQCr+Tblr4DhX3Xahbei00OJnUgRw6FMsnyROZ170Lx0YNcOrRJ9PuaOZiYXY2Hm t71o/PZjMtmiYMIxFaiMnql/dCca777l+uBmlwFOR8bquBWiLStmPpvf7Kh5GZNw XvLGAhk/oxG0O9Pa3OfrlD5vrn/UEGJBu0C+c6ZSLyRk8RjAh8ZbUvnDhhQw3PoK MQSmFK8BN8X34elu7kq0j7nS0D6Mt7eS40oYeHEaQDdBGl8f7rcqC3RjJ/b/F9wA +CsKaps6TvpxE7ln9Y3+0yscgeRbyHW0zem6U7MMvVnK/znuNY90Wmajbea7SUj6 nGZpLGS1TqS4H5rn9U1N1WCSyFukTpAQLCPQHeUrSiHKa9Ye5KuC6u2ZXgy0qpGj nMLu+7746wemi7jN06yZjEmDVneMNCxjLYs4ZhuhiTEItlZpR0VBugNbKo2mJw2U UesizB3AzQkqGOKp70y74yC+ykLkR5vRNyY3MENJ+W83U1haS7C1rhqFV4eXflVe EHl8tj7p4KrfhSPr0Rd12UIWDXkYUpCAPlDMdEa9+SDAyuSnkN4P1fAeuzG01jeJ bnsrWgs3gH3KaGBcPTV4tOTavilGNYDvHZbN9XpYZoZQoPrDZc61M5Ol/cxBahkO n4aDyhpx5hHnSs7VQuHnjeMUxt3J5HqrXPvaf6uPYNT8KQIDAQABoxAwDjAMBgNV HRMEBTADAQH/MA0GCSqGSIb3DQEBCwUAA4ICAQCHCxFqJwfVMI9kMvwlj+sxd4Q5 KuyWxlXRfzpYZ/6JCUq7VBceRVJ87KytMNCyq61rd3Jhb8ssoMCENB68HYhIFUGz GR92AAc6LTh2Y3vQAg640Cz2vLCGnqnlbIslYV6fzxYqgSopR5wJ4D/kJ9w7NSrC paN6bS8Olv//tN6RSnvEMJZdXFA40xFin6qT8Op3nrysEE7Z84wPG9Wj2DXskX6v bZenCEgl1/Ezif5IEgJcYdRkXtYPp6JNbVV+KjDTIMEaUVMpGMGefrt22E+4nSa3 qFvcbzYEKeANe9IAxdPzeWiQ2U90PqWFYCA9sOVsrlSwrup+yYXl0yhTxKY67NCX gyVtZRnzawv0AVFsfCOT4V0wJSuUz4BV6sH7kl2C7FW3zqYVdFEDigbUNsEEh/jF 3JiAtgNbpJ8TtiCFrCI4g9Jepa3polVPzDD8mLtkWWnfSBN/28cxa2jiUlfQxB39 kyqu4rWbm01lyucJxVgJzH0SGyEM5OvF/OIOU3Q7UIXEcZSX3m4Xo59+v6ZNDwKL PcFDNK+PL3WNYfdexQCSAbLm1gkUrVIqvidpCSSVv5oWwTM5m7rbA16Hlu4Ea2ep Pl7I9YXXXnIEFqLYZDnCJglcXmlt6OjI8D3w0TRWHb6bFqubDP417sJDX1S6udN5 wOnOIqg0ZZcqfvpxXA== -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIID7zCCAdcCAQEwDQYJKoZIhvcNAQEFBQAwPjELMAkGA1UEBhMCVVMxDDAKBgNV BAoMA3h5ejEMMAoGA1UECwwDYWJjMRMwEQYDVQQDDApJTlRFUklNLUNOMCAXDTIw MDgyMTE3MTA0M1oYDzMzODkwODA0MTcxMDQzWjA7MQswCQYDVQQGEwJVUzEMMAoG A1UECgwDeHl6MQwwCgYDVQQLDANhYmMxEDAOBgNVBAMMB1VTRVItQ04wggEiMA0G CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC6eQYdbIFhsinob3t3AV4yEH/tz/LV I+UAGLpxQnqGnuAV5GY3CXiAO8GZjx7y3oA1DGfe+/cc6n9BmYWXsKvxpKO8PQkB PYIFtD878uDNv7kVoZG8EVsEngBxd4efMniKWwKtMle0hZ+jj3u4Ad49DsXcC0L2 8uV/eQ6hzsQiR0nTQJ/4QqNNtThSGAFSr7Oo8xzxBNTJhe+BvwDE8JMkCS0v22JW my2GYrRKw4RlSKxwv9QZr83gSicKSUPUACBYfJ7RuXSQOHOMlIcC4oGtDrMshGzr 704Ho+DiByYf5G6nkfZ1I7T039gEKKIllNKWqhyQHejKba3nP163ZKI3AgMBAAEw DQYJKoZIhvcNAQEFBQADggIBADfitSfjlYa2inBKlpWN8VT0DPm5uw8EHuwLymCM WYrQMCuQVE2xYoqCSmXj6KLFt8ycgxHsthdkAzXxDhawaKjz2UFp6nszmUA4xfvS mxLSajwzK/KMBkjdFL7TM+TTBJ1bleDbmoJvDiUeQwisbb1Uh8b3v/jpBwoiamm8 Y4Ca5A15SeBUvAt0/Mc4XJfZ/Ts+LBAPevI9ZyU7C5JZky1q41KPklEHfFZKQRfP cTyTYYvlPoq57C8XPDs6r50EV3B6Z8MN21OB6MVGi8BOY/c7a2h1ZOhxNyBnJuQX w4meJthoKcHUnAs8YCrEoQKayMqPH0Vdhaii/gx4jAgh4PNyIZz5cAst+ybPtQj4 i7LFEWjxis+NLQMHhyE4fIGIkEjzU0uGDugifheIwKALqYEgMDrcoolwvGMdPxGo Qps7tkad5vZV9d9+tTbI+DMB16Y51S04/u1dGFz3jSrDVF08PznJc99VB69OReiC K17n8Xyox/VAaYsRFbOAJpLRWwcnotDpFQbgiLrmXxNOoiWPNbQsQzaQx7cR9okQ v5RTpFAkrdjadhMsXFFiQh+axlaGD368ZGAj5ZoyOiXkV88tNCtyP/RDgW5ftQQ7 fdv05bNXhDfLgEgQvVSDfClDL1hKukLmLQS3ILfB4FlM/XmE+FW/qgo9aSx2XIbx E4ie -----END CERTIFICATE----- -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAunkGHWyBYbIp6G97dwFeMhB/7c/y1SPlABi6cUJ6hp7gFeRm Nwl4gDvBmY8e8t6ANQxn3vv3HOp/QZmFl7Cr8aSjvD0JAT2CBbQ/O/Lgzb+5FaGR vBFbBJ4AcXeHnzJ4ilsCrTJXtIWfo497uAHePQ7F3AtC9vLlf3kOoc7EIkdJ00Cf +EKjTbU4UhgBUq+zqPMc8QTUyYXvgb8AxPCTJAktL9tiVpsthmK0SsOEZUiscL/U Ga/N4EonCklD1AAgWHye0bl0kDhzjJSHAuKBrQ6zLIRs6+9OB6Pg4gcmH+Rup5H2 dSO09N/YBCiiJZTSlqockB3oym2t5z9et2SiNwIDAQABAoIBAQCKzivPG0X0AztO 2i19mHcVrVKNI44POnjsaXvfcyzhqMIFic7MiTA5xEGInRDcmOO2mVV4lvaLf8La gfz/vXNAnN2E8aoSUkbHGDU52sGcZmrPv0VMSV8HQNXzoJZD2r3/v19urVq79fuv NM9TWZCkwqpl8bwXNxe+m85YhCFboY9G543qmuXzKAQLoSupT0e4eIo2IGp7eJYK 5J/wtlEumUdhsKo1ajLojDgsgPKfrCyvsmO+bj1dRKGXVLO2SL2pFVCjjHF4SP3q 1WX39beu61Zu+kGthDgj5muHgH06FtnWoHLIUrRmYpM+ezCxQHdRWz7AYjheeE7q QqJv1PqBAoGBAOlb/gzsps+rInE+LQoEzVj8osILI4NxIpNc6+iG81dEi+zQABX/ bHV6hXGGceozVcX4B+V7f08PlZIAgM3IDqfy0fH2pwEQahJ8a3MwzCgR66RxYlkX E8czkoz0pcHW58FnLLlWXpHRALTtqoPP5LnWs0SmoNvcHZ9yjJ6tvpRlAoGBAMyQ fytsyla1ujO0l/kuLFG7gndeOc96SutH3V17lZ1pN0efHyk2aglOnl6YsdPKLZvZ 3ghj01HV0Q0f//xpftduuA7gdgDzSG1irXsxEidfVxX7RsPxX6cx8dhYnuk5rz5E XyTko7zTpr+A4XMnq6+JNSSCIE+CVYcYf/hyemxrAoGAeC9py4xCaWgxR/OGzMcm X3NV++wysSqebRkJYuvF/icOjbuen7W6TVL50Ts2BjHENj6FCpqtObHEDbr2m4Uy jysPF7g50OF8T+MGkAAM1YJNQ5cl2M564DhefPwvNoMRP1l8/kNOV3k2DPjuvg5f NZsvHudWp4VZOFqNs9e19MUCgYAjewCDoKfrqDN2mmEtmAOZ3YMAfzhZsyVhb6KG f1Pw7HnpE0FNXaHAoYE4eRWG3W9Rs9Ud8WqKrCJJO36j4gxdA1grRGVTPt8WEeJz FozGhXPOXTnl7GyhzDjdRGmznAy4KRWziXCY5MDsQEdaOMw/cvXjsio2gC2jc+1m QzzWpwKBgHzszJ5s6vcWElox4Yc1elQ8xniPpo3RtfXZOLX8xA4eR9yQawah1zd6 ChfeYbHVfq007s+RWGTb+KYQ6ic9nkW464qmVxHGBatUo9+MR4Gk8blANoAfHxdV g6JNgT2kIGu9IEwoD6XQldC/v24bvFSesyGRHNdI4mUG+hhU4aNw -----END RSA PRIVATE KEY----- microsoft-authentication-library-for-go-1.0.0/apps/testdata/test-cert-chain.pem000066400000000000000000000115011442026362400276350ustar00rootroot00000000000000-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAunkGHWyBYbIp6G97dwFeMhB/7c/y1SPlABi6cUJ6hp7gFeRm Nwl4gDvBmY8e8t6ANQxn3vv3HOp/QZmFl7Cr8aSjvD0JAT2CBbQ/O/Lgzb+5FaGR vBFbBJ4AcXeHnzJ4ilsCrTJXtIWfo497uAHePQ7F3AtC9vLlf3kOoc7EIkdJ00Cf +EKjTbU4UhgBUq+zqPMc8QTUyYXvgb8AxPCTJAktL9tiVpsthmK0SsOEZUiscL/U Ga/N4EonCklD1AAgWHye0bl0kDhzjJSHAuKBrQ6zLIRs6+9OB6Pg4gcmH+Rup5H2 dSO09N/YBCiiJZTSlqockB3oym2t5z9et2SiNwIDAQABAoIBAQCKzivPG0X0AztO 2i19mHcVrVKNI44POnjsaXvfcyzhqMIFic7MiTA5xEGInRDcmOO2mVV4lvaLf8La gfz/vXNAnN2E8aoSUkbHGDU52sGcZmrPv0VMSV8HQNXzoJZD2r3/v19urVq79fuv NM9TWZCkwqpl8bwXNxe+m85YhCFboY9G543qmuXzKAQLoSupT0e4eIo2IGp7eJYK 5J/wtlEumUdhsKo1ajLojDgsgPKfrCyvsmO+bj1dRKGXVLO2SL2pFVCjjHF4SP3q 1WX39beu61Zu+kGthDgj5muHgH06FtnWoHLIUrRmYpM+ezCxQHdRWz7AYjheeE7q QqJv1PqBAoGBAOlb/gzsps+rInE+LQoEzVj8osILI4NxIpNc6+iG81dEi+zQABX/ bHV6hXGGceozVcX4B+V7f08PlZIAgM3IDqfy0fH2pwEQahJ8a3MwzCgR66RxYlkX E8czkoz0pcHW58FnLLlWXpHRALTtqoPP5LnWs0SmoNvcHZ9yjJ6tvpRlAoGBAMyQ fytsyla1ujO0l/kuLFG7gndeOc96SutH3V17lZ1pN0efHyk2aglOnl6YsdPKLZvZ 3ghj01HV0Q0f//xpftduuA7gdgDzSG1irXsxEidfVxX7RsPxX6cx8dhYnuk5rz5E XyTko7zTpr+A4XMnq6+JNSSCIE+CVYcYf/hyemxrAoGAeC9py4xCaWgxR/OGzMcm X3NV++wysSqebRkJYuvF/icOjbuen7W6TVL50Ts2BjHENj6FCpqtObHEDbr2m4Uy jysPF7g50OF8T+MGkAAM1YJNQ5cl2M564DhefPwvNoMRP1l8/kNOV3k2DPjuvg5f NZsvHudWp4VZOFqNs9e19MUCgYAjewCDoKfrqDN2mmEtmAOZ3YMAfzhZsyVhb6KG f1Pw7HnpE0FNXaHAoYE4eRWG3W9Rs9Ud8WqKrCJJO36j4gxdA1grRGVTPt8WEeJz FozGhXPOXTnl7GyhzDjdRGmznAy4KRWziXCY5MDsQEdaOMw/cvXjsio2gC2jc+1m QzzWpwKBgHzszJ5s6vcWElox4Yc1elQ8xniPpo3RtfXZOLX8xA4eR9yQawah1zd6 ChfeYbHVfq007s+RWGTb+KYQ6ic9nkW464qmVxHGBatUo9+MR4Gk8blANoAfHxdV g6JNgT2kIGu9IEwoD6XQldC/v24bvFSesyGRHNdI4mUG+hhU4aNw -----END RSA PRIVATE KEY----- -----BEGIN CERTIFICATE----- MIID7zCCAdcCAQEwDQYJKoZIhvcNAQEFBQAwPjELMAkGA1UEBhMCVVMxDDAKBgNV BAoMA3h5ejEMMAoGA1UECwwDYWJjMRMwEQYDVQQDDApJTlRFUklNLUNOMCAXDTIw MDgyMTE3MTA0M1oYDzMzODkwODA0MTcxMDQzWjA7MQswCQYDVQQGEwJVUzEMMAoG A1UECgwDeHl6MQwwCgYDVQQLDANhYmMxEDAOBgNVBAMMB1VTRVItQ04wggEiMA0G CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC6eQYdbIFhsinob3t3AV4yEH/tz/LV I+UAGLpxQnqGnuAV5GY3CXiAO8GZjx7y3oA1DGfe+/cc6n9BmYWXsKvxpKO8PQkB PYIFtD878uDNv7kVoZG8EVsEngBxd4efMniKWwKtMle0hZ+jj3u4Ad49DsXcC0L2 8uV/eQ6hzsQiR0nTQJ/4QqNNtThSGAFSr7Oo8xzxBNTJhe+BvwDE8JMkCS0v22JW my2GYrRKw4RlSKxwv9QZr83gSicKSUPUACBYfJ7RuXSQOHOMlIcC4oGtDrMshGzr 704Ho+DiByYf5G6nkfZ1I7T039gEKKIllNKWqhyQHejKba3nP163ZKI3AgMBAAEw DQYJKoZIhvcNAQEFBQADggIBADfitSfjlYa2inBKlpWN8VT0DPm5uw8EHuwLymCM WYrQMCuQVE2xYoqCSmXj6KLFt8ycgxHsthdkAzXxDhawaKjz2UFp6nszmUA4xfvS mxLSajwzK/KMBkjdFL7TM+TTBJ1bleDbmoJvDiUeQwisbb1Uh8b3v/jpBwoiamm8 Y4Ca5A15SeBUvAt0/Mc4XJfZ/Ts+LBAPevI9ZyU7C5JZky1q41KPklEHfFZKQRfP cTyTYYvlPoq57C8XPDs6r50EV3B6Z8MN21OB6MVGi8BOY/c7a2h1ZOhxNyBnJuQX w4meJthoKcHUnAs8YCrEoQKayMqPH0Vdhaii/gx4jAgh4PNyIZz5cAst+ybPtQj4 i7LFEWjxis+NLQMHhyE4fIGIkEjzU0uGDugifheIwKALqYEgMDrcoolwvGMdPxGo Qps7tkad5vZV9d9+tTbI+DMB16Y51S04/u1dGFz3jSrDVF08PznJc99VB69OReiC K17n8Xyox/VAaYsRFbOAJpLRWwcnotDpFQbgiLrmXxNOoiWPNbQsQzaQx7cR9okQ v5RTpFAkrdjadhMsXFFiQh+axlaGD368ZGAj5ZoyOiXkV88tNCtyP/RDgW5ftQQ7 fdv05bNXhDfLgEgQvVSDfClDL1hKukLmLQS3ILfB4FlM/XmE+FW/qgo9aSx2XIbx E4ie -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIFGTCCAwGgAwIBAgIUBpOlpNN/cgasvozVw6mfa04+ZC0wDQYJKoZIhvcNAQEL BQAwOzELMAkGA1UEBhMCVVMxDDAKBgNVBAoMA3h6eTEMMAoGA1UECwwDYWJjMRAw DgYDVQQDDAdST09ULUNOMCAXDTIwMDgyMTE3MTAyNVoYDzMzODkwODA0MTcxMDI1 WjA+MQswCQYDVQQGEwJVUzEMMAoGA1UECgwDeHl6MQwwCgYDVQQLDANhYmMxEzAR BgNVBAMMCklOVEVSSU0tQ04wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoIC AQCr+Tblr4DhX3Xahbei00OJnUgRw6FMsnyROZ170Lx0YNcOrRJ9PuaOZiYXY2Hm t71o/PZjMtmiYMIxFaiMnql/dCca777l+uBmlwFOR8bquBWiLStmPpvf7Kh5GZNw XvLGAhk/oxG0O9Pa3OfrlD5vrn/UEGJBu0C+c6ZSLyRk8RjAh8ZbUvnDhhQw3PoK MQSmFK8BN8X34elu7kq0j7nS0D6Mt7eS40oYeHEaQDdBGl8f7rcqC3RjJ/b/F9wA +CsKaps6TvpxE7ln9Y3+0yscgeRbyHW0zem6U7MMvVnK/znuNY90Wmajbea7SUj6 nGZpLGS1TqS4H5rn9U1N1WCSyFukTpAQLCPQHeUrSiHKa9Ye5KuC6u2ZXgy0qpGj nMLu+7746wemi7jN06yZjEmDVneMNCxjLYs4ZhuhiTEItlZpR0VBugNbKo2mJw2U UesizB3AzQkqGOKp70y74yC+ykLkR5vRNyY3MENJ+W83U1haS7C1rhqFV4eXflVe EHl8tj7p4KrfhSPr0Rd12UIWDXkYUpCAPlDMdEa9+SDAyuSnkN4P1fAeuzG01jeJ bnsrWgs3gH3KaGBcPTV4tOTavilGNYDvHZbN9XpYZoZQoPrDZc61M5Ol/cxBahkO n4aDyhpx5hHnSs7VQuHnjeMUxt3J5HqrXPvaf6uPYNT8KQIDAQABoxAwDjAMBgNV HRMEBTADAQH/MA0GCSqGSIb3DQEBCwUAA4ICAQCHCxFqJwfVMI9kMvwlj+sxd4Q5 KuyWxlXRfzpYZ/6JCUq7VBceRVJ87KytMNCyq61rd3Jhb8ssoMCENB68HYhIFUGz GR92AAc6LTh2Y3vQAg640Cz2vLCGnqnlbIslYV6fzxYqgSopR5wJ4D/kJ9w7NSrC paN6bS8Olv//tN6RSnvEMJZdXFA40xFin6qT8Op3nrysEE7Z84wPG9Wj2DXskX6v bZenCEgl1/Ezif5IEgJcYdRkXtYPp6JNbVV+KjDTIMEaUVMpGMGefrt22E+4nSa3 qFvcbzYEKeANe9IAxdPzeWiQ2U90PqWFYCA9sOVsrlSwrup+yYXl0yhTxKY67NCX gyVtZRnzawv0AVFsfCOT4V0wJSuUz4BV6sH7kl2C7FW3zqYVdFEDigbUNsEEh/jF 3JiAtgNbpJ8TtiCFrCI4g9Jepa3polVPzDD8mLtkWWnfSBN/28cxa2jiUlfQxB39 kyqu4rWbm01lyucJxVgJzH0SGyEM5OvF/OIOU3Q7UIXEcZSX3m4Xo59+v6ZNDwKL PcFDNK+PL3WNYfdexQCSAbLm1gkUrVIqvidpCSSVv5oWwTM5m7rbA16Hlu4Ea2ep Pl7I9YXXXnIEFqLYZDnCJglcXmlt6OjI8D3w0TRWHb6bFqubDP417sJDX1S6udN5 wOnOIqg0ZZcqfvpxXA== -----END CERTIFICATE----- microsoft-authentication-library-for-go-1.0.0/apps/testdata/test-cert.pem000066400000000000000000000051501442026362400265600ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIICljCCAX4CCQDNgteZ+lJH4zANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 czAeFw0yMTAxMDQyMzQzNDVaFw0yMTAyMDMyMzQzNDVaMA0xCzAJBgNVBAYTAnVz MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1r58wq7JQxM12viLNbdG fFizeVQwWRwrx/4CH3kU8jjGovbhkvC/uLWqVGchgATThhGkvNrA92WvdkVwsZMk Qf7ZnTA7kemo4VFtgo5XCGEej9gOTW13Evdc/0Flip+RXl3h3Q6BbbB9IFE0c6cS 3i/v/t8KGpVYQHQzBwTcYehM6eDO8ZjUyUUcJOMXdMCctamig7fMGlziKFahn4dX JoiiK4oNKE9okXIAXCTbVkAxxH0hD+5XH1nn5LJnHe0e5DflI3YIiPgmRL5uC89K XqmYCKWrq5z2D5k+5fQLmbOcxErBcFCh8hA+Xu0RLT4BHPEgc6iVIqxL4CZi/cke uwIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQAAyDbm0Fda0/vY6ZVDML2IbGWbro1w nWYNw6wclNU6sx1oeG/k/y2ni7NImPpbFN+594WS6rYHgFdROfeuNgGnjgQCJogk +8ouf1R6vFMUAScWeSaFnZmBEgwofWsnIcUKkbDIXbpRhMrkNEcY09VgjmCKhspQ iX2bJQTj49XBac9tBaJJYDZ4HgkO4nU7QeEPpvwlELZFoZZXtd3fan+VUyFS2a9n gkAMDYoQPGN4tyGFabWws/GlMxelWvqUzpQKmeRPVz+cij75l8eKThEiu0zbjOTD Gq81BcY61SPqN02zoPCtqZ/zU6HhaL3x7zUuzhLhNoh83A43UVYEoOOf -----END CERTIFICATE----- -----BEGIN PRIVATE KEY----- MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDWvnzCrslDEzXa +Is1t0Z8WLN5VDBZHCvH/gIfeRTyOMai9uGS8L+4tapUZyGABNOGEaS82sD3Za92 RXCxkyRB/tmdMDuR6ajhUW2CjlcIYR6P2A5NbXcS91z/QWWKn5FeXeHdDoFtsH0g UTRzpxLeL+/+3woalVhAdDMHBNxh6Ezp4M7xmNTJRRwk4xd0wJy1qaKDt8waXOIo VqGfh1cmiKIrig0oT2iRcgBcJNtWQDHEfSEP7lcfWefksmcd7R7kN+UjdgiI+CZE vm4Lz0peqZgIpaurnPYPmT7l9AuZs5zESsFwUKHyED5e7REtPgEc8SBzqJUirEvg JmL9yR67AgMBAAECggEAAQ/IBh5fGFnL9l0sMwPI8Wxu1ra31njxLnfvAsDSfbAS K1QVIWjXSc58HRa1b7CWax9DNTvPoGl8SJVnTTlxAHKGGOTYJoyFLTf91ptlisEQ KZ3j1DYqVImsiAaGvfyz90d3imQ795Lby4EbRUcaLMcH5LatkhwS556rcelwPXuq M43XaZu5Es4pG0EmzfXplO/awt5HdUDPEAY3yw7QH8D1/l/toLPyiFv37RezkVK9 ffcUQpH7uH000Gja+JSEHgpWZhE96ac6H0zBtlM1VkMtfBuczz5tkKN/p70fhr8T ZXARZqIaF4vx7RkBBzCfhvrgGqxXMuvTaW6N4RDWYQKBgQD1iZ7/xr9qy4cPFSOt yBnG5cE6wC7wP8qgr0N7MgAii5OZgx6rtfGIVJDY58CFijnT8jZ5pjNS3p7j/Rzp lQJMIwC5kIe/7FU7nmE3ko7Wg+bpd8iWLLIi/QWVFLbS7qVmulTc+CEXWyhAiI2u RL/1APjIDFKp9gqtKmwb9erxDwKBgQDf5PbGHuPv5RBLJz9du+M/BIBY+HDltG89 p3huHHTjkJ5R38oximf2HnV4ygT/p2+ZUD6TJZZw6qou3/GiU5gZbRpg+4LXtQUR vV+S2n/t86NG1YcGmM29r8LWqrK9gxLW0X62Fpps16rHSP7kVc4SvmrYwqNzqKlC D9QbFYYflQKBgQCKEVzrDuNMNi43+PcbHU4BXeiOFMtQJU7XlDYp7C/PPRU+WVDB 1Yl/062vioHjlZp259hiB2cMzkoigY3kevnTvksGDZOIBGjZIXIhQbQ4Q+twlP6i E3gH3Kdq8T7s1W0EmvplVtGkxImZ4C9rMxWNu4IpW2SQVd4jCZvJDTuTWQKBgQCn LGjuCYacSubdlpKDxJSrKwtCY0641P7yhCcx4GGOwR7Vd0mbsAJsDNYduIn+8eAs E3SFnl00NqOXmHLth4lcAtDddS5/LZR5aHMCTc+TtoVFkI3faRzF84SBkLchNctN RuNbxojLmETVxDU9/Kt/51oUO1CcPWUUBImVJ38b+QKBgQCTbi0nS0n8kC7nlXWN QtPcf4UraJAxv1DGq4lnJ8AHSZqqkP5fyjfknSw5ExOPDg4mEHhnnpsvwJuSX00d UYUN2ZJXPZeaO0HmbYZ3/vC9bo6KW95PhidEUQpGlKrFY342khjQHJtH67YUThwU lQFhpxvPgPNBuxVRnsxoH/sLOA== -----END PRIVATE KEY----- microsoft-authentication-library-for-go-1.0.0/apps/tests/000077500000000000000000000000001442026362400234735ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/tests/benchmarks/000077500000000000000000000000001442026362400256105ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/tests/benchmarks/confidential.go000066400000000000000000000131001442026362400305710ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main import ( "context" "fmt" "os" "runtime" "strconv" "sync" "text/template" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" ) const accessToken = "fake_token" var tokenScope = []string{"fake_scope"} type testParams struct { // the number of goroutines to use Concurrency int // the number of tokens in the cache // must be divisible by Concurrency TokenCount int } func fakeClient() (base.Client, error) { // we use a base.Client so we can provide a fake OAuth client return base.New("fake_client_id", "https://fake_authority/fake", &oauth.Client{ AccessTokens: &fake.AccessTokens{ AccessToken: accesstokens.TokenResponse{ AccessToken: accessToken, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, }, }, Authority: &fake.Authority{ InstanceResp: authority.InstanceDiscoveryResponse{ Metadata: []authority.InstanceDiscoveryMetadata{ { PreferredNetwork: "fake_authority", Aliases: []string{"fake_authority"}, }, }, }, }, Resolver: &fake.ResolveEndpoints{ Endpoints: authority.Endpoints{ AuthorizationEndpoint: "auth_endpoint", TokenEndpoint: "token_endpoint", }, }, WSTrust: &fake.WSTrust{}, }) } type execTime struct { start time.Time end time.Time } func populateTokenCache(client base.Client, params testParams) execTime { if r := params.TokenCount % params.Concurrency; r != 0 { panic("TokenCount must be divisible by Concurrency") } parts := params.TokenCount / params.Concurrency authParams := client.AuthParams authParams.Scopes = tokenScope authParams.AuthorizationType = authority.ATClientCredentials wg := &sync.WaitGroup{} fmt.Printf("Populating token cache with %d tokens...", params.TokenCount) start := time.Now() for n := 0; n < params.Concurrency; n++ { wg.Add(1) go func(chunk int) { for i := parts * chunk; i < parts*(chunk+1); i++ { // we use this to add a fake token to the cache. // each token has a different scope which is what makes them unique _, err := client.AuthResultFromToken(context.Background(), authParams, accesstokens.TokenResponse{ AccessToken: accessToken, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: []string{strconv.FormatInt(int64(i), 10)}}, }, true) if err != nil { panic(err) } } wg.Done() }(n) } wg.Wait() return execTime{start: start, end: time.Now()} } func executeTest(client base.Client, params testParams) execTime { wg := &sync.WaitGroup{} fmt.Printf("Begin token retrieval.....") start := time.Now() for n := 0; n < params.Concurrency; n++ { wg.Add(1) go func() { // retrieve each token once per goroutine for tk := 0; tk < params.TokenCount; tk++ { _, err := client.AcquireTokenSilent(context.Background(), base.AcquireTokenSilentParameters{ Scopes: []string{strconv.FormatInt(int64(tk), 10)}, RequestType: accesstokens.ATConfidential, Credential: &accesstokens.Credential{ Secret: "fake_secret", }, }) if err != nil { panic(err) } } wg.Done() }() } wg.Wait() return execTime{start: start, end: time.Now()} } // Stats is used with statsTemplText for reporting purposes type Stats struct { popExec execTime retExec execTime Concurrency int Count int64 } // PopDur returns the total duration for populating the cache. func (s *Stats) PopDur() time.Duration { return s.popExec.end.Sub(s.popExec.start) } // RetDur returns the total duration for retrieving tokens. func (s *Stats) RetDur() time.Duration { return s.retExec.end.Sub(s.retExec.start) } // PopAvg returns the mean average of caching a token. func (s *Stats) PopAvg() time.Duration { return s.PopDur() / time.Duration(s.Count) } // RetAvg returns the mean average of retrieving a token. func (s *Stats) RetAvg() time.Duration { return s.RetDur() / time.Duration(s.Count) } var statsTemplText = ` Test Results: [{{.Concurrency}} goroutines][{{.Count}} tokens] [population: total {{.PopDur}}, avg {{.PopAvg}}] [retrieval: total {{.RetDur}}, avg {{.RetAvg}}] ========================================================================== ` var statsTempl = template.Must(template.New("stats").Parse(statsTemplText)) func main() { tests := []testParams{ { Concurrency: runtime.NumCPU(), TokenCount: 100, }, { Concurrency: runtime.NumCPU(), TokenCount: 1000, }, { Concurrency: runtime.NumCPU(), TokenCount: 10000, }, { Concurrency: runtime.NumCPU(), TokenCount: 20000, }, } for _, t := range tests { client, err := fakeClient() if err != nil { panic(err) } fmt.Printf("Test Params: %#v\n", t) ptime := populateTokenCache(client, t) ttime := executeTest(client, t) if err := statsTempl.Execute(os.Stdout, &Stats{ popExec: ptime, retExec: ttime, Concurrency: t.Concurrency, Count: int64(t.TokenCount), }); err != nil { panic(err) } } } microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/000077500000000000000000000000001442026362400251355ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/README.md000066400000000000000000000037071442026362400264230ustar00rootroot00000000000000# Running the Dev Apps for MSAL Go To run one of the dev app which uses MSAL Go, the `config.json` file and the `confidential_config.json` should look like the following: ```json { "authority": "https://login.microsoftonline.com/organizations", "client_id": "your_client_id", "scopes": ["user.read"], "username": "your_username", "password": "your_password", "redirect_uri": "redirect uri registered on the portal", "code_challenge": "transformed code verifier from PKCE", "state": "state parameter for authorization code flow", "client_secret": "client secret you generated for your app", "thumbprint": "the certificate thumbprint defined in your app generation", "pem_file": "the file path of your private key pem" } ``` The dev apps in this repo get tokens for the MS Graph API. To find permissible scopes for MS Graph, visit this [link](https://docs.microsoft.com/graph/permissions-reference). PKCE is explained [here](https://tools.ietf.org/html/rfc7636#section-4.1). ## On Windows To run the dev samples: `cd test/devapps` run the command: 'go run ./ 1' Alternatives: * 1 build and run "locally" * In the devapps folder * type 'go build' * type 'devapps.exe 1' to run the device code flow * 2 (Advanced) install and run from the gobin folder * See more: https://golang.org/cmd/go/#hdr-Compile_and_install_packages_and_dependencies * In the devapps folder * type 'go install' * locate your gobin folder e.g. type 'go env' to find your gobin folder location cd to your gobin folder * type 'devapps.exe 1' to run the device code flow ## On Mac To run one of the devapps, run the command `go run src/test/devapps/*.go `. The devapp numbers are as follows: * 1 - `device_code_flow_sample.go` * 2 - `authorization_code_sample.go` * 3 - `username_password_sample.go` * 4 - `confidential_auth_code_sample.go` * 5 - `client_secret_sample.go` * 6 - `client_certificate_sample.go` microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/authorization_code_sample.go000066400000000000000000000042251442026362400327220ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main // TODO(msal expert): This should be refactored into an example maybe? // a "main" with a bunch of private functions that can't run isn't a good code sample. /* func getToken(w http.ResponseWriter, r *http.Request) { // Getting the authorization code from the URL's query states, ok := r.URL.Query()["state"] if !ok || len(states[0]) < 1 { log.Fatal(errors.New("State parameter missing, can't verify authorization code")) } codes, ok := r.URL.Query()["code"] if !ok || len(codes[0]) < 1 { log.Fatal(errors.New("Authorization code missing")) } if states[0] != config.State { log.Fatal(errors.New("State parameter is incorrect")) } code := codes[0] // Getting the access token using the authorization code result, err := publicClientApp.AcquireTokenByAuthCode(context.Background(), config.Scopes, &msal.AcquireTokenByAuthCodeOptions{ Code: code, CodeChallenge: config.CodeChallenge, }) if err != nil { log.Fatal(err) } // Prints the access token on the webpage fmt.Fprintf(w, "Access token is "+result.GetAccessToken()) } func acquireByAuthorizationCodePublic() { options := msal.DefaultPublicClientApplicationOptions() options.Authority = config.Authority publicClientApp, err := msal.NewPublicClientApplication(config.ClientID, &options) if err != nil { panic(err) } http.HandleFunc("/", redirectToURL) // The redirect uri set in our app's registration is http://localhost:port/redirect http.HandleFunc("/redirect", getToken) log.Fatal(http.ListenAndServe(":"+port, nil)) } func redirectToURL(w http.ResponseWriter, r *http.Request) { // Getting the URL to redirect to acquire the authorization code authCodeURLParams := msal.CreateAuthorizationCodeURLParameters(config.ClientID, config.RedirectURI, config.Scopes) authCodeURLParams.CodeChallenge = config.CodeChallenge authCodeURLParams.State = config.State authURL, err := publicClientApp.AuthCodeURL(context.Background(), authCodeURLParams) if err != nil { log.Fatal(err) } // Redirecting to the URL we have received log.Info(authURL) http.Redirect(w, r, authURL, http.StatusSeeOther) } */ microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/client_certificate_sample.go000066400000000000000000000023201442026362400326420ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main import ( "context" "fmt" "log" "os" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) func acquireTokenClientCertificate() { config := CreateConfig("confidential_config.json") pemData, err := os.ReadFile(config.PemData) if err != nil { log.Fatal(err) } // This extracts our public certificates and private key from the PEM file. If it is // encrypted, the second argument must be password to decode. certs, privateKey, err := confidential.CertFromPEM(pemData, "") if err != nil { log.Fatal(err) } cred, err := confidential.NewCredFromCert(certs, privateKey) if err != nil { log.Fatal(err) } app, err := confidential.New(config.Authority, config.ClientID, cred, confidential.WithCache(cacheAccessor)) if err != nil { log.Fatal(err) } result, err := app.AcquireTokenSilent(context.Background(), config.Scopes) if err != nil { result, err = app.AcquireTokenByCredential(context.Background(), config.Scopes) if err != nil { log.Fatal(err) } fmt.Println("Access Token Is " + result.AccessToken) return } fmt.Println("Silently acquired token " + result.AccessToken) } microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/client_secret_sample.go000066400000000000000000000015611442026362400316530ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main import ( "context" "fmt" "log" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) func acquireTokenClientSecret() { config := CreateConfig("confidential_config.json") cred, err := confidential.NewCredFromSecret(config.ClientSecret) if err != nil { log.Fatal(err) } app, err := confidential.New(config.Authority, config.ClientID, cred, confidential.WithCache(cacheAccessor)) if err != nil { log.Fatal(err) } result, err := app.AcquireTokenSilent(context.Background(), config.Scopes) if err != nil { result, err = app.AcquireTokenByCredential(context.Background(), config.Scopes) if err != nil { log.Fatal(err) } fmt.Println("Access Token Is " + result.AccessToken) } fmt.Println("Silently acquired token " + result.AccessToken) } microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/confidential_auth_code_sample.go000066400000000000000000000065671442026362400335150ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main /* var ( accessToken string confidentialConfig = CreateConfig("confidential_config.json") app confidential.Client ) // TODO(msal): I'm not sure what to do here with the CodeChallenge and State. authCodeURLParams // is no more. CodeChallenge is only used now in a confidential.AcquireTokenByAuthCode(), which // this is not using. Maybe now this is a two step process???? func redirectToURLConfidential(w http.ResponseWriter, r *http.Request) { // Getting the URL to redirect to acquire the authorization code authCodeURLParams.CodeChallenge = confidentialConfig.CodeChallenge authCodeURLParams.State = confidentialConfig.State authURL, err := app.AuthCodeURL(context.Background(), confidentialConfig.ClientID, confidentialConfig.RedirectURI, confidentialConfig.Scopes) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } // Redirecting to the URL we have received log.Println("redirecting to auth: ", authURL) http.Redirect(w, r, authURL, http.StatusSeeOther) } func getTokenConfidential(w http.ResponseWriter, r *http.Request) { // Getting the authorization code from the URL's query states, ok := r.URL.Query()["state"] if !ok || len(states[0]) < 1 { log.Fatal(errors.New("State parameter missing, can't verify authorization code")) } codes, ok := r.URL.Query()["code"] if !ok || len(codes[0]) < 1 { log.Fatal(errors.New("Authorization code missing")) } if states[0] != config.State { log.Fatal(errors.New("State parameter is incorrect")) } code := codes[0] // Getting the access token using the authorization code result, err := app.AcquireTokenByAuthCode( context.Background(), confidentialConfig.Scopes, confidential.CodeChallenge(code, confidentialConfig.CodeChallenge), ) if err != nil { log.Fatal(err) } // Prints the access token on the webpage fmt.Fprintf(w, "Access token is "+result.GetAccessToken()) accessToken = result.GetAccessToken() } // TODO(msal): Needs to use an x509 certificate like the other now that we are not using a // thumbprint directly. /* func acquireByAuthorizationCodeConfidential(ctx context.Context) { key, err := os.ReadFile(confidentialConfig.KeyFile) if err != nil { log.Fatal(err) } certificate, err := msal.CreateClientCredentialFromCertificate(confidentialConfig.Thumbprint, key) if err != nil { log.Fatal(err) } options := msal.DefaultConfidentialClientApplicationOptions() options.Accessor = cacheAccessor options.Authority = confidentialConfig.Authority app, err := msal.NewConfidentialClientApplication(confidentialConfig.ClientID, certificate, &options) if err != nil { log.Fatal(err) } var userAccount shared.Account for _, account := range app.Accounts(ctx) { if account.PreferredUsername == confidentialConfig.Username { userAccount = account } } result, err := app.AcquireTokenSilent( context.Background(), confidentialConfig.Scopes, &msal.AcquireTokenSilentOptions{ Account: userAccount, }, ) if err != nil { panic(err) } fmt.Printf("Access token is " + result.GetAccessToken()) accessToken = result.GetAccessToken() http.HandleFunc("/", redirectToURLConfidential) // The redirect uri set in our app's registration is http://localhost:port/redirect http.HandleFunc("/redirect", getTokenConfidential) log.Fatal(http.ListenAndServe(":"+port, nil)) } */ microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/confidential_config.json000066400000000000000000000010041442026362400320070ustar00rootroot00000000000000{ "authority": "https://login.microsoftonline.com/your_tenant_id", "client_id": "your_client_id", "scopes": ["requested_scope- usually of the format - {ResourceId}+/.default"], "redirect_uri": "redirect uri registered on the portal", "code_challenge": "transformed code verifier from PKCE", "state": "state parameter for authorization code flow", "client_secret": "client secret you generated for your app", "pem_file": "the file path of pem containing public cert and private key" } microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/config.json000066400000000000000000000011111442026362400272670ustar00rootroot00000000000000{ "authority": "https://login.microsoftonline.com/organizations", "client_id": "your_client_id", "scopes": ["user.read"], "username": "your_username", "password": "your_password", "redirect_uri": "redirect uri registered on the portal", "code_challenge": "transformed code verifier from PKCE", "state": "state parameter for authorization code flow", "client_secret": "client secret you generated for your app", "thumbprint": "the certificate thumbprint defined in your app generation", "pem_file": "the file path of your private key pem" }microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/device_code_flow_sample.go000066400000000000000000000034261442026362400323120ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main import ( "context" "fmt" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) func acquireTokenDeviceCode() { config := CreateConfig("config.json") app, err := public.New(config.ClientID, public.WithCache(cacheAccessor), public.WithAuthority(config.Authority)) if err != nil { panic(err) } // look in the cache to see if the account to use has been cached var userAccount public.Account accounts, err := app.Accounts(context.Background()) if err != nil { panic("failed to read the cache") } for _, account := range accounts { if account.PreferredUsername == config.Username { userAccount = account } } // found a cached account, now see if an applicable token has been cached // NOTE: this API conflates error states, i.e. err is non-nil if an applicable token isn't // cached or if something goes wrong (making the HTTP request, unmarshalling, etc). authResult, err := app.AcquireTokenSilent(context.Background(), config.Scopes, public.WithSilentAccount(userAccount)) if err != nil { // either there was no cached account/token or the call to AcquireTokenSilent() failed // make a new request to AAD ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) defer cancel() devCode, err := app.AcquireTokenByDeviceCode(ctx, config.Scopes) if err != nil { panic(err) } fmt.Printf("Device Code is: %s\n", devCode.Result.Message) result, err := devCode.AuthenticationResult(ctx) if err != nil { panic(fmt.Sprintf("got error while waiting for user to input the device code: %s", err)) } fmt.Println("Access token is " + result.AccessToken) return } fmt.Println("Access token is " + authResult.AccessToken) } microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/main.go000066400000000000000000000013631442026362400264130ustar00rootroot00000000000000package main import ( "context" "os" ) var ( //config = CreateConfig("config.json") cacheAccessor = &TokenCache{file: "serialized_cache.json"} ) func main() { ctx := context.Background() // TODO(msal): This is pretty yikes. At least we should use the flag package. exampleType := os.Args[1] if exampleType == "1" { acquireTokenDeviceCode() /*} else if exampleType == "2" { acquireByAuthorizationCodePublic() */ } else if exampleType == "3" { acquireByUsernamePasswordPublic(ctx) } else if exampleType == "4" { panic("currently not implemented") //acquireByAuthorizationCodeConfidential() } else if exampleType == "5" { acquireTokenClientSecret() } else if exampleType == "6" { acquireTokenClientCertificate() } } microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/sample_cache_accessor.go000066400000000000000000000012411442026362400317500ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main import ( "context" "log" "os" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" ) type TokenCache struct { file string } func (t *TokenCache) Replace(ctx context.Context, cache cache.Unmarshaler, hints cache.ReplaceHints) error { data, err := os.ReadFile(t.file) if err != nil { log.Println(err) } return cache.Unmarshal(data) } func (t *TokenCache) Export(ctx context.Context, cache cache.Marshaler, hints cache.ExportHints) error { data, err := cache.Marshal() if err != nil { log.Println(err) } return os.WriteFile(t.file, data, 0600) } microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/sample_utils.go000066400000000000000000000021321442026362400301630ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main import ( "encoding/json" "log" "os" ) // Config represents the config.json required to run the samples type Config struct { ClientID string `json:"client_id"` Authority string `json:"authority"` Scopes []string `json:"scopes"` Username string `json:"username"` Password string `json:"password"` RedirectURI string `json:"redirect_uri"` CodeChallenge string `json:"code_challenge"` CodeChallengeMethod string `json:"code_challenge_method"` State string `json:"state"` ClientSecret string `json:"client_secret"` Thumbprint string `json:"thumbprint"` PemData string `json:"pem_file"` } // CreateConfig creates the Config struct from a json file. func CreateConfig(fileName string) *Config { data, err := os.ReadFile(fileName) if err != nil { log.Fatal(err) } config := &Config{} err = json.Unmarshal(data, config) if err != nil { log.Fatal(err) } return config } microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/serialized_cache.json000066400000000000000000000032741442026362400313140ustar00rootroot00000000000000{ "Account": { "uid.utid-login.windows.net-contoso": { "username": "John Doe", "local_account_id": "object1234", "realm": "contoso", "environment": "login.windows.net", "home_account_id": "uid.utid", "authority_type": "MSSTS" } }, "RefreshToken": { "uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": { "target": "s2 s1 s3", "environment": "login.windows.net", "credential_type": "RefreshToken", "secret": "a refresh token", "client_id": "my_client_id", "home_account_id": "uid.utid" } }, "AccessToken": { "uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": { "environment": "login.windows.net", "credential_type": "AccessToken", "secret": "an access token", "realm": "contoso", "target": "s2 s1 s3", "client_id": "my_client_id", "cached_at": "1000", "home_account_id": "uid.utid", "extended_expires_on": "4600", "expires_on": "4600" } }, "IdToken": { "uid.utid-login.windows.net-idtoken-my_client_id-contoso-": { "realm": "contoso", "environment": "login.windows.net", "credential_type": "IdToken", "secret": "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature", "client_id": "my_client_id", "home_account_id": "uid.utid" } }, "AppMetadata": { "AppMetadata-login.windows.net-my_client_id": { "environment": "login.windows.net", "family_id": null, "client_id": "my_client_id" } } }microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/username_password_sample.go000066400000000000000000000031121442026362400325630ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package main import ( "context" "fmt" "log" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) func acquireByUsernamePasswordPublic(ctx context.Context) { config := CreateConfig("config.json") app, err := public.New(config.ClientID, public.WithCache(cacheAccessor), public.WithAuthority(config.Authority)) if err != nil { panic(err) } // look in the cache to see if the account to use has been cached var userAccount public.Account accounts, err := app.Accounts(ctx) if err != nil { panic("failed to read the cache") } for _, account := range accounts { if account.PreferredUsername == config.Username { userAccount = account } } // found a cached account, now see if an applicable token has been cached // NOTE: this API conflates error states, i.e. err is non-nil if an applicable token isn't // cached or if something goes wrong (making the HTTP request, unmarshalling, etc). result, err := app.AcquireTokenSilent( context.Background(), config.Scopes, public.WithSilentAccount(userAccount), ) if err != nil { // either there's no applicable token in the cache or something failed log.Println(err) // either there was no cached account/token or the call to AcquireTokenSilent() failed // make a new request to AAD result, err = app.AcquireTokenByUsernamePassword( context.Background(), config.Scopes, config.Username, config.Password, ) if err != nil { log.Fatal(err) } } fmt.Println("Access token is " + result.AccessToken) } microsoft-authentication-library-for-go-1.0.0/apps/tests/integration/000077500000000000000000000000001442026362400260165ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/tests/integration/integration_test.go000066400000000000000000000332061442026362400317330ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package integration import ( "context" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "testing" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) const ( msIDlabDefaultScope = "https://msidlab.com/.default" graphDefaultScope = "https://graph.windows.net/.default" ) const microsoftAuthorityHost = "https://login.microsoftonline.com/" const ( organizationsAuthority = microsoftAuthorityHost + "organizations/" microsoftAuthority = microsoftAuthorityHost + "microsoft.onmicrosoft.com" //msIDlabTenantAuthority = microsoftAuthorityHost + "msidlab4.onmicrosoft.com" - Will be needed in the future ) var httpClient = http.Client{} func httpRequest(ctx context.Context, url string, query url.Values, accessToken string) ([]byte, error) { if _, ok := ctx.Deadline(); !ok { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, 10*time.Second) defer cancel() } req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("failed to build new http request: %w", err) } req.Header.Set("Authorization", "Bearer "+accessToken) req.URL.RawQuery = query.Encode() resp, err := httpClient.Do(req) if err != nil { return nil, fmt.Errorf("http.Get(%s) failed: %w", req.URL.String(), err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("http.Get(%s): could not read body: %w", req.URL.String(), err) } return body, nil } type labClient struct { app confidential.Client } // TODO : Add app object type user struct { AppID string `json:"appId"` ObjectID string `json:"objectId"` UserType string `json:"userType"` DisplayName string `json:"displayName"` Licenses string `json:"licences"` Upn string `json:"upn"` Mfa string `json:"mfa"` ProtectionPolicy string `json:"protectionPolicy"` HomeDomain string `json:"homeDomain"` HomeUPN string `json:"homeUPN"` B2cProvider string `json:"b2cProvider"` LabName string `json:"labName"` LastUpdatedBy string `json:"lastUpdatedBy"` LastUpdatedDate string `json:"lastUpdatedDate"` Password string } type secret struct { Value string `json:"value"` } func newLabClient() (*labClient, error) { clientID := os.Getenv("clientId") secret := os.Getenv("clientSecret") cred, err := confidential.NewCredFromSecret(secret) if err != nil { return nil, fmt.Errorf("could not create a cred from a secret: %w", err) } app, err := confidential.New(microsoftAuthority, clientID, cred) if err != nil { return nil, err } return &labClient{app: app}, nil } func (l *labClient) labAccessToken() (string, error) { scopes := []string{msIDlabDefaultScope} result, err := l.app.AcquireTokenSilent(context.Background(), scopes) if err != nil { result, err = l.app.AcquireTokenByCredential(context.Background(), scopes) if err != nil { return "", fmt.Errorf("AcquireTokenByCredential() error: %w", err) } } return result.AccessToken, nil } func (l *labClient) user(ctx context.Context, query url.Values) (user, error) { accessToken, err := l.labAccessToken() if err != nil { return user{}, fmt.Errorf("problem getting lab access token: %w", err) } responseBody, err := httpRequest(ctx, "https://msidlab.com/api/user", query, accessToken) if err != nil { return user{}, err } var users []user err = json.Unmarshal(responseBody, &users) if err != nil { return user{}, err } if len(users) == 0 { return user{}, errors.New("No user found") } user := users[0] user.Password, err = l.secret(ctx, url.Values{"Secret": []string{user.LabName}}) if err != nil { return user, err } return user, nil } func (l *labClient) secret(ctx context.Context, query url.Values) (string, error) { accessToken, err := l.labAccessToken() if err != nil { return "", err } responseBody, err := httpRequest(ctx, "https://msidlab.com/api/LabUserSecret", query, accessToken) if err != nil { return "", err } var secret secret err = json.Unmarshal(responseBody, &secret) if err != nil { return "", err } return secret.Value, nil } // TODO: Add getApp() when needed func testUser(ctx context.Context, desc string, lc *labClient, query url.Values) user { testUser, err := lc.user(ctx, query) if err != nil { panic(fmt.Sprintf("TestUsernamePassword(%s) setup: testUser(): Failed to get input user: %s", desc, err)) } return testUser } func TestUsernamePassword(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") } labClientInstance, err := newLabClient() if err != nil { panic("failed to get a lab client: " + err.Error()) } tests := []struct { desc string vals url.Values }{ {"Managed", url.Values{"usertype": []string{"cloud"}}}, {"ADFSv2", url.Values{"usertype": []string{"federated"}, "federationProvider": []string{"ADFSv2"}}}, {"ADFSv3", url.Values{"usertype": []string{"federated"}, "federationProvider": []string{"ADFSv3"}}}, {"ADFSv4", url.Values{"usertype": []string{"federated"}, "federationProvider": []string{"ADFSv4"}}}, } for _, test := range tests { ctx := context.Background() user := testUser(ctx, test.desc, labClientInstance, test.vals) app, err := public.New(user.AppID, public.WithAuthority(organizationsAuthority)) if err != nil { panic(errors.Verbose(err)) } result, err := app.AcquireTokenByUsernamePassword( context.Background(), []string{graphDefaultScope}, user.Upn, user.Password, ) if err != nil { t.Fatalf("TestUsernamePassword(%s): on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", test.desc, errors.Verbose(err)) } if result.AccessToken == "" { t.Fatalf("TestUsernamePassword(%s): got AccessToken == '', want AccessToken != ''", test.desc) } if result.IDToken.IsZero() { t.Fatalf("TestUsernamePassword(%s): got IDToken == empty, want IDToken == non-empty struct", test.desc) } if result.Account.PreferredUsername != user.Upn { t.Fatalf("TestUsernamePassword(%s): got Username == %s, want Username == %s", test.desc, result.Account.PreferredUsername, user.Upn) } } } func TestConfidentialClientwithSecret(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") } clientID := os.Getenv("clientId") secret := os.Getenv("clientSecret") cred, err := confidential.NewCredFromSecret(secret) if err != nil { panic(errors.Verbose(err)) } app, err := confidential.New(microsoftAuthority, clientID, cred) if err != nil { panic(errors.Verbose(err)) } scopes := []string{msIDlabDefaultScope} result, err := app.AcquireTokenByCredential(context.Background(), scopes) if err != nil { t.Fatalf("TestConfidentialClientwithSecret: on AcquireTokenByCredential(): got err == %s, want err == nil", errors.Verbose(err)) } if result.AccessToken == "" { t.Fatal("TestConfidentialClientwithSecret: on AcquireTokenByCredential(): got AccessToken == '', want AccessToken != ''") } silentResult, err := app.AcquireTokenSilent(context.Background(), scopes) if err != nil { t.Fatalf("TestConfidentialClientwithSecret: on AcquireTokenSilent(): got err == %s, want err == nil", errors.Verbose(err)) } if silentResult.AccessToken == "" { t.Fatal("TestConfidentialClientwithSecret: on AcquireTokenSilent(): got AccessToken == '', want AccessToken != ''") } } func TestOnBehalfOf(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") } labClientInstance, err := newLabClient() if err != nil { panic("failed to get a lab client: " + err.Error()) } ctx := context.Background() //Confidential Client Application Config ccaClientID := os.Getenv("oboConfidentialClientId") ccaClientSecret := os.Getenv("oboConfidentialClientSecret") ccaScopes := []string{"https://graph.microsoft.com/user.read"} // Public Client Application Confifg pcaClientID := os.Getenv("oboPublicClientId") user := testUser(ctx, "OnBehalfOf", labClientInstance, url.Values{"usertype": []string{"cloud"}}) pcaScopes := []string{fmt.Sprintf("api://%s/.default", ccaClientID)} // 1. An app obtains a token representing a user, for our mid-tier service pca, err := public.New(pcaClientID, public.WithAuthority(organizationsAuthority)) if err != nil { panic(errors.Verbose(err)) } result, err := pca.AcquireTokenByUsernamePassword( ctx, pcaScopes, user.Upn, user.Password, ) if err != nil { t.Fatalf("TestOnBehalfOf: on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", errors.Verbose(err)) } if result.AccessToken == "" { t.Fatal("TestOnBehalfOf: on AcquireTokenByUsernamePassword(): got AccessToken == '', want AccessToken != ''") } // 2. Our mid-tier service uses OBO to obtain a token for downstream service cred, err := confidential.NewCredFromSecret(ccaClientSecret) if err != nil { panic(errors.Verbose(err)) } cca, err := confidential.New("https://login.microsoftonline.com/common", ccaClientID, cred) if err != nil { panic(errors.Verbose(err)) } result1, err := cca.AcquireTokenOnBehalfOf(ctx, result.AccessToken, ccaScopes) if err != nil { t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got err == %s, want err == nil", errors.Verbose(err)) } if result1.AccessToken == "" { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got AccessToken == '', want AccessToken != ''") } // 3. Same scope and assertion should return cached access token result2, err := cca.AcquireTokenOnBehalfOf(ctx, result.AccessToken, ccaScopes) if err != nil { t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf() silent token retrieval: got err == %s, want err == nil", errors.Verbose(err)) } if result1.AccessToken != result2.AccessToken { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens don't match") } // 4. scope2 should return new token scope2 := []string{"https://graph.windows.net/.default"} result3, err := cca.AcquireTokenOnBehalfOf(ctx, result.AccessToken, scope2) if err != nil { t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got err == %s, want err == nil", errors.Verbose(err)) } if result3.AccessToken == "" { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got AccessToken == '', want AccessToken != ''") } if result3.AccessToken == result2.AccessToken { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens match when they should not") } // 5. scope2 should return cached token result4, err := cca.AcquireTokenOnBehalfOf(ctx, result.AccessToken, scope2) if err != nil { t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got err == %s, want err == nil", errors.Verbose(err)) } if result4.AccessToken == "" { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got AccessToken == '', want AccessToken != ''") } if result4.AccessToken != result3.AccessToken { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens don't match") } // 6. New user assertion should return new token pca1, err := public.New(pcaClientID, public.WithAuthority(organizationsAuthority)) if err != nil { panic(errors.Verbose(err)) } result5, err := pca1.AcquireTokenByUsernamePassword( ctx, pcaScopes, user.Upn, user.Password, ) if err != nil { t.Fatalf("TestOnBehalfOf: on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", errors.Verbose(err)) } result6, err := cca.AcquireTokenOnBehalfOf(ctx, result5.AccessToken, scope2) if err != nil { t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got err == %s, want err == nil", errors.Verbose(err)) } if result6.AccessToken == "" { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got AccessToken == '', want AccessToken != ''") } if result6.AccessToken == result4.AccessToken { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens match when they should not") } if result6.AccessToken == result3.AccessToken { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens match when they should not") } if result6.AccessToken == result2.AccessToken { t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens match when they should not") } } func TestRemoveAccount(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") } labClientInstance, err := newLabClient() if err != nil { panic("failed to get a lab client: " + err.Error()) } ctx := context.Background() user := testUser(ctx, "TestRemoveAccount", labClientInstance, url.Values{"usertype": []string{"cloud"}}) app, err := public.New(user.AppID, public.WithAuthority(organizationsAuthority)) if err != nil { panic(errors.Verbose(err)) } // Populate the cache _, err = app.AcquireTokenByUsernamePassword( context.Background(), []string{graphDefaultScope}, user.Upn, user.Password, ) if err != nil { t.Fatalf("TestRemoveAccount: on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", errors.Verbose(err)) } accounts, err := app.Accounts(ctx) if err != nil { t.Fatal(err) } if len(accounts) == 0 { t.Fatal("TestRemoveAccount: No user accounts found in cache") } testAccount := accounts[0] // Only one account is populated and that is what we will remove. err = app.RemoveAccount(ctx, testAccount) if err != nil { t.Fatalf("TestRemoveAccount: on RemoveAccount(): got err == %s, want err == nil", errors.Verbose(err)) } // Remove Account will clear the cache fields associated with this account so acquire token silent should fail _, err = app.AcquireTokenSilent(ctx, []string{graphDefaultScope}, public.WithSilentAccount(testAccount)) if err == nil { t.Fatal("TestRemoveAccount: RemoveAccount() didn't clear the cache as expected") } } microsoft-authentication-library-for-go-1.0.0/apps/tests/performance/000077500000000000000000000000001442026362400257745ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/apps/tests/performance/performance_test.go000066400000000000000000000110401442026362400316570ustar00rootroot00000000000000// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package performance import ( "context" "fmt" "math/rand" "os" "testing" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/montanaflynn/stats" ) func fakeClient() (base.Client, error) { // we use a base.Client so we can provide a fake OAuth client return base.New("fake_client_id", "https://fake_authority/my_utid", &oauth.Client{ Authority: &fake.Authority{ InstanceResp: authority.InstanceDiscoveryResponse{ Metadata: []authority.InstanceDiscoveryMetadata{ { PreferredNetwork: "fake_authority", Aliases: []string{"fake_authority"}, }, }, }, }, Resolver: &fake.ResolveEndpoints{ Endpoints: authority.Endpoints{ AuthorizationEndpoint: "auth_endpoint", TokenEndpoint: "token_endpoint", }, }, WSTrust: &fake.WSTrust{}, }) } func populateCache(users int, tokens int, authParams authority.AuthParams, client base.Client) { for user := 0; user < users; user++ { for token := 0; token < tokens; token++ { authParams := client.AuthParams authParams.UserAssertion = fmt.Sprintf("fake_access_token%d", user) authParams.AuthorizationType = authority.ATOnBehalfOf scope := fmt.Sprintf("scope%d", token) _, err := client.AuthResultFromToken(context.Background(), authParams, accesstokens.TokenResponse{ AccessToken: fmt.Sprintf("fake_access_token%d", user), RefreshToken: "fake_refresh_token", ClientInfo: accesstokens.ClientInfo{UID: "my_uid", UTID: fmt.Sprintf("%dmy_utid", user)}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: []string{scope}}, IDToken: accesstokens.IDToken{ RawToken: "x.e30", }, }, true) if err != nil { panic(err) } } } } func calculateStats(users, tokens int, duration []float64) { fmt.Printf("No of users: %d, No of tokens per user: %d \n", users, tokens) mean, err := stats.Mean(duration) if err != nil { panic(err) } meanTime := mean / float64(time.Microsecond) fmt.Println("Mean") fmt.Println(meanTime) median, err := stats.Median(duration) medianTime := median / float64(time.Microsecond) if err != nil { panic(err) } fmt.Println("Median") fmt.Println(medianTime) stdDev, err := stats.StandardDeviation(duration) stdDevTime := stdDev / float64(time.Microsecond) if err != nil { panic(err) } fmt.Println("Standard Deviation") fmt.Println(stdDevTime) min, err := stats.Min(duration) minTime := min / float64(time.Microsecond) if err != nil { panic(err) } fmt.Println("Min Time") fmt.Println(minTime) max, err := stats.Max(duration) maxTime := max / float64(time.Microsecond) if err != nil { panic(err) } fmt.Println("Max Time") fmt.Println(maxTime) } func benchMarkObo(users int, tokens int, client base.Client) { var duration []float64 for start := time.Now(); time.Since(start) < time.Minute*1; { s := time.Now() queryCache(users, tokens, client) e := time.Now() duration = append(duration, float64(e.Sub(s))) } calculateStats(users, tokens, duration) } func queryCache(users int, tokens int, client base.Client) { userAssertion := fmt.Sprintf("fake_access_token%d", rand.Intn(users)) scope := []string{fmt.Sprintf("scope%d", rand.Intn(tokens))} params := base.AcquireTokenOnBehalfOfParameters{ Scopes: scope, UserAssertion: userAssertion, Credential: &accesstokens.Credential{Secret: "fake_secret"}, } _, err := client.AcquireTokenOnBehalfOf(context.Background(), params) if err != nil { panic(err) } } func TestOnBehalfOfCacheTests(t *testing.T) { if os.Getenv("CI") != "" { t.Skip("Skipping testing in CI environment") } tests := []struct { Users int Tokens int }{ {1, 10000}, {1, 100000}, {100, 10000}, {1000, 10000}, {10000, 100}, } for _, test := range tests { client, err := fakeClient() if err != nil { panic(err) } authParams := client.AuthParams populateCache(test.Users, test.Tokens, authParams, client) benchMarkObo(test.Users, test.Tokens, client) } } microsoft-authentication-library-for-go-1.0.0/changelog.md000066400000000000000000000000001442026362400236250ustar00rootroot00000000000000microsoft-authentication-library-for-go-1.0.0/go.mod000066400000000000000000000005201442026362400224710ustar00rootroot00000000000000module github.com/AzureAD/microsoft-authentication-library-for-go go 1.18 require ( github.com/golang-jwt/jwt/v4 v4.4.3 github.com/google/uuid v1.3.0 github.com/kylelemons/godebug v1.1.0 github.com/montanaflynn/stats v0.7.0 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 ) require golang.org/x/sys v0.5.0 // indirect microsoft-authentication-library-for-go-1.0.0/go.sum000066400000000000000000000022211442026362400225160ustar00rootroot00000000000000github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/montanaflynn/stats v0.7.0 h1:r3y12KyNxj/Sb/iOE46ws+3mS1+MZca1wlHQFPsY/JU= github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=